tensorflow与pytorch的转换是怎样的
Admin 2022-08-08 群英技术资讯 530 次浏览
由于本人只熟悉pytorch,而对tensorflow一知半解,而代码经常遇到tensorflow,而我希望使用pytorch,因此简单介绍一下tensorflow转pytorch,可能存在诸多错误,希望轻喷~
在TensorFlow的世界里,变量的定义和初始化是分开的。
tensorflow中一般都是在开头预定义变量,声明其数据类型、形状等,在执行的时候再赋具体的值,如下图所示,而pytorch用到时才会定义,定义和变量初始化是合在一起的。
tensorflow中利用tf.Variable创建变量并进行初始化,而pytorch中使用torch.tensor创建变量并进行初始化,如下图所示。
在TensorFlow的世界里,变量的定义和初始化是分开的,所有关于图变量的赋值和计算都要通过tf.Session的run来进行。
sess.run([G_solver, G_loss_temp, MSE_loss], feed_dict = {X: X_mb, M: M_mb, H: H_mb})
而在pytorch中,并不需要通过run进行,赋值完了直接计算即可。
pytorch运算时要创建完的numpy数组转为tensor,如下:
if use_gpu is True: X_mb = torch.tensor(X_mb, device="cuda") M_mb = torch.tensor(M_mb, device="cuda") H_mb = torch.tensor(H_mb, device="cuda") else: X_mb = torch.tensor(X_mb) M_mb = torch.tensor(M_mb) H_mb = torch.tensor(H_mb)
最后运行完还要将tensor数据类型转换回numpy数组:
if use_gpu is True: imputed_data=imputed_data.cpu().detach().numpy() else: imputed_data=imputed_data.detach().numpy()
而tensorflow中不需要这种操作。
在tensorflow中包含诸多函数是pytorch中没有的,但是都可以在其他库中找到类似,具体如下表所示。
tensorflow中函数 | pytorch中代替(所在库) | 参数区别 |
---|---|---|
tf.sqrt | np.sqrt(numpy) | 完全相同 |
tf.random_normal | np.random.normal(numpy) | tf.random_normal(shape = size, stddev = xavier_stddev) np.random.normal(size = size, scale = xavier_stddev) |
tf.concat | torch.cat(torch) | inputs = tf.concat(values = [x, m], axis = 1) inputs = torch.cat(dim=1, tensors=[x, m]) |
tf.nn.relu | F.relu(torch.nn.functional) | 完全相同 |
tf.nn.sigmoid | torch.sigmoid(torch) | 完全相同 |
tf.matmul | torch.matmul(torch) | 完全相同 |
tf.reduce_mean | torch.mean(torch) | 完全相同 |
tf.log | torch.log(torch) | 完全相同 |
tf.zeros | np.zeros | 完全相同 |
tf.train.AdamOptimizer | torch.optim.Adam(torch) | optimizer_D = tf.train.AdamOptimizer().minimize(D_loss, var_list=theta_D) optimizer_D = torch.optim.Adam(params=theta_D) |
免责声明:本站发布的内容(图片、视频和文字)以原创、转载和分享为主,文章观点不代表本网站立场,如果涉及侵权请联系站长邮箱:mmqy2019@163.com进行举报,并提供相关证据,查实之后,将立刻删除涉嫌侵权内容。
猜你喜欢
今天看到了一个比较诡异的写法,for后直接跟了else语句,起初还以为是没有缩进好,查询后发现果然有这种语法,特此分享。之前写过c++和Java,在for后接else还是第一次见。
这篇文章主要介绍了python中的opencv 图像分割与提取,图像中将前景对象作为目标图像分割或者提取出来。对背景本身并无兴趣分水岭算法及GrabCut算法对图像进行分割及提取。具体实现过程需要的朋友可以参考下面文章详细介绍
制表符也属于“写法是两个字符的组合,但含义上只是一个字符”的情形。它的写法是“\t”,是反斜杠和t字母的组合,t取的是table之意。它的含义是一个字符,叫做制表符。它的作用是对齐表格数据的各列。
这篇文章主要介绍了python编程开发之类型转换convert用法,结合实例形式分析了Python中常见的数据类型及类型转换convert的具体使用方法,需要的朋友可以参考下
不少朋友应该都有玩过2048游戏吧,就是合并和消除数字的一款游戏。那么我们如果使用python,怎么写一个2048游戏呢?下面就给大家分享使用Python实现2048游戏代码,感兴趣的朋友可以参考学习。
成为群英会员,开启智能安全云计算之旅
立即注册Copyright © QY Network Company Ltd. All Rights Reserved. 2003-2020 群英 版权所有
增值电信经营许可证 : B1.B2-20140078 粤ICP备09006778号 域名注册商资质 粤 D3.1-20240008