PyTorch卷积神经网络的重要基础函数有哪些
Admin 2022-09-17 群英技术资讯 276 次浏览
nn.Conv2d在pytorch中用于实现卷积。
nn.Conv2d( in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1, )
1、in_channels为输入通道数。
2、out_channels为输出通道数。
3、kernel_size为卷积核大小。
4、stride为步数。
5、padding为padding情况。
6、dilation表示空洞卷积情况。
nn.MaxPool2d在pytorch中用于实现最大池化。
具体使用方式如下:
MaxPool2d(kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False)
1、kernel_size为池化核的大小
2、stride为步长
3、padding为填充情况
nn.ReLU()用来实现Relu函数,实现非线性。
x.view用于reshape特征层的形状。
这是一个简单的CNN模型,用于预测mnist手写体。
import os import numpy as np import torch import torch.nn as nn import torch.utils.data as Data import torchvision import matplotlib.pyplot as plt # 循环世代 EPOCH = 20 BATCH_SIZE = 50 # 下载mnist数据集 train_data = torchvision.datasets.MNIST(root='./mnist/',train=True,transform=torchvision.transforms.ToTensor(),download=True,) # (60000, 28, 28) print(train_data.train_data.size()) # (60000) print(train_data.train_labels.size()) train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True) # 测试集 test_data = torchvision.datasets.MNIST(root='./mnist/', train=False) # (2000, 1, 28, 28) # 标准化 test_x = torch.unsqueeze(test_data.test_data, dim=1).type(torch.FloatTensor)[:2000]/255. test_y = test_data.test_labels[:2000] # 建立pytorch神经网络 class CNN(nn.Module): def __init__(self): super(CNN, self).__init__() #----------------------------# # 第一部分卷积 #----------------------------# self.conv1 = nn.Sequential( nn.Conv2d( in_channels=1, out_channels=32, kernel_size=5, stride=1, padding=2, dilation=1 ), nn.ReLU(), nn.MaxPool2d(kernel_size=2), ) #----------------------------# # 第二部分卷积 #----------------------------# self.conv2 = nn.Sequential( nn.Conv2d( in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1, dilation=1 ), nn.ReLU(), nn.MaxPool2d(kernel_size=2), ) #----------------------------# # 全连接+池化+全连接 #----------------------------# self.ful1 = nn.Linear(64 * 7 * 7, 512) self.drop = nn.Dropout(0.5) self.ful2 = nn.Sequential(nn.Linear(512, 10),nn.Softmax()) #----------------------------# # 前向传播 #----------------------------# def forward(self, x): x = self.conv1(x) x = self.conv2(x) x = x.view(x.size(0), -1) x = self.ful1(x) x = self.drop(x) output = self.ful2(x) return output cnn = CNN() # 指定优化器 optimizer = torch.optim.Adam(cnn.parameters(), lr=1e-3) # 指定loss函数 loss_func = nn.CrossEntropyLoss() for epoch in range(EPOCH): for step, (b_x, b_y) in enumerate(train_loader): #----------------------------# # 计算loss并修正权值 #----------------------------# output = cnn(b_x) loss = loss_func(output, b_y) optimizer.zero_grad() loss.backward() optimizer.step() #----------------------------# # 打印 #----------------------------# if step % 50 == 0: test_output = cnn(test_x) pred_y = torch.max(test_output, 1)[1].data.numpy() accuracy = float((pred_y == test_y.data.numpy()).astype(int).sum()) / float(test_y.size(0)) print('Epoch: %2d'% epoch, ', loss: %.4f' % loss.data.numpy(), ', accuracy: %.4f' % accuracy)
免责声明:本站发布的内容(图片、视频和文字)以原创、转载和分享为主,文章观点不代表本网站立场,如果涉及侵权请联系站长邮箱:mmqy2019@163.com进行举报,并提供相关证据,查实之后,将立刻删除涉嫌侵权内容。
猜你喜欢
python中,while循环与for循环是经常使用的循环语句,一直到的到结果才会循环结束。但是,也会有一直循环,无法计算出结果的情况出现,这时我们就要跳出循环。本文介绍使用break跳出for循环的两种方法:1、提前定义一个变量,让其为空字符串;2、使用for…else…实现break跳出嵌套的for循环。
这篇文章主要介绍了Python开发网站目录扫描器的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
通常需要对前端传递过来的参数进行校验,校验的方式有多种,本文主要介绍了Python中rapidjson参数校验实现,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
PIL网上有很多介绍,这里不再讲解。直接操作,读取一张图片,将其转换为灰度图像,并打印出来。
目录pack常用属性pack类提供了下列函数(使用组件实例对象调用)grid属性设置grid类提供了下列函数(使用组件实例对象调用)place属性设置place类提供了下列函数(使用组件实例对象调
成为群英会员,开启智能安全云计算之旅
立即注册Copyright © QY Network Company Ltd. All Rights Reserved. 2003-2020 群英 版权所有
增值电信经营许可证 : B1.B2-20140078 粤ICP备09006778号 域名注册商资质 粤 D3.1-20240008