TensorFlow:将字符串解析为两个张量

时间:2019-05-25 02:34:19

标签: tensorflow tensorflow-datasets

我有一个TFRecord数据集,其中每个记录包含19个uint8,后跟36个字节,代表9个小尾数float32。

前19个字节是示例,后10个浮点数是标签。我只想相应地重新解释这些字节。

当我索引这样的字符串时,Tensorflow不喜欢它:

def parse(serialized):
    return tf.decode_raw(serialized[0:19], tf.uint8), tf.decode_raw(serialized[19:], tf.float32)

  

*** ValueError:使用输入昏暗0索引超出范围;输入的'strided_slice'(op:'StridedSlice')的调光只有0个,输入形状为[],   [1],[1],[1]并具有计算的输入张量:input [3] = <1>。

接下来,我尝试将序列化解释为字节,然后重新解释切片:

def parse(serialized):
    expanded = tf.decode_raw(serialized, tf.uint8)

    return tf.cast(expanded[0:19], tf.uint8), tf.cast(expanded[19:], tf.float32)

这适用于uint8,但对于float32,它将每个字节解释为自己的float32:

(Pdb) sess.run(label_it)
array([205., 204., 204.,  60., 154., 153., 153.,  60., 102., 102., 166.,
        61.,  10., 215.,  35.,  60., 184.,  30.,  45.,  63.,  51.,  51.,
        51.,  61., 133., 235.,  81.,  61.,  92., 143.,  66.,  61., 164.,
       112.,  61.,  61.], dtype=float32)

我只是真的想将此字符串在第19个字节和第20个字节之间分成两部分,然后对其进行解码。但是,我有一段时间了。在这一点上,我很高兴听到这样做的任何解决方案,尽管很棘手。

谢谢!

1 个答案:

答案 0 :(得分:0)

嘿,我认为问这个问题可以帮助我向自己澄清这个问题,并且能够找到答案。我的答案是tf.strings.substr()

def parse(serialized):
    example_str = tf.strings.substr(serialized, 0, 19)
    label_str = tf.strings.substr(serialized, 19, -1)

    return tf.decode_raw(example_str, tf.uint8), tf.decode_raw(label_str, tf.float32)

但是我很乐意看到更好的发布方式。