使用存储在TFRecords文件中的高度,宽度信息来设置Tensor的形状

时间:2016-10-26 09:47:55

标签: tensorflow

我已将图片及其标签的目录转换为TFRecords文件,功能图包括image_rawlabelheightwidth和{{1} }。功能如下:

depth

现在,我想阅读这个TFRecords文件来提供输入管道。但是,由于def convert_to_tfrecords(data_samples, filename): def _int64_feature(value): return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) def _bytes_feature(value): return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) writer = tf.python_io.TFRecordWriter(filename) for fname, lb in data_samples: im = cv2.imread(fname, cv2.IMREAD_UNCHANGED) image_raw = im.tostring() feats = tf.train.Features( feature = { 'image_raw': _bytes_feature(image_raw), 'label': _int64_feature(int(lb)), 'height': _int64_feature(im.shape[0]), 'width': _int64_feature(im.shape[1]), 'depth': _int64_feature(im.shape[2]) } ) example = tf.train.Example(features=feats) writer.write(example.SerializeToString()) writer.close() 已被展平,我们需要将其重塑为原始image_raw大小。那么如何从TFRecords文件中获取[height, width, depth]heightwidth的值?似乎以下代码无法工作,因为depth是没有值的Tensor。

height

当我阅读Tensorflow的官方文档时,我发现它们通常会传到一个已知的大小,说def read_and_decode(filename_queue): reader = tf.TFRecordReader() _, serialized_example = reader.read(filename_queue) feats = { 'image_raw': tf.FixedLenFeature([], tf.string), 'label': tf.FixedLenFeature([], tf.int64), 'height': tf.FixedLenFeature([], tf.int64), 'width': tf.FixedLenFeature([], tf.int64), 'depth': tf.FixedLenFeature([], tf.int64) } features = tf.parse_single_example(serialized_example, features=feats) image = tf.decode_raw(features['image_raw'], tf.uint8) label = tf.cast(features['label'], tf.int32) height = tf.cast(features['height'], tf.int32) width = tf.cast(features['width'], tf.int32) depth = tf.cast(features['depth'], tf.int32) image = tf.reshape(image, [height, width, depth]) # <== not work image = tf.cast(image, tf.float32) * (1. / 255) - 0.5 return image, label 。但是,我不喜欢它,因为这些信息已存储到TFRecords文件中,并且手动传入固定大小无法确保大小与文件中存储的数据一致。

所有想法?

3 个答案:

答案 0 :(得分:1)

height返回的tf.parse_single_example是Tensor,获取其价值的唯一方法是在其上调用session.run()或类似内容。但是,我觉得这太过分了。

由于Tensorflow示例只是一个协议缓冲区(请参阅文档),因此您不必使用tf.parse_single_example来读取它。您可以自己解析它并直接读取您想要的形状。

您可能还会考虑在Tensorflow的github问题跟踪器上提交功能请求---我同意这个API对于这个用例似乎有些尴尬。

答案 1 :(得分:1)

'tf.reshape'函数仅接受张量,而不接受张量列表,因此您可以使用以下代码:

image = tf.reshape(image, tf.stack([height, width, depth]))

答案 2 :(得分:0)

您还可以从张量中获取numpy数组,并使用np.resize()进行整形,将尺寸作为参数传递