图像中的张量流非常慢

时间:2016-12-24 09:54:36

标签: python tensorflow

我正试图围绕着自定义数据集。 Tensorflow tutorial提供了一个对象,我可以在其中请求另一批训练数据并将其输入,因此我尝试做类似的事情。我创建了一个数据集对象ds,它公开了函数get_alphabet,它使用WholeFileReader读取png,并使用tf.image.decode_png来获取Tensor。它返回一个Tensors列表作为数据(train_x),一个张量列表作为one_hot标签(train_y)。问题是评估这些图像花费了大量的时间,最终达到了我在等待10秒后看到打印“1”后打印出“2”的程度。我如何加快速度?

coord = tf.train.Coordinator()
for i in range(NUM_ITERATIONS):
    train_x, train_y = ds.get_alphabet(BATCH_SIZE) 
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    inputs = np.ndarray(shape=(BATCH_SIZE, FEATURE_SIZE), dtype=int)
    outputs = np.ndarray(shape=(BATCH_SIZE, NUM_LABELS), dtype=int)

    print('1')
    for image in range(BATCH_SIZE):
        inputs[image] = train_x[image].eval(session=sess).flatten()
        outputs[image] = train_y[image].eval(session=sess).flatten()
    print('2')

    if i % 10 == 0:
        current_accuracy = accuracy.eval(session=sess, feed_dict={ x: inputs, y: outputs })
        print("accuracy of {0:.2f}% on {1}th iteration".format(current_accuracy * 100, i))
    if i == 0:
        print(inputs.shape)
        print(outputs)

coord.request_stop()
coord.join(threads)

1 个答案:

答案 0 :(得分:0)

将TFRecord与新的数据集API(作为TF 1.4版本的一部分发布)一起使用以解决您的问题

TFRecords:

-TensorFlow's own binary format
-Its format uses a mixture of its Records format and and Protocol Buffers(or Protobuf)
-Record is simply a binary file that contains tf.train.Example Protobuf Objects
-Protocol Buffers, it as a way to serialize data structures, given some schema describing what the data is.

TFRecords的优势:

-They make better use of disk cache.
-They are faster to move around.
-They can store data of different types (so you can put both images and labels in one place)


首先将所有内容转换为TFRecords(即图像和相应标签都转换为TFrecord)文件,然后您可以使用TFRecordDataset API从TFrecord文件批量检索数据

创建TFrecords的步骤

1)使用tf.python_io.TFRecordWriter打开tfrecords文件并开始写入
2)在写入tfrecords文件之前,应将图像数据和标签数据转换为正确的数据类型(如浮点数)
3)现在数据类型转换为tf.train.Feature
4)最后使用tf.Example创建一个示例协议缓冲区,并将转换后的特征用于其中 5)使用serialize()函数序列化Example 6)写出序列化的例子

例如代码:

from matplotlib.image import imread
tfrecord_filename = 'TFrecord.tfrecord'

writer = tf.python_io.TFRecordWriter(tfrecord_filename)
sess = tf.Session()
input_files = _get_file_names() # get list of image file names
for input_file in input_files:
    img = imread(input_file)
    #tf.image.resize_image_with_crop_or_pad - used to resize image of different size
    img = sess.run(tf.image.resize_image_with_crop_or_pad(img, 
                  target_height=480, target_width=640))
    label = <get_correspondinglabels>
    input_feature = { 'label': _int64_feature(label),
                  'image': _bytes_feature(img.tostring()) }
    feature = tf.train.Features(feature=input_feature)
    example = tf.train.Example(features=feature)
    writer.write(example.SerializeToString())
writer.close()

def _int64_feature(value):
     return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _bytes_feature(value):
     return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

从TFRecords阅读的步骤:

1)读取TFRecord文件名
2)通过提供TFrecord文件名
来创建TFRecordDataset 3)创建解码的Parse函数,并在输入数据中进行任何预处理工作 4)使用先前步骤中创建的数据集创建批处理,重复(没有纪元)和改组 5)创建迭代器以批量获取所需的输入(即小批量)

例如代码:

def input_model_function():
    dataset = tf.data.TFRecordDataset(tfrecord_filename)
    dataset = dataset.map(_parse_function)
    dataset = dataset.batch(20)# you can use any number of batching
    iterator = dataset.make_one_shot_iterator()
    sess = tf.Session()
    batch_images, batch_labels = sess.run(iterator.get_next())
return {'x':batch_images}, batch_labels

def _parse_function(example_proto):
    parsed_features = tf.parse_single_example(example_proto, tfrecord_features)
    # Get the image as raw bytes.
    image_raw = parsed_features['image']
    # Decode the raw bytes so it becomes a tensor with type.
    image = tf.decode_raw(image_raw, tf.uint8)
    # The type is now uint8 but we need it to be float.
    image = tf.cast(image, tf.float32)
    #Get the label associated with the image.
    label = parsed_features['label']
return image, label

最后将批量数据集输入模型(使用任何预制估算器或自定义估算器API创建)