我有一个TFRecord数据集,其中每个记录包含19个uint8,后跟36个字节,代表9个小尾数float32。
前19个字节是示例,后10个浮点数是标签。我只想相应地重新解释这些字节。
当我索引这样的字符串时,Tensorflow不喜欢它:
def parse(serialized):
return tf.decode_raw(serialized[0:19], tf.uint8), tf.decode_raw(serialized[19:], tf.float32)
*** ValueError:使用输入昏暗0索引超出范围;输入的'strided_slice'(op:'StridedSlice')的调光只有0个,输入形状为[], [1],[1],[1]并具有计算的输入张量:input [3] = <1>。
接下来,我尝试将序列化解释为字节,然后重新解释切片:
def parse(serialized):
expanded = tf.decode_raw(serialized, tf.uint8)
return tf.cast(expanded[0:19], tf.uint8), tf.cast(expanded[19:], tf.float32)
这适用于uint8,但对于float32,它将每个字节解释为自己的float32:
(Pdb) sess.run(label_it)
array([205., 204., 204., 60., 154., 153., 153., 60., 102., 102., 166.,
61., 10., 215., 35., 60., 184., 30., 45., 63., 51., 51.,
51., 61., 133., 235., 81., 61., 92., 143., 66., 61., 164.,
112., 61., 61.], dtype=float32)
我只是真的想将此字符串在第19个字节和第20个字节之间分成两部分,然后对其进行解码。但是,我有一段时间了。在这一点上,我很高兴听到这样做的任何解决方案,尽管很棘手。
谢谢!
答案 0 :(得分:0)
嘿,我认为问这个问题可以帮助我向自己澄清这个问题,并且能够找到答案。我的答案是tf.strings.substr()
def parse(serialized):
example_str = tf.strings.substr(serialized, 0, 19)
label_str = tf.strings.substr(serialized, 19, -1)
return tf.decode_raw(example_str, tf.uint8), tf.decode_raw(label_str, tf.float32)
但是我很乐意看到更好的发布方式。