读写tfrecords二进制文件(类型missmatch)

时间:2017-01-13 18:07:00

标签: image-processing binary tensorflow

您好我正在尝试构建图像输入管道。我的预处理培训数据存储在我用以下代码行创建的tfrecords文件中:

def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

...

img_raw = img.tostring()                                        # typeof(img) = np.Array with shape (50, 80) dtype float64
img_label_text_raw = str.encode(img_lable)
example = tf.train.Example(features=tf.train.Features(feature={
    'height': _int64_feature(height),                           #heigth (integer)
    'width': _int64_feature(width),                             #width (integer)
    'depth': _int64_feature(depth),                             #num of rgb channels (integer)
    'image_data': _bytes_feature(img_raw),                      #raw image data (byte string)
    'label_text': _bytes_feature(img_label_text_raw),           #raw image_lable_text (byte string)
    'lable': _int64_feature(lable_txt_to_int[img_lable])}))     #label index (integer)

writer.write(example.SerializeToString())

现在我尝试读取二进制数据以重构其中的张量:

def read_and_decode(filename_queue):
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)

    features = tf.parse_single_example(
        serialized_example,
        # Defaults are not specified since both keys are required.
        features={
            'label': tf.FixedLenFeature([], tf.int64),
            'height': tf.FixedLenFeature([], tf.int64),
            'width': tf.FixedLenFeature([], tf.int64),
            'depth': tf.FixedLenFeature([], tf.int64),
            'image_data': tf.FixedLenFeature([], tf.string)
        })

    label = features['label']
    height = tf.cast(features['height'], tf.int64)
    width = tf.cast(features['width'], tf.int64)
    depth = tf.cast(features['depth'], tf.int64)

    image_shape = tf.pack([height, width, depth])
    image = tf.decode_raw(features['image_data'], tf.float64)
    image = tf.reshape(image, image_shape)


    images, labels = tf.train.shuffle_batch([image, label],     batch_size=2,
                                                 capacity=30,
                                                 num_threads=1,
                                                 min_after_dequeue=10)
    return images, labels

可悲的是,这不起作用。我得到这个错误消息:

  

ValueError:Tensor转换请求Tensor的dtype字符串   dtype int64:' Tensor(" ParseSingleExample / Squeeze_label:0",shape =(),   D型= int64类型)'   ...

     

TypeError:输入'字节' ' DecodeRaw' Op的类型为int64,与预期的字符串类型不匹配。

有些人可以给我一些如何解决这个问题的提示吗?

提前致谢!

更新:完整的代码清单" read_and_decode"

@mmry非常感谢你。现在我的代码打破了批处理。用:

  

ValueError:必须完全定义所有形状:   [TensorShape([Dimension(None),Dimension(None),Dimension(None)]),   TensorShape([])]

有什么建议吗?

1 个答案:

答案 0 :(得分:4)

此行中无需使用tf.decode_raw()操作:

label = tf.decode_raw(features['label'], tf.int64)

相反,你应该能够写:

label = features['label']

tf.decode_raw() op只接受tf.string张量,并将一些张量数据的二进制表示(作为可变长度字符串)转换为类型表示(作为特定类型元素的向量) )。但是,您已将功能'label'定义为类型tf.int64,因此如果您要将其用作tf.int64,则无需转换该功能。