使用Tensorflow使用我自己的数据使用初始和TFrecords训练图像分类器

时间:2017-09-14 13:10:58

标签: image tensorflow classification

我按照教程介绍了如何在Github上训练自己的数据:https://github.com/tensorflow/models/tree/master/inception#how-to-construct-a-new-dataset-for-retraining。 我分割我的数据(训练和验证),创建标签建议并设法使用bazel-bin创建TFrecords。一切正常,现在我有自己的数据作为TFrecords。

现在我想从头开始使用inception-v3模型训练我的图像分类器,似乎我应该使用脚本inception_train.py,但我不确定。是对的吗 ? https://github.com/tensorflow/models/blob/master/inception/inception/inception_train.py

如果是这样,我有两个问题: 1-)如何使用我的TFrecords进行训练。如果你能告诉我一个例子会很棒。 2-)我可以在CPU上运行还是只能在GPU上运行?

非常感谢你。

1 个答案:

答案 0 :(得分:0)

尝试使用以下示例代码从您的tfrecords中读取图像和标签,

<button id="test">test</button>

现在,您还有一个名为tf.train.shuffle_batch的函数来帮助您生成执行此功能的多个CPU线程,并根据用户指定的批量大小返回图像和标签。您需要创建同步数据和培训管道,以便它们同时工作。

要回答你的第二个问题,是的,你可以单独使用CPU来训练你的模型,但它会很慢,可能需要几个小时甚至几天才能获得不错的结果。在创建初始模型之前删除import os import glob import tensorflow as tf from matplotlib import pyplot as plt def read_and_decode_file(filename_queue): # Create an instance of tf record reader reader = tf.TFRecordReader() # Read the generated filename queue _, serialized_reader = reader.read(filename_queue) # extract the features you require from the tfrecord using their corresponding key # In my example, all images were written with 'image' key features = tf.parse_single_example( serialized_reader, features={ 'image': tf.FixedLenFeature([], tf.string), 'labels': tf.FixedLenFeature([], tf.int16) }) # Extract the set of images as shown below img = features['image'] img_out = tf.image.resize_image_with_crop_or_pad(img, target_height=128, target_width=128) # Similarly extract the labels, be careful with the type label = features['labels'] return img_out, label if __name__ == "__main__": tf.reset_default_graph() # Path to your tfrecords path_to_tf_records = os.getcwd() + '/*.tfrecords' # Collect all tfrecords present in the records folder using glob list_of_tfrecords = sorted(glob.glob(path_to_tf_records)) # Generate a tensorflow readable filename queue by supplying it with # a list of tfrecords, optionally it is recommended to shuffle your data # before feeding into the network filename_queue = tf.train.string_input_producer(list_of_tfrecords, shuffle=False) # Supply the tensorflow generated filename queue to the custom function above image, label = read_and_decode_file(filename_queue) # Create a new tf session to read the data sess = tf.Session() tf.train.start_queue_runners(sess=sess) # Arbitrary number of iterations for i in range(50): img =sess.run(image) # Show image plt.imshow(img) 装饰器,tensorflow将在CPU上创建模型。

希望这个解释有所帮助。