Pytorch识别LeNet模型怎样实现的
Admin 2022-09-06 群英技术资讯 478 次浏览
LeNet网络过卷积层时候保持分辨率不变,过池化层时候分辨率变小。实现如下
from PIL import Image import cv2 import matplotlib.pyplot as plt import torchvision from torchvision import transforms import torch from torch.utils.data import DataLoader import torch.nn as nn import numpy as np import tqdm as tqdm class LeNet(nn.Module): def __init__(self) -> None: super().__init__() self.sequential = nn.Sequential(nn.Conv2d(1,6,kernel_size=5,padding=2),nn.Sigmoid(), nn.AvgPool2d(kernel_size=2,stride=2), nn.Conv2d(6,16,kernel_size=5),nn.Sigmoid(), nn.AvgPool2d(kernel_size=2,stride=2), nn.Flatten(), nn.Linear(16*25,120),nn.Sigmoid(), nn.Linear(120,84),nn.Sigmoid(), nn.Linear(84,10)) def forward(self,x): return self.sequential(x) class MLP(nn.Module): def __init__(self) -> None: super().__init__() self.sequential = nn.Sequential(nn.Flatten(), nn.Linear(28*28,120),nn.Sigmoid(), nn.Linear(120,84),nn.Sigmoid(), nn.Linear(84,10)) def forward(self,x): return self.sequential(x) epochs = 15 batch = 32 lr=0.9 loss = nn.CrossEntropyLoss() model = LeNet() optimizer = torch.optim.SGD(model.parameters(),lr) device = torch.device('cuda') root = r"./" trans_compose = transforms.Compose([transforms.ToTensor(), ]) train_data = torchvision.datasets.MNIST(root,train=True,transform=trans_compose,download=True) test_data = torchvision.datasets.MNIST(root,train=False,transform=trans_compose,download=True) train_loader = DataLoader(train_data,batch_size=batch,shuffle=True) test_loader = DataLoader(test_data,batch_size=batch,shuffle=False) model.to(device) loss.to(device) # model.apply(init_weights) for epoch in range(epochs): train_loss = 0 test_loss = 0 correct_train = 0 correct_test = 0 for index,(x,y) in enumerate(train_loader): x = x.to(device) y = y.to(device) predict = model(x) L = loss(predict,y) optimizer.zero_grad() L.backward() optimizer.step() train_loss = train_loss + L correct_train += (predict.argmax(dim=1)==y).sum() acc_train = correct_train/(batch*len(train_loader)) with torch.no_grad(): for index,(x,y) in enumerate(test_loader): [x,y] = [x.to(device),y.to(device)] predict = model(x) L1 = loss(predict,y) test_loss = test_loss + L1 correct_test += (predict.argmax(dim=1)==y).sum() acc_test = correct_test/(batch*len(test_loader)) print(f'epoch:{epoch},train_loss:{train_loss/batch},test_loss:{test_loss/batch},acc_train:{acc_train},acc_test:{acc_test}')
epoch:12,train_loss:2.235553741455078,test_loss:0.3947642743587494,acc_train:0.9879833459854126,acc_test:0.9851238131523132
epoch:13,train_loss:2.028963804244995,test_loss:0.3220392167568207,acc_train:0.9891499876976013,acc_test:0.9875199794769287
epoch:14,train_loss:1.8020273447036743,test_loss:0.34837451577186584,acc_train:0.9901833534240723,acc_test:0.98702073097229
找了一张图片,将其分割成只含一个数字的图片进行测试
images_np = cv2.imread("/content/R-C.png",cv2.IMREAD_GRAYSCALE) h,w = images_np.shape images_np = np.array(255*torch.ones(h,w))-images_np#图片反色 images = Image.fromarray(images_np) plt.figure(1) plt.imshow(images) test_images = [] for i in range(10): for j in range(16): test_images.append(images_np[h//10*i:h//10+h//10*i,w//16*j:w//16*j+w//16]) sample = test_images[77] sample_tensor = torch.tensor(sample).unsqueeze(0).unsqueeze(0).type(torch.FloatTensor).to(device) sample_tensor = torch.nn.functional.interpolate(sample_tensor,(28,28)) predict = model(sample_tensor) output = predict.argmax() print(output) plt.figure(2) plt.imshow(np.array(sample_tensor.squeeze().to('cpu')))
此时预测结果为4,预测正确。从这段代码中可以看到有一个反色的步骤,若不反色,结果会受到影响,如下图所示,预测为0,错误。
模型用于输入的图片是单通道的黑白图片,这里由于可视化出现了黄色,但实际上是黑白色,反色操作说明了数据的预处理十分的重要,很多数据如果是不清理过是无法直接用于推理的。
将所有用来泛化性测试的图片进行准确率测试:
correct = 0 i = 0 cnt = 1 for sample in test_images: sample_tensor = torch.tensor(sample).unsqueeze(0).unsqueeze(0).type(torch.FloatTensor).to(device) sample_tensor = torch.nn.functional.interpolate(sample_tensor,(28,28)) predict = model(sample_tensor) output = predict.argmax() if(output==i): correct+=1 if(cnt%16==0): i+=1 cnt+=1 acc_g = correct/len(test_images) print(f'acc_g:{acc_g}')
如果不反色,acc_g=0.15
acc_g:0.50625
免责声明:本站发布的内容(图片、视频和文字)以原创、转载和分享为主,文章观点不代表本网站立场,如果涉及侵权请联系站长邮箱:mmqy2019@163.com进行举报,并提供相关证据,查实之后,将立刻删除涉嫌侵权内容。
猜你喜欢
集合(set)是一个无序的不重复元素序列。因此在每次运行的时候集合的运行结果的内容都是相同的,但元素的排列顺序却不是固定的,所以本章中部分案例的运行结果会出现与给出结果不同的情况(运行结果不唯一)可以使用大括号{}或者set()函数创建集合,注意:创建一个空集合必须用set()而不是{},因为{}是用来创建一个空字典
目录pack常用属性pack类提供了下列函数(使用组件实例对象调用)grid属性设置grid类提供了下列函数(使用组件实例对象调用)place属性设置place类提供了下列函数(使用组件实例对象调
这篇文章主要介绍了利用 Python 开发一个 Python 解释器,在本文中,我们将设计一个可以执行算术运算的解释器。下面我们大家一起来看看吧
这篇文章主要为大家介绍了Python上下文管理器,具有一定的参考价值,感兴趣的小伙伴们可以参考一下,希望能够给你带来帮助<BR>
这篇文章主要为大家介绍了Python变量的作用域,具有一定的参考价值,感兴趣的小伙伴们可以参考一下,希望能够给你带来帮助
成为群英会员,开启智能安全云计算之旅
立即注册Copyright © QY Network Company Ltd. All Rights Reserved. 2003-2020 群英 版权所有
增值电信经营许可证 : B1.B2-20140078 粤ICP备09006778号 域名注册商资质 粤 D3.1-20240008