当从tf.py_func
调用的函数返回的numpy字符串数组将具有固定的字符串长度,其尾随'\ x00'字符而不是“自然”变量字符串长度而没有填充。
以下是一个例子:
import tensorflow as tf
import numpy as np
def main():
def foo(x):
a = np.asarray(['abc', 'd'], dtype=np.string_)
return a
with tf.Session() as sess:
f = tf.py_func(foo, [tf.constant(1)], (tf.string))
f = tf.Print(f, [f, tf.shape(f)])
actual = sess.run(f)
print actual
打印出来:
[abc d\000\000][2]
我使用的一个小解决方法是:
f = tf.string_split(f, delimiter='\x00', skip_empty=True).values
这是TF问题还是我做错了什么?
答案 0 :(得分:2)
返回列表而不是ndarray似乎有效:
def foo(x):
a = [['abc', 'd']]
return a
答案 1 :(得分:0)
你应该改变
a = np.asarray(['abc', 'd'], dtype=np.string_)
到
a = np.asarray(['abc', 'd'], dtype=object)