检查两个Tensor形状是否相同的最佳方法是什么,包括无?

时间:2017-03-01 04:27:34

标签: numpy tensorflow

我想检查两个张量是否具有相同的形状。

假设我有一些像这样的张量:

a = tf.placeholder(tf.float32, shape=[None, 3])
b = tf.placeholder(tf.float32, shape=[None, 3])

我添加了assert a.shape == b.shape。但是,这可能是由于None而失败。确实a.shape = (?, 1)b.shape也是(?, 1)。他们看起来和我一样。

如果没有None,则可以正常工作。

a = tf.placeholder(tf.float32, shape=[1, 3])
b = tf.placeholder(tf.float32, shape=[1, 3])
assert a.shape == b.shape  # True

如何在形状检查中忽略无?

总结:

1: a = tf.placeholder(tf.float32, shape=[1, 3])
2: b = tf.placeholder(tf.float32, shape=[1, 3])
3: assert a.shape == b.shape  # True
4: 
5: a = tf.placeholder(tf.float32, shape=[None, 3])
6: b = tf.placeholder(tf.float32, shape=[None, 3])
7: assert a.shape == b.shape  # False

我希望第7行的断言为True。

1 个答案:

答案 0 :(得分:4)

您可以使用a.shape.as_list() == b.shape.as_list()比较两个tf.TensorShape对象的“相等”。但是,在执行此操作时应该小心,因为如果两个形状在同一位置包含None,那么具有这些形状的张量不能保证在该维度中具有相同的大小。

(能够在batch_size中表示tf.TensorShape之类的“符号”维度会很有用,这会使等式测试更有用。我们正在考虑允许API扩展这在TensorFlow的未来版本中。)