如何使用FixedLengthRecordReader读取Tensorflow中的自定义数据格式?

时间:2016-06-28 23:52:26

标签: python tensorflow

我在使用 Tensorflow 队列读取数据时正在做一些实验,我想问一个简单的案例如下。

假设我的二进制数据文件包含float值(4字节)和boolean标签(1字节)的列表。例如

3.4 true 2.1 false 0.3 true ..."

以下是我的试用版。我想削减一个4字节的部分(转换为float)和1字节的部分(转换为bool)。但是,在这种情况下,我不确定如何使用decode_raw()。如果我使用uint8,它会削减每个字节。

如果使用FixedLengthRecordReader是错误的,有什么简单的方法可以做到这一点?谁能帮我?

# Dimensions of data
data_bytes = 4
label_bytes = 1  
record_bytes = data_bytes + label_bytes

reader = tf.FixedLengthRecordReader(record_bytes=record_bytes)
key, value = reader.read(filename_queue)

value = tf.decode_raw(value, tf.uint8)
data_part = tf.cast(tf.slice(value, [0], [data_bytes]), tf.float32)
label_part = tf.cast(tf.slice(value, [data_bytes], [label_bytes]), tf.bool)     

1 个答案:

答案 0 :(得分:4)

尝试这样的事情:

data_part = tf.bitcast(tf.slice(value, [0], [data_bytes]), tf.float32)

即,使用tf.bitcast将4个uint8转换为float32,而不是使用tf.cast