pytorch如何实现rnn变长输入序列,方法是什么
Admin 2022-07-23 群英技术资讯 408 次浏览
输入数据是长度不固定的序列数据,主要讲解两个部分
1、Data.DataLoader的collate_fn用法,以及按batch进行padding数据
2、pack_padded_sequence和pad_packed_sequence来处理变长序列
Dataloader的collate_fn参数,定义数据处理和合并成batch的方式。
由于pack_padded_sequence用到的tensor必须按照长度从大到小排过序的,所以在Collate_fn中,需要完成两件事,一是把当前batch的样本按照当前batch最大长度进行padding,二是将padding后的数据从大到小进行排序。
def pad_tensor(vec, pad): """ args: vec - tensor to pad pad - the size to pad to return: a new tensor padded to 'pad' """ return torch.cat([vec, torch.zeros(pad - len(vec), dtype=torch.float)], dim=0).data.numpy() class Collate: """ a variant of callate_fn that pads according to the longest sequence in a batch of sequences """ def __init__(self): pass def _collate(self, batch): """ args: batch - list of (tensor, label) reutrn: xs - a tensor of all examples in 'batch' before padding like: ''' [tensor([1,2,3,4]), tensor([1,2]), tensor([1,2,3,4,5])] ''' ys - a LongTensor of all labels in batch like: ''' [1,0,1] ''' """ xs = [torch.FloatTensor(v[0]) for v in batch] ys = torch.LongTensor([v[1] for v in batch]) # 获得每个样本的序列长度 seq_lengths = torch.LongTensor([v for v in map(len, xs)]) max_len = max([len(v) for v in xs]) # 每个样本都padding到当前batch的最大长度 xs = torch.FloatTensor([pad_tensor(v, max_len) for v in xs]) # 把xs和ys按照序列长度从大到小排序 seq_lengths, perm_idx = seq_lengths.sort(0, descending=True) xs = xs[perm_idx] ys = ys[perm_idx] return xs, seq_lengths, ys def __call__(self, batch): return self._collate(batch)
定义完collate类以后,在DataLoader中直接使用
train_data = Data.DataLoader(dataset=train_dataset, batch_size=32, num_workers=0, collate_fn=Collate())
pack_padded_sequence将一个填充过的变长序列压紧。输入参数包括
input(Variable)- 被填充过后的变长序列组成的batch data
lengths (list[int]) - 变长序列的原始序列长度
batch_first (bool,optional) - 如果是True,input的形状应该是(batch_size,seq_len,input_size)
返回值:一个PackedSequence对象,可以直接作为rnn,lstm,gru的传入数据。
用法:
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence # x是填充过后的batch数据,seq_lengths是每个样本的序列长度 packed_input = pack_padded_sequence(x, seq_lengths, batch_first=True)
定义了一个单向的LSTM模型,因为处理的是变长序列,forward函数传入的值是一个PackedSequence对象,返回值也是一个PackedSequence对象
class Model(nn.Module): def __init__(self, in_size, hid_size, n_layer, drop=0.1, bi=False): super(Model, self).__init__() self.lstm = nn.LSTM(input_size=in_size, hidden_size=hid_size, num_layers=n_layer, batch_first=True, dropout=drop, bidirectional=bi) # 分类类别数目为2 self.fc = nn.Linear(in_features=hid_size, out_features=2) def forward(self, x): ''' :param x: 变长序列时,x是一个PackedSequence对象 :return: PackedSequence对象 ''' # lstm_out: tensor of shape (batch, seq_len, num_directions * hidden_size) lstm_out, _ = self.lstm(x) return lstm_out model = Model() lstm_out = model(packed_input)
这个操作和pack_padded_sequence()是相反的,把压紧的序列再填充回来。因为前面提到的LSTM模型传入和返回的都是PackedSequence对象,所以我们如果想要把返回的PackedSequence对象转换回Tensor,就需要用到pad_packed_sequence函数。
参数说明:
sequence (PackedSequence) �C 将要被填充的 batch
batch_first (bool, optional) �C 如果为True,返回的数据的形状为(batch_size,seq_len,input_size)
返回值: 一个tuple,包含被填充后的序列,和batch中序列的长度列表。
用法:
# 此处lstm_out是一个PackedSequence对象 output, _ = pad_packed_sequence(lstm_out)
返回的output是一个形状为(batch_size,seq_len,input_size)的tensor。
1、pytorch在自定义dataset时,可以在DataLoader的collate_fn参数中定义对数据的变换,操作以及合成batch的方式。
2、处理变长rnn问题时,通过pack_padded_sequence()将填充的batch数据转换成PackedSequence对象,直接传入rnn模型中。通过pad_packed_sequence()来将rnn模型输出的PackedSequence对象转换回相应的Tensor。
补充:pytorch实现不定长输入的RNN / LSTM / GRU
As we all know,RNN循环神经网络(及其改进模型LSTM、GRU)可以处理序列的顺序信息,如人类自然语言。但是在实际场景中,我们常常向模型输入一个批次(batch)的数据,这个批次中的每个序列往往不是等长的。
pytorch提供的模型(nn.RNN,nn.LSTM,nn.GRU)是支持可变长序列的处理的,但条件是传入的数据必须按序列长度排序。本文针对以下两种场景提出解决方法。
1、每个样本只有一个序列:(seq,label),其中seq是一个长度不定的序列。则使用pytorch训练时,我们将按列把一个批次的数据输入网络,seq这一列的形状就是(batch_size, seq_len),经过编码层(如word2vec)之后的形状是(batch_size, seq_len, emb_size)。
2、情况1的拓展:每个样本有两个(或多个)序列,如(seq1, seq2, label)。这种样本形式在问答系统、推荐系统多见。
定义ImprovedRnn类。与nn.RNN,nn.LSTM,nn.GRU相比,除了此两点【①forward函数多一个参数lengths表示每个seq的长度】【②初始化函数(__init__)第一个参数module必须指定三者之一】外,使用方法完全相同。
import torch from torch import nn class ImprovedRnn(nn.Module): def __init__(self, module, *args, **kwargs): assert module in (nn.RNN, nn.LSTM, nn.GRU) super().__init__() self.module = module(*args, **kwargs) def forward(self, input, lengths): # input shape(batch_size, seq_len, input_size) if not hasattr(self, '_flattened'): self.module.flatten_parameters() setattr(self, '_flattened', True) max_len = input.shape[1] # enforce_sorted=False则自动按lengths排序,并且返回值package.unsorted_indices可用于恢复原顺序 package = nn.utils.rnn.pack_padded_sequence(input, lengths.cpu(), batch_first=self.module.batch_first, enforce_sorted=False) result, hidden = self.module(package) # total_length参数一般不需要,因为lengths列表中一般含最大值。但分布式训练时是将一个batch切分了,故一定要有! result, lens = nn.utils.rnn.pad_packed_sequence(result, batch_first=self.module.batch_first, total_length=max_len) return result[package.unsorted_indices], hidden # output shape(batch_size, seq_len, rnn_hidden_size)
使用示例:
class TestNet(nn.Module): def __init__(self, word_emb, gru_in, gru_out): super().__init__() self.encode = nn.Embedding.from_pretrained(torch.Tensor(word_emb)) self.rnn = ImprovedRnn(nn.RNN, input_size=gru_in, hidden_size=gru_out, batch_first=True, bidirectional=True) def forward(self, seq1, seq1_lengths, seq2, seq2_lengths): seq1_emb = self.encode(seq1) seq2_emb = self.encode(seq2) rnn1, hn = self.rnn(seq1_emb, seq1_lengths) rnn2, hn = self.rnn(seq2_emb, seq2_lengths) """ 此处略去rnn1和rnn2的后续计算,当前网络最后计算结果记为prediction """ return prediction
免责声明:本站发布的内容(图片、视频和文字)以原创、转载和分享为主,文章观点不代表本网站立场,如果涉及侵权请联系站长邮箱:mmqy2019@163.com进行举报,并提供相关证据,查实之后,将立刻删除涉嫌侵权内容。
猜你喜欢
这篇文章主要介绍了Python 类方法和静态方法之间的区别,静态方法并不是真正意义上的类方法,它只是一个被放到类里的函数而已,更多内容需要的朋友可以参考一下
这篇文章主要介绍python稀疏矩阵用法,下文有具体的介绍和示例,有一定借鉴价值,感兴趣的朋友可以参考下,希望大家阅读完这篇文章能有所收获,接下来小编带着大家一起了解看看吧。
现在重新稍微系统的介绍一下ResNet网络结构。 ResNet结构首先通过一个卷积层然后有一个池化层,然后通过一系列的残差结构,最后再通过一个平均池化下采样操作,以及一个全连接层的得到了一个输出。ResNet网络可以达到很深的层数的原因就是不断的堆叠残差结构而来的。
Pandas是Python语言的一个扩展程序库,提供高性能、易于使用的数据结构和数据分析工具,下面这篇文章主要给大家介绍了关于如何使用pandas对超大csv文件进行快速拆分的相关资料,需要的朋友可以参考下
内存映射通常可以提高I/O的性能,因为使用内存映射时,不需要对每个访问都建立一个单独的系统调用,也不需要在缓冲区之间复制数据,内核和用户都能很方便的直接访问内存。
成为群英会员,开启智能安全云计算之旅
立即注册Copyright © QY Network Company Ltd. All Rights Reserved. 2003-2020 群英 版权所有
增值电信经营许可证 : B1.B2-20140078 粤ICP备09006778号 域名注册商资质 粤 D3.1-20240008