使用pytorch的dataloader报错怎么办,如何解决
Admin 2022-07-30 群英技术资讯 487 次浏览
使用pytorch的dataloader报错:
RuntimeError: stack expects each tensor to be equal size, but got [2] at entry 0 and [1] at entry 1
报错定位:位于定义dataset的代码中
def __getitem__(self, index): ... return y #此处报错
报错内容
File "D:\python\lib\site-packages\torch\utils\data\_utils\collate.py", line 55, in default_collate
return torch.stack(batch, 0, out=out)
RuntimeError: stack expects each tensor to be equal size, but got [2] at entry 0 and [1] at entry 1
把前一行的报错带上能够更清楚地明白问题在哪里.
从报错可以看到,是在代码中执行torch.stack时发生了报错.因此必须要明白在哪里执行了stack操作.
通过调试可以发现,在通过loader加载一个batch数据的时候,是通过每一次给一个随机的index取出相应的向量.那么最终要形成一个batch的数据就必须要进行拼接操作,而torch.stack就是进行这里所说的拼接.
再来看看具体报的什么错: 说是stack的向量维度不同. 这说明在每次给出一个随机的index,返回的y向量的维度应该是相同的,而我们这里是不同的.
这样解决方法也就明确了:使返回的向量y的维度固定下来.
为什么我会出现这样的一个问题,是因为我的特征向量中存在multi-hot特征.而为了节省空间,我是用一个列表存储这个特征的.示例如下:
feature=[[1,3,5], [0,2], [1,2,5,8]]
这就导致了我每次返回的向量的维度是不同的.因此可以采用向量补全的方法,把不同长度的向量补全成等长的.
# 把所有向量的长度都补为6 multi = np.pad(multi, (0, 6-multi.shape[0]), 'constant', constant_values=(0, -1))
在构建dataset重写的__getitem__方法中要返回相同长度的tensor.
可以使用向量补全的方法来解决这个问题.
补充:pytorch学习笔记:torch.utils.data下的TensorDataset和DataLoader的使用
对给定的tensor数据(样本和标签),将它们包装成dataset。注意,如果是numpy的array,或者Pandas的DataFrame需要先转换成Tensor。
''' data_tensor (Tensor) - 样本数据 target_tensor (Tensor) - 样本目标(标签) ''' dataset=torch.utils.data.TensorDataset(data_tensor, target_tensor)
我们先定义一下样本数据和标签数据,一共有1000个样本
import torch import numpy as np num_inputs = 2 num_examples = 1000 true_w = [2, -3.4] true_b = 4.2 features = torch.tensor(np.random.normal(0, 1, (num_examples, num_inputs)), dtype=torch.float) labels = true_w[0] * features[:, 0] + \ true_w[1] * features[:, 1] + true_b labels += torch.tensor(np.random.normal(0, 0.01, size=labels.size()), dtype=torch.float) print(features.shape) print(labels.shape) ''' 输出:torch.Size([1000, 2]) torch.Size([1000]) '''
然后我们使用TensorDataset来生成数据集
import torch.utils.data as Data # 将训练数据的特征和标签组合 dataset = Data.TensorDataset(features, labels)
数据加载器,组合数据集和采样器,并在数据集上提供单进程或多进程迭代器。它可以对我们上面所说的数据集Dataset作进一步的设置。
dataset (Dataset) �C 加载数据的数据集。
batch_size (int, optional) �C 每个batch加载多少个样本(默认: 1)。
shuffle (bool, optional) �C 设置为True时会在每个epoch重新打乱数据(默认: False).
sampler (Sampler, optional) �C 定义从数据集中提取样本的策略。如果指定,则shuffle必须设置成False。
num_workers (int, optional) �C 用多少个子进程加载数据。0表示数据将在主进程中加载(默认: 0)
pin_memory:内存寄存,默认为False。在数据返回前,是否将数据复制到CUDA内存中。
drop_last (bool, optional) �C 如果数据集大小不能被batch size整除,则设置为True后可删除最后一个不完整的batch。如果设为False并且数据集的大小不能被batch size整除,则最后一个batch将更小。(默认: False)
timeout:是用来设置数据读取的超时时间的,如果超过这个时间还没读取到数据的话就会报错。 所以,数值必须大于等于0。
data_iter=torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None)
上面对一些重要常用的参数做了说明,其中有一个参数是sampler,下面我们对它有哪些具体取值再做一下说明。只列出几个常用的取值:
torch.utils.data.sampler.SequentialSampler(dataset)
样本元素按顺序采样,始终以相同的顺序。
torch.utils.data.sampler.RandomSampler(dataset)
样本元素随机采样,没有替换。
torch.utils.data.sampler.SubsetRandomSampler(indices)
样本元素从指定的索引列表中随机抽取,没有替换。
下面就来看一个例子,该例子使用的dataset就是上面所生成的dataset
data_iter=Data.DataLoader(dataset, batch_size=10, shuffle=False, sampler=torch.utils.data.sampler.RandomSampler(dataset)) for X, y in data_iter: print(X,"\n", y) break ''' 输出: tensor([[-1.6338, 0.8451], [ 0.7245, -0.7387], [ 0.4672, 0.2623], [-1.9082, 0.0980], [-0.3881, 0.5138], [-0.6983, -0.4712], [ 0.1400, 0.7489], [-0.7761, -0.4596], [-2.2700, -0.2532], [-1.2641, -2.8089]]) tensor([-1.9451, 8.1587, 4.2374, 0.0519, 1.6843, 4.3970, 1.9311, 4.1999,0.5253, 11.2277]) '''
免责声明:本站发布的内容(图片、视频和文字)以原创、转载和分享为主,文章观点不代表本网站立场,如果涉及侵权请联系站长邮箱:mmqy2019@163.com进行举报,并提供相关证据,查实之后,将立刻删除涉嫌侵权内容。
猜你喜欢
这篇文章主要介绍了Python数据类型、进制转换、字符串格式化,Python2中区分整型int、长整型long,Python3中只有统称为整型int,本文给大家介绍的非常详细,需要的朋友参考下吧
这篇文章主要介绍了python-pandas创建Series数据类型的操作,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
这篇文章主要介绍了python出现RuntimeError错误问题及解决方案,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
python中如何求取字典最大值?方法一,通过sorted()函数排序所有的value;方法二,通过max函数取字典中value所对应的key值;
整个Spark 框架模块包含:Spark Core、 Spark SQL、 Spark Streaming、 Spark GraphX、 Spark MLlib,而后四项的能力都是建立在核心引擎之上。Spark Core:Spark的核心,Spark核心功能均由Spark Core模块提供...
成为群英会员,开启智能安全云计算之旅
立即注册Copyright © QY Network Company Ltd. All Rights Reserved. 2003-2020 群英 版权所有
增值电信经营许可证 : B1.B2-20140078 粤ICP备09006778号 域名注册商资质 粤 D3.1-20240008