TensorFlow`py_func`的输出具有未知的等级/形状

时间:2017-03-04 00:02:10

标签: tensorflow

我正在尝试在TensorFlow中创建一个简单的神经网络。唯一棘手的部分是我使用py_func实现了自定义操作。当我将输出从py_func传递到Dense图层时,TensorFlow会抱怨应该知道等级。具体错误是:

ValueError: Inputs to `Dense` should have known rank.

当我通过py_func传递数据时,我不知道如何保留数据的形状。我的问题是如何获得正确的形状?我在下面有一个简单的例子来说明问题。

def my_func(x):
    return np.sinh(x).astype('float32')

inp = tf.convert_to_tensor(np.arange(5))
y = tf.py_func(my_func, [inp], tf.float32, False)

with tf.Session() as sess:
    with sess.as_default():
        print(inp.shape)
        print(inp.eval())
        print(y.shape)
        print(y.eval())

此代码段的输出为:

(5,)
[0 1 2 3 4]
<unknown>
[  0.       
1.17520118   3.62686038  10.01787472  27.28991699]

为什么y.shape <unknown>?我希望形状(5,)inp相同。谢谢!

1 个答案:

答案 0 :(得分:28)

由于py_func可以执行任意Python代码并输出任何内容,因此TensorFlow无法找出形状(它需要分析函数体的Python代码)您可以手动赋予形状

y.set_shape(inp.get_shape())