当从tf.py_func返回numpy字符串数组时,在Tensor中尾随'\ x00'字符

时间:2018-03-07 21:15:09

标签: python tensorflow

当从tf.py_func调用的函数返回的numpy字符串数组将具有固定的字符串长度,其尾随'\ x00'字符而不是“自然”变量字符串长度而没有填充。

以下是一个例子:

import tensorflow as tf
import numpy as np

def main():
    def foo(x):
        a = np.asarray(['abc', 'd'], dtype=np.string_)
        return a

    with tf.Session() as sess:

        f = tf.py_func(foo, [tf.constant(1)], (tf.string))
        f = tf.Print(f, [f, tf.shape(f)])

        actual = sess.run(f)
        print actual

打印出来:

[abc d\000\000][2]

我使用的一个小解决方法是:

f = tf.string_split(f, delimiter='\x00', skip_empty=True).values

这是TF问题还是我做错了什么?

2 个答案:

答案 0 :(得分:2)

返回列表而不是ndarray似乎有效:

def foo(x): 
  a = [['abc', 'd']]
  return a

答案 1 :(得分:0)

你应该改变

a = np.asarray(['abc', 'd'], dtype=np.string_)

a = np.asarray(['abc', 'd'], dtype=object)