我按照教程介绍了如何在Github上训练自己的数据:https://github.com/tensorflow/models/tree/master/inception#how-to-construct-a-new-dataset-for-retraining。 我分割我的数据(训练和验证),创建标签建议并设法使用bazel-bin创建TFrecords。一切正常,现在我有自己的数据作为TFrecords。
现在我想从头开始使用inception-v3模型训练我的图像分类器,似乎我应该使用脚本inception_train.py,但我不确定。是对的吗 ? https://github.com/tensorflow/models/blob/master/inception/inception/inception_train.py。
如果是这样,我有两个问题: 1-)如何使用我的TFrecords进行训练。如果你能告诉我一个例子会很棒。 2-)我可以在CPU上运行还是只能在GPU上运行?
非常感谢你。
答案 0 :(得分:0)
尝试使用以下示例代码从您的tfrecords中读取图像和标签,
<button id="test">test</button>
现在,您还有一个名为tf.train.shuffle_batch的函数来帮助您生成执行此功能的多个CPU线程,并根据用户指定的批量大小返回图像和标签。您需要创建同步数据和培训管道,以便它们同时工作。
要回答你的第二个问题,是的,你可以单独使用CPU来训练你的模型,但它会很慢,可能需要几个小时甚至几天才能获得不错的结果。在创建初始模型之前删除import os
import glob
import tensorflow as tf
from matplotlib import pyplot as plt
def read_and_decode_file(filename_queue):
# Create an instance of tf record reader
reader = tf.TFRecordReader()
# Read the generated filename queue
_, serialized_reader = reader.read(filename_queue)
# extract the features you require from the tfrecord using their corresponding key
# In my example, all images were written with 'image' key
features = tf.parse_single_example(
serialized_reader, features={
'image': tf.FixedLenFeature([], tf.string),
'labels': tf.FixedLenFeature([], tf.int16)
})
# Extract the set of images as shown below
img = features['image']
img_out = tf.image.resize_image_with_crop_or_pad(img, target_height=128, target_width=128)
# Similarly extract the labels, be careful with the type
label = features['labels']
return img_out, label
if __name__ == "__main__":
tf.reset_default_graph()
# Path to your tfrecords
path_to_tf_records = os.getcwd() + '/*.tfrecords'
# Collect all tfrecords present in the records folder using glob
list_of_tfrecords = sorted(glob.glob(path_to_tf_records))
# Generate a tensorflow readable filename queue by supplying it with
# a list of tfrecords, optionally it is recommended to shuffle your data
# before feeding into the network
filename_queue = tf.train.string_input_producer(list_of_tfrecords, shuffle=False)
# Supply the tensorflow generated filename queue to the custom function above
image, label = read_and_decode_file(filename_queue)
# Create a new tf session to read the data
sess = tf.Session()
tf.train.start_queue_runners(sess=sess)
# Arbitrary number of iterations
for i in range(50):
img =sess.run(image)
# Show image
plt.imshow(img)
装饰器,tensorflow将在CPU上创建模型。
希望这个解释有所帮助。