我正试图围绕着自定义数据集。 Tensorflow tutorial提供了一个对象,我可以在其中请求另一批训练数据并将其输入,因此我尝试做类似的事情。我创建了一个数据集对象ds,它公开了函数get_alphabet,它使用WholeFileReader读取png,并使用tf.image.decode_png来获取Tensor。它返回一个Tensors列表作为数据(train_x),一个张量列表作为one_hot标签(train_y)。问题是评估这些图像花费了大量的时间,最终达到了我在等待10秒后看到打印“1”后打印出“2”的程度。我如何加快速度?
coord = tf.train.Coordinator()
for i in range(NUM_ITERATIONS):
train_x, train_y = ds.get_alphabet(BATCH_SIZE)
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
inputs = np.ndarray(shape=(BATCH_SIZE, FEATURE_SIZE), dtype=int)
outputs = np.ndarray(shape=(BATCH_SIZE, NUM_LABELS), dtype=int)
print('1')
for image in range(BATCH_SIZE):
inputs[image] = train_x[image].eval(session=sess).flatten()
outputs[image] = train_y[image].eval(session=sess).flatten()
print('2')
if i % 10 == 0:
current_accuracy = accuracy.eval(session=sess, feed_dict={ x: inputs, y: outputs })
print("accuracy of {0:.2f}% on {1}th iteration".format(current_accuracy * 100, i))
if i == 0:
print(inputs.shape)
print(outputs)
coord.request_stop()
coord.join(threads)
答案 0 :(得分:0)
-TensorFlow's own binary format
-Its format uses a mixture of its Records format and and Protocol Buffers(or Protobuf)
-Record is simply a binary file that contains tf.train.Example Protobuf Objects
-Protocol Buffers, it as a way to serialize data structures, given some schema describing what the data is.
-They make better use of disk cache.
-They are faster to move around.
-They can store data of different types (so you can put both images and labels in one place)
首先将所有内容转换为TFRecords(即图像和相应标签都转换为TFrecord)文件,然后您可以使用TFRecordDataset API从TFrecord文件批量检索数据
创建TFrecords的步骤
1)使用tf.python_io.TFRecordWriter打开tfrecords文件并开始写入
2)在写入tfrecords文件之前,应将图像数据和标签数据转换为正确的数据类型(如浮点数)
3)现在数据类型转换为tf.train.Feature
4)最后使用tf.Example创建一个示例协议缓冲区,并将转换后的特征用于其中
5)使用serialize()函数序列化Example
6)写出序列化的例子
例如代码:
from matplotlib.image import imread
tfrecord_filename = 'TFrecord.tfrecord'
writer = tf.python_io.TFRecordWriter(tfrecord_filename)
sess = tf.Session()
input_files = _get_file_names() # get list of image file names
for input_file in input_files:
img = imread(input_file)
#tf.image.resize_image_with_crop_or_pad - used to resize image of different size
img = sess.run(tf.image.resize_image_with_crop_or_pad(img,
target_height=480, target_width=640))
label = <get_correspondinglabels>
input_feature = { 'label': _int64_feature(label),
'image': _bytes_feature(img.tostring()) }
feature = tf.train.Features(feature=input_feature)
example = tf.train.Example(features=feature)
writer.write(example.SerializeToString())
writer.close()
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
从TFRecords阅读的步骤:
1)读取TFRecord文件名
2)通过提供TFrecord文件名
来创建TFRecordDataset
3)创建解码的Parse函数,并在输入数据中进行任何预处理工作
4)使用先前步骤中创建的数据集创建批处理,重复(没有纪元)和改组
5)创建迭代器以批量获取所需的输入(即小批量)
例如代码:
def input_model_function():
dataset = tf.data.TFRecordDataset(tfrecord_filename)
dataset = dataset.map(_parse_function)
dataset = dataset.batch(20)# you can use any number of batching
iterator = dataset.make_one_shot_iterator()
sess = tf.Session()
batch_images, batch_labels = sess.run(iterator.get_next())
return {'x':batch_images}, batch_labels
def _parse_function(example_proto):
parsed_features = tf.parse_single_example(example_proto, tfrecord_features)
# Get the image as raw bytes.
image_raw = parsed_features['image']
# Decode the raw bytes so it becomes a tensor with type.
image = tf.decode_raw(image_raw, tf.uint8)
# The type is now uint8 but we need it to be float.
image = tf.cast(image, tf.float32)
#Get the label associated with the image.
label = parsed_features['label']
return image, label
最后将批量数据集输入模型(使用任何预制估算器或自定义估算器API创建)