有没有办法在TensorFlow中获取TensorFlow字符串的长度?例如,是否有任何函数将a = tf.constant("Hello everyone", tf.string)
的长度返回为14
而不将字符串传递回Python。
答案 0 :(得分:5)
这对我有用:
x = tf.constant("Hello everyone")
# Launch the default graph.
with tf.Session() as sess:
print(tf.size(tf.string_split([x],"")).eval())
答案 1 :(得分:1)
TensorFlow版本0.9不存在此类功能。但是,您可以使用tf.py_func
在TensorFlow张量上运行任意Python函数。以下是获取TensorFlow字符串长度的一种方法:
def string_length(t):
return tf.py_func(lambda p: [len(x) for x in p], [t], [tf.int64])[0]
a = tf.constant(["Hello everyone"], tf.string)
sess = tf.InteractiveSession()
sess.run(string_length(a))
答案 2 :(得分:1)
另一个次优选项是将字符串转换为稀疏字符串:
strings = ['Why hello','world','!']
chars = tf.string_split(strings,"")
然后计算每行的最大索引+1
line_number = chars.indices[:,0]
line_position = chars.indices[:,1]
lengths = tf.segment_max(data = line_position,
segment_ids = line_number) + 1
with tf.Session() as sess:
print(lengths.eval())
[9 5 1]