篇首语:本文由小编为大家整理,主要介绍了tensorflow中tfrecords使用介绍相关的知识,希望对你有一定的参考价值。
这篇文章主要讲一下如何用Tensorflow中的标准数据读取方式简单的实现对自己数据的读取操作.
主要分为以下两个步骤:(1)将自己的数据集转化为 xx.tfrecords的形式;(2):在自己的程序中读取并使用.tfrecords进行操作.
数据集转换:为了便于讲解,我们简单制作了一个数据,如下图所示:
程序:
[python] view plain copy- import tensorflow as tf
- import numpy as np
- import os
- from PIL import Image
- def _int64_feature(value):
- return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
- def _bytes_feature(value):
- return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
- def img_to_tfrecord(data_path):
- rows = 256
- cols = 256
- depth = 3
- writer = tf.python_io.TFRecordWriter("test.tfrecords")
- labelfile=open("random.txt")
- lines=labelfile.readlines()
- for line in lines:
- #print line
- img_name = line.split(" ")[0]#name
- label = line.split(" ")[1]#label
- img_path = data_path+img_name
- img = Image.open(img_path)
- img = img.resize((rows,cols))
- #img_raw = img.tostring()
- img_raw = img.tobytes()
- example = tf.train.Example(features = tf.train.Features(feature =
- "height": _int64_feature(rows),
- "weight": _int64_feature(cols),
- "depth": _int64_feature(depth),
- "image_raw": _bytes_feature(img_raw),
- "label": _bytes_feature(label)))
- writer.write(example.SerializeToString())
- writer.close()
- if __name__ == "__main__":
- current_dir = os.getcwd()
- data_path = current_dir + "/data/"
- #name = current_dir + "/data"
- print("Convert start")
- img_to_tfrecord(data_path)
- print("done!")
运行该段程序可以看到在dataset_tfrecord文件夹下面有test.tfrecord文件生成。 在TF的Session中调用这个生成的文件:
[python] view plain copy
- #encoding=utf-8
- # 设置utf-8编码,方便在程序中加入中文注释.
- import os
- import scipy.misc
- import tensorflow as tf
- import numpy as np
- from test import *
- import matplotlib.pyplot as plt
- def read_and_decode(filename_queue):
- reader = tf.TFRecordReader()
- _, serialized_example = reader.read(filename_queue)
- features = tf.parse_single_example(serialized_example,features =
- "image_raw":tf.FixedLenFeature([], tf.string))
- image = tf.decode_raw(features["image_raw"], tf.uint8)
- image = tf.reshape(image, [OUTPUT_SIZE, OUTPUT_SIZE, 3])
- image = tf.cast(image, tf.float32)
- #image = image / 255.0
- return image
- data_dir = "/home/sanyuan/dataset_animal/dataset_tfrecords/"
- filenames = [os.path.join(data_dir,"train%d.tfrecords" % ii) for ii in range(1)] #如果有多个文件,直接更改这里即可
- filename_queue = tf.train.string_input_producer(filenames)
- image = read_and_decode(filename_queue)
- with tf.Session() as sess:
- coord = tf.train.Coordinator()
- threads = tf.train.start_queue_runners(coord=coord)
- for i in xrange(2):
- img = sess.run([image])
- print(img[0].shape) # 设置batch_size等于1.每次读出来只有一张图
- plt.imshow(img[0])
- plt.show()
- coord.request_stop()
- coord.join(threads)
程序到这里就已经处理完成了,当然在decorde的过程中也是可以进行一些预处理操作的,不过建议还是在制作数据集的时候进行,TFrecord使用的是队列的方式进行读取数据,这个对于多线程操作来说还是很方便的,只需要设置好格式,每次直接读取就可以了.
以上是关于tensorflow中tfrecords使用介绍的主要内容,如果未能解决你的问题,请参考以下文章