当我打印“ a”和“ b”的形状时,它显示完全相同。但是,tf.map_fn传递“ b”时会引发错误。 有什么线索吗? 非常感谢
import tensorflow as tf
def distance(elem):
# return tf.norm(elem[0] - elem[1], ord='euclidean')
return elem
tf.enable_eager_execution()
a = tf.constant([[1, 2, 3, 4], [1, 2, 3, 4]])
b = [[1, 2, 3, 4], [1, 2, 3, 4]]
print(tf.shape(a)) # it prints tf.Tensor([2 4], shape=(2,), dtype=int32)
print(tf.shape(b)) # it print -> tf.Tensor([2 4], shape=(2,), dtype=int32)
print(tf.map_fn(distance, a, tf.int32)) # works
print(tf.map_fn(distance, b, tf.int32)) # ValueError: elements in elems must be 1+ dimensional Tensors, not scalars