生成和读取tfrecord文件的操作是什么?
Admin 2021-05-22 群英技术资讯 653 次浏览
tfrecord是tensorflow中常用的数据打包格式,这篇文章给大家介绍的就是关于tfrecord文件的生成和读取,本文有具体以及步骤,具有的一定的参考价值,需要的朋友可以参考学习。
训练模型时,我们并不是直接将图像送入模型,而是先将图像转换为tfrecord文件,再将tfrecord文件送入模型。为进一步理解tfrecord文件,本例先将6幅图像及其标签转换为tfrecord文件,然后读取tfrecord文件,重现6幅图像及其标签。
1、生成tfrecord文件
import os import numpy as np import tensorflow as tf from PIL import Image filenames = [ 'images/cat/1.jpg', 'images/cat/2.jpg', 'images/dog/1.jpg', 'images/dog/2.jpg', 'images/pig/1.jpg', 'images/pig/2.jpg',] labels = {'cat':0, 'dog':1, 'pig':2} def int64_feature(values): if not isinstance(values, (tuple, list)): values = [values] return tf.train.Feature(int64_list=tf.train.Int64List(value=values)) def bytes_feature(values): return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values])) with tf.Session() as sess: output_filename = os.path.join('images/train.tfrecords') with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer: for filename in filenames: #读取图像 image_data = Image.open(filename) #图像灰度化 image_data = np.array(image_data.convert('L')) #将图像转化为bytes image_data = image_data.tobytes() #读取label label = labels[filename.split('/')[-2]] #生成protocol数据类型 example = tf.train.Example(features=tf.train.Features(feature={'image': bytes_feature(image_data), 'label': int64_feature(label)})) tfrecord_writer.write(example.SerializeToString())
2、读取tfrecord文件
import tensorflow as tf import matplotlib.pyplot as plt from PIL import Image # 根据文件名生成一个队列 filename_queue = tf.train.string_input_producer(['images/train.tfrecords']) reader = tf.TFRecordReader() # 返回文件名和文件 _, serialized_example = reader.read(filename_queue) features = tf.parse_single_example(serialized_example, features={'image': tf.FixedLenFeature([], tf.string), 'label': tf.FixedLenFeature([], tf.int64)}) # 获取图像数据 image = tf.decode_raw(features['image'], tf.uint8) # 恢复图像原始尺寸[高,宽] image = tf.reshape(image, [60, 160]) # 获取label label = tf.cast(features['label'], tf.int32) with tf.Session() as sess: # 创建一个协调器,管理线程 coord = tf.train.Coordinator() # 启动QueueRunner, 此时文件名队列已经进队 threads = tf.train.start_queue_runners(sess=sess, coord=coord) for i in range(6): image_b, label_b = sess.run([image, label]) img = Image.fromarray(image_b, 'L') plt.imshow(img) plt.axis('off') plt.show() print(label_b) # 通知其他线程关闭 coord.request_stop() # 其他所有线程关闭之后,这一函数才能返回 coord.join(threads)
以上就是关于怎样实现tfrecord文件生成与读取的操作介绍,希望文本对大家学习有帮助,想要了解更多tfrecord文件生成与读取的内容大家可以关注其他相关文章。
文本转载自脚本之家免责声明:本站发布的内容(图片、视频和文字)以原创、转载和分享为主,文章观点不代表本网站立场,如果涉及侵权请联系站长邮箱:mmqy2019@163.com进行举报,并提供相关证据,查实之后,将立刻删除涉嫌侵权内容。
猜你喜欢
python高并发是什么?某个时间段内,数据涌来,这就是并发。如果数据量很大,就是高并发。那么python高并发怎么解决呢?
本篇文章给大家带来了关于Python的相关知识,numpy里有一个非常神奇的函数叫做np.where()函数,下面这篇文章主要给大家介绍了关于Python np.where()的详解以及代码应用的相关资料,文中通过实例代码介绍的非常详细,希望对大家有帮助。
Python提供了许多操作Excel的模块,能够让我们从繁琐的工作中腾出双手。本文主要为大家介绍的是openpyxl模块,它的功能相对与其他模块更为齐全,感兴趣的小伙伴快来学习一下吧
这篇文章主要为大家介绍了python之异步编程,具有一定的参考价值,感兴趣的小伙伴们可以参考一下,希望能够给你带来帮助<BR>
本文主要介绍了Python实现批量压缩文件/文件夹zipfile的使用,文中通过示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
成为群英会员,开启智能安全云计算之旅
立即注册Copyright © QY Network Company Ltd. All Rights Reserved. 2003-2020 群英 版权所有
增值电信经营许可证 : B1.B2-20140078 粤ICP备09006778号 域名注册商资质 粤 D3.1-20240008