tensorflow conv2d 不兼容的形状

时间:2021-04-17 07:17:22

标签: keras tensorflow2.0

class ResnetIdentityBlock(tf.keras.Model):
  def __init__(self, kernel_size, filters):
    super(ResnetIdentityBlock, self).__init__(name='')
    filters1, filters2, filters3 = filters
    print("filters1=", filters1)

    self.conv2a = tf.keras.layers.Conv2D(filters1, kernel_size=kernel_size, use_bias=False)
    self.bn2a = tf.keras.layers.BatchNormalization()

    self.conv2b = tf.keras.layers.Conv2D(filters2, kernel_size=kernel_size, padding='same',use_bias=False)
    self.bn2b = tf.keras.layers.BatchNormalization()

    print("filters3=", filters3)
    self.conv2c = tf.keras.layers.Conv2D(filters3, kernel_size=kernel_size,use_bias=False )
    self.bn2c = tf.keras.layers.BatchNormalization()

  def call(self, input_tensor, training=False):
    print("input_tensor shape", input_tensor.shape)
    x = self.conv2a(input_tensor)
    x = self.bn2a(x, training=training)
    x = tf.nn.relu(x)

    x = self.conv2b(x)
    x = self.bn2b(x, training=training)
    x = tf.nn.relu(x)

    x = self.conv2c(x)
    x = self.bn2c(x, training=training)

    x += input_tensor
    return tf.nn.relu(x)


block = ResnetIdentityBlock(1, [1, 2, 3])

_ = block(tf.zeros([1, 5, 5, 4]))

输出为:

input_tensor shape (1, 5, 5, 4)
---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-200-df6af2fb50f9> in <module>
----> 1 _ = block(tf.zeros([1, 5, 5, 4]))

/usr/local/anaconda3/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer.py in __call__(self, *args, **kwargs)
   1010         with autocast_variable.enable_auto_cast_variables(
   1011             self._compute_dtype_object):
-> 1012           outputs = call_fn(inputs, *args, **kwargs)
   1013 
   1014         if self._activity_regularizer:

<ipython-input-199-94943f82e58b> in call(self, input_tensor, training)
     28     x = self.bn2c(x, training=training)
     29 
---> 30     x += input_tensor
     31     return tf.nn.relu(x)
     32 

/usr/local/anaconda3/lib/python3.8/site-packages/tensorflow/python/ops/math_ops.py in binary_op_wrapper(x, y)
   1162     with ops.name_scope(None, op_name, [x, y]) as name:
   1163       try:
-> 1164         return func(x, y, name=name)
   1165       except (TypeError, ValueError) as e:
   1166         # Even if dispatching the op failed, the RHS may be a tensor aware

/usr/local/anaconda3/lib/python3.8/site-packages/tensorflow/python/util/dispatch.py in wrapper(*args, **kwargs)
    199     """Call target, and fall back on dispatchers if there is a TypeError."""
    200     try:
--> 201       return target(*args, **kwargs)
    202     except (TypeError, ValueError):
    203       # Note: convert_to_eager_tensor currently raises a ValueError, not a

/usr/local/anaconda3/lib/python3.8/site-packages/tensorflow/python/ops/math_ops.py in _add_dispatch(x, y, name)
   1484     return gen_math_ops.add(x, y, name=name)
   1485   else:
-> 1486     return gen_math_ops.add_v2(x, y, name=name)
   1487 
   1488 

/usr/local/anaconda3/lib/python3.8/site-packages/tensorflow/python/ops/gen_math_ops.py in add_v2(x, y, name)
    470       return _result
    471     except _core._NotOkStatusException as e:
--> 472       _ops.raise_from_not_ok_status(e, name)
    473     except _core._FallbackException:
    474       pass

/usr/local/anaconda3/lib/python3.8/site-packages/tensorflow/python/framework/ops.py in raise_from_not_ok_status(e, name)
   6860   message = e.message + (" name: " + name if name is not None else "")
   6861   # pylint: disable=protected-access
-> 6862   six.raise_from(core._status_to_exception(e.code, message), None)
   6863   # pylint: enable=protected-access
   6864 

/usr/local/anaconda3/lib/python3.8/site-packages/six.py in raise_from(value, from_value)

InvalidArgumentError: Incompatible shapes: [1,5,5,3] vs. [1,5,5,4] [Op:AddV2]

为什么会出现这样的错误:

Incompatible shapes: [1,5,5,3] vs. [1,5,5,4] 

我正在运行 tensorflow 2.0

0 个答案:

没有答案