我想检查两个张量是否具有相同的形状。
假设我有一些像这样的张量:
a = tf.placeholder(tf.float32, shape=[None, 3])
b = tf.placeholder(tf.float32, shape=[None, 3])
我添加了assert a.shape == b.shape
。但是,这可能是由于None而失败。确实a.shape
= (?, 1)
,b.shape
也是(?, 1)
。他们看起来和我一样。
如果没有None,则可以正常工作。
a = tf.placeholder(tf.float32, shape=[1, 3])
b = tf.placeholder(tf.float32, shape=[1, 3])
assert a.shape == b.shape # True
如何在形状检查中忽略无?
总结:
1: a = tf.placeholder(tf.float32, shape=[1, 3])
2: b = tf.placeholder(tf.float32, shape=[1, 3])
3: assert a.shape == b.shape # True
4:
5: a = tf.placeholder(tf.float32, shape=[None, 3])
6: b = tf.placeholder(tf.float32, shape=[None, 3])
7: assert a.shape == b.shape # False
我希望第7行的断言为True。
答案 0 :(得分:4)
您可以使用a.shape.as_list() == b.shape.as_list()
比较两个tf.TensorShape
对象的“相等”。但是,在执行此操作时应该小心,因为如果两个形状在同一位置包含None
,那么具有这些形状的张量不能保证在该维度中具有相同的大小。
(能够在batch_size
中表示tf.TensorShape
之类的“符号”维度会很有用,这会使等式测试更有用。我们正在考虑允许API扩展这在TensorFlow的未来版本中。)