map_fn的奇怪案例

时间:2019-07-26 20:16:35

标签: python tensorflow

当我打印“ a”和“ b”的形状时,它显示完全相同。但是,tf.map_fn传递“ b”时会引发错误。 有什么线索吗? 非常感谢

import tensorflow as tf


def distance(elem):
    # return tf.norm(elem[0] - elem[1], ord='euclidean')
    return elem


tf.enable_eager_execution()

a = tf.constant([[1, 2, 3, 4], [1, 2, 3, 4]])
b = [[1, 2, 3, 4], [1, 2, 3, 4]]

print(tf.shape(a)) # it prints tf.Tensor([2 4], shape=(2,), dtype=int32)
print(tf.shape(b)) # it print -> tf.Tensor([2 4], shape=(2,), dtype=int32)

print(tf.map_fn(distance, a, tf.int32))  # works

print(tf.map_fn(distance, b, tf.int32))  # ValueError: elements in elems must be 1+ dimensional Tensors, not scalars

0 个答案:

没有答案