如何获得TensorFlow字符串的长度?

时间:2016-07-05 21:55:42

标签: tensorflow

有没有办法在TensorFlow中获取TensorFlow字符串的长度?例如,是否有任何函数将a = tf.constant("Hello everyone", tf.string)的长度返回为14而不将字符串传递回Python。

3 个答案:

答案 0 :(得分:5)

这对我有用:

x = tf.constant("Hello everyone")

# Launch the default graph.
with tf.Session() as sess:
    print(tf.size(tf.string_split([x],"")).eval())

答案 1 :(得分:1)

TensorFlow版本0.9不存在此类功能。但是,您可以使用tf.py_func在TensorFlow张量上运行任意Python函数。以下是获取TensorFlow字符串长度的一种方法:

def string_length(t):
  return tf.py_func(lambda p: [len(x) for x in p], [t], [tf.int64])[0]

a = tf.constant(["Hello everyone"], tf.string)
sess = tf.InteractiveSession()
sess.run(string_length(a))

答案 2 :(得分:1)

另一个次优选项是将字符串转换为稀疏字符串:

strings = ['Why hello','world','!']
chars = tf.string_split(strings,"")

然后计算每行的最大索引+1

line_number = chars.indices[:,0]
line_position = chars.indices[:,1]
lengths = tf.segment_max(data = line_position, 
                         segment_ids = line_number) + 1

with tf.Session() as sess:
    print(lengths.eval())

[9 5 1]