TensorFlow字符串:它们是什么以及如何使用它们

时间:2016-08-11 17:14:04

标签: python string numpy tensorflow tfrecord

当我使用tf.read_file读取文件时,我得到类型为tf.string的内容。文档仅表示它是"可变长度字节数组。 Tensor的每个元素都是一个字节数组。" (https://www.tensorflow.org/versions/r0.10/resources/dims_types.html)。我不知道如何解释这个。

我对这种类型无能为力。在通常的python中,您可以通过索引获取元素my_string[:4],但是当我运行以下代码时,我会收到错误。

import tensorflow as tf
import numpy as np

x = tf.constant("This is string")
y = x[:4]


init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
result = sess.run(y)
print result

它说

  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/tensor_shape.py", line 621, in assert_has_rank
    raise ValueError("Shape %s must have rank %d" % (self, rank))
ValueError: Shape () must have rank 1

此外,我无法将我的字符串转换为tf.float32张量。它是.flo文件,它有魔术标题" PIEH"。这个numpy代码成功地将这样的头转换为数字(参见这里的例子https://stackoverflow.com/a/28016469/4744283),但我不能用tensorflow做到这一点。我试过tf.string_to_number(string, out_type=tf.float32),但它说

tensorflow.python.framework.errors.InvalidArgumentError: StringToNumberOp could not correctly convert string: PIEH

那么,字符串是什么?它的形状是什么?我怎么能至少得到一部分字符串?我想如果我可以参与其中,我可以跳过" PIEH"一部分。

UPD :我忘了说tf.slice(string, [0], [4])也没有处理同样的错误。

1 个答案:

答案 0 :(得分:15)

与Python不同,可以将字符串视为切片等字符列表,TensorFlow的tf.string是不可分割的值。例如,下面的xTensor,其形状为(2,),其每个元素都是一个可变长度的字符串。

x = tf.constant(["This is a string", "This is another string"])

但是,要实现您的目标,TensorFlow会提供tf.decode_raw运算符。它需要tf.string张量作为输入,但可以将字符串解码为任何其他原始数据类型。例如,要将字符串解释为字符张量,可以执行以下操作:

x = tf.constant("This is string")
x = tf.decode_raw(x, tf.uint8)
y = x[:4]
sess = tf.InteractiveSession()
print(y.eval())
# prints [ 84 104 105 115]