Pytorch识别LeNet模型怎样实现的
Admin 2022-09-06 群英技术资讯 697 次浏览
LeNet网络过卷积层时候保持分辨率不变,过池化层时候分辨率变小。实现如下
from PIL import Image import cv2 import matplotlib.pyplot as plt import torchvision from torchvision import transforms import torch from torch.utils.data import DataLoader import torch.nn as nn import numpy as np import tqdm as tqdm class LeNet(nn.Module): def __init__(self) -> None: super().__init__() self.sequential = nn.Sequential(nn.Conv2d(1,6,kernel_size=5,padding=2),nn.Sigmoid(), nn.AvgPool2d(kernel_size=2,stride=2), nn.Conv2d(6,16,kernel_size=5),nn.Sigmoid(), nn.AvgPool2d(kernel_size=2,stride=2), nn.Flatten(), nn.Linear(16*25,120),nn.Sigmoid(), nn.Linear(120,84),nn.Sigmoid(), nn.Linear(84,10)) def forward(self,x): return self.sequential(x) class MLP(nn.Module): def __init__(self) -> None: super().__init__() self.sequential = nn.Sequential(nn.Flatten(), nn.Linear(28*28,120),nn.Sigmoid(), nn.Linear(120,84),nn.Sigmoid(), nn.Linear(84,10)) def forward(self,x): return self.sequential(x) epochs = 15 batch = 32 lr=0.9 loss = nn.CrossEntropyLoss() model = LeNet() optimizer = torch.optim.SGD(model.parameters(),lr) device = torch.device('cuda') root = r"./" trans_compose = transforms.Compose([transforms.ToTensor(), ]) train_data = torchvision.datasets.MNIST(root,train=True,transform=trans_compose,download=True) test_data = torchvision.datasets.MNIST(root,train=False,transform=trans_compose,download=True) train_loader = DataLoader(train_data,batch_size=batch,shuffle=True) test_loader = DataLoader(test_data,batch_size=batch,shuffle=False) model.to(device) loss.to(device) # model.apply(init_weights) for epoch in range(epochs): train_loss = 0 test_loss = 0 correct_train = 0 correct_test = 0 for index,(x,y) in enumerate(train_loader): x = x.to(device) y = y.to(device) predict = model(x) L = loss(predict,y) optimizer.zero_grad() L.backward() optimizer.step() train_loss = train_loss + L correct_train += (predict.argmax(dim=1)==y).sum() acc_train = correct_train/(batch*len(train_loader)) with torch.no_grad(): for index,(x,y) in enumerate(test_loader): [x,y] = [x.to(device),y.to(device)] predict = model(x) L1 = loss(predict,y) test_loss = test_loss + L1 correct_test += (predict.argmax(dim=1)==y).sum() acc_test = correct_test/(batch*len(test_loader)) print(f'epoch:{epoch},train_loss:{train_loss/batch},test_loss:{test_loss/batch},acc_train:{acc_train},acc_test:{acc_test}')
epoch:12,train_loss:2.235553741455078,test_loss:0.3947642743587494,acc_train:0.9879833459854126,acc_test:0.9851238131523132
epoch:13,train_loss:2.028963804244995,test_loss:0.3220392167568207,acc_train:0.9891499876976013,acc_test:0.9875199794769287
epoch:14,train_loss:1.8020273447036743,test_loss:0.34837451577186584,acc_train:0.9901833534240723,acc_test:0.98702073097229
找了一张图片,将其分割成只含一个数字的图片进行测试
images_np = cv2.imread("/content/R-C.png",cv2.IMREAD_GRAYSCALE) h,w = images_np.shape images_np = np.array(255*torch.ones(h,w))-images_np#图片反色 images = Image.fromarray(images_np) plt.figure(1) plt.imshow(images) test_images = [] for i in range(10): for j in range(16): test_images.append(images_np[h//10*i:h//10+h//10*i,w//16*j:w//16*j+w//16]) sample = test_images[77] sample_tensor = torch.tensor(sample).unsqueeze(0).unsqueeze(0).type(torch.FloatTensor).to(device) sample_tensor = torch.nn.functional.interpolate(sample_tensor,(28,28)) predict = model(sample_tensor) output = predict.argmax() print(output) plt.figure(2) plt.imshow(np.array(sample_tensor.squeeze().to('cpu')))
此时预测结果为4,预测正确。从这段代码中可以看到有一个反色的步骤,若不反色,结果会受到影响,如下图所示,预测为0,错误。
模型用于输入的图片是单通道的黑白图片,这里由于可视化出现了黄色,但实际上是黑白色,反色操作说明了数据的预处理十分的重要,很多数据如果是不清理过是无法直接用于推理的。
将所有用来泛化性测试的图片进行准确率测试:
correct = 0 i = 0 cnt = 1 for sample in test_images: sample_tensor = torch.tensor(sample).unsqueeze(0).unsqueeze(0).type(torch.FloatTensor).to(device) sample_tensor = torch.nn.functional.interpolate(sample_tensor,(28,28)) predict = model(sample_tensor) output = predict.argmax() if(output==i): correct+=1 if(cnt%16==0): i+=1 cnt+=1 acc_g = correct/len(test_images) print(f'acc_g:{acc_g}')
如果不反色,acc_g=0.15
acc_g:0.50625
免责声明:本站发布的内容(图片、视频和文字)以原创、转载和分享为主,文章观点不代表本网站立场,如果涉及侵权请联系站长邮箱:mmqy2019@163.com进行举报,并提供相关证据,查实之后,将立刻删除涉嫌侵权内容。
猜你喜欢
CSV(以逗号分隔的值)是用于存储表格数据的纯文本文件格式(如电子表格或数据库)。其主要存储的表格数据包括数字和纯文本。多数联机服务允许用户以CSV文件格式导出网站中的数据。通常在Excel中打开CSV文件,并且几乎所有数据库都有不同的特定工具来允许同一文件的导入。
这篇文章主要介绍了用Python写一个简易版弹球游戏,文中有很多实用代码,对正在学习python的小伙伴们有很大的帮助.需要的朋友可以参考下
今天学习了一下外键查询的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
这篇文章主要介绍了python计算机视觉OpenCV入门讲解,关于图像处理的相关简单操作,包括读入图像、显示图像及图像相关理论知识
Pygame是被设计用来写游戏的python模块集合,Pygame是在优秀的SDL库之上开发的功能性包。本文将利用Pygame制作简易的动画,感兴趣的可以学习一下
成为群英会员,开启智能安全云计算之旅
立即注册Copyright © QY Network Company Ltd. All Rights Reserved. 2003-2020 群英 版权所有
增值电信经营许可证 : B1.B2-20140078 粤ICP备09006778号 域名注册商资质 粤 D3.1-20240008