3D CNN 模型抛出负维度错误——维度问题

时间:2021-06-21 12:41:44

标签: tensorflow keras deep-learning conv-neural-network tensorflow2.0

我正在创建一个高度 = 128、宽度 = 128、通道数 = 3 的 3D CNN 模型。3D CNN 的代码-

localStorage

因此,在我尝试构建模型时创建模型函数后,它会引发值错误

def get_model(width=128, height=128, depth=3):
  """
  Build a 3D convolutional neural network
  """
  inputs = tf.keras.Input((width, height, depth, 1))

  x = layers.Conv3D(filters=64, kernel_size=3, activation="relu")(inputs)
  x = layers.MaxPool3D(pool_size=2)(x)
  x = layers.BatchNormalization()(x)

  x = layers.Conv3D(filters=128, kernel_size=3, activation="relu")(x)
  x = layers.MaxPool3D(pool_size=2)(x)
  x = layers.BatchNormalization()(x)

  x = layers.Conv3D(filters=256, kernel_size=3, activation="relu")(x)
  x = layers.MaxPool3D(pool_size=2)(x)
  x = layers.BatchNormalization()(x)

  x = layers.GlobalAveragePooling3D()(x)
  x = layers.Dense(units=512, activation="relu")(x)

  x = layers.Dropout(0.3)(x)

  outputs = layers.Dense(units=4, activation='softmax')(x)

  model= keras.Model(inputs, outputs, name="3DCNN")
  return model

构建模型的代码:- #构建模型

-ValueError: Negative dimension size caused by subtracting 2 from 1 for '{{node max_pooling3d_5/MaxPool3D}} = MaxPool3D[T=DT_FLOAT, data_format="NDHWC", ksize=[1, 2, 2, 2, 1], padding="VALID", strides=[1, 2, 2, 2, 1]](Placeholder)' with input shapes: [?,126,126,1,64].

完全错误-

model = get_model(width=128, height=128, depth=3)
model.summary()

这个错误是什么意思??我的维度有问题吗??

提前致谢!!!!!!

1 个答案:

答案 0 :(得分:1)

不指定 data_format 参数,Conv3D 层将输入形状视为:

batch_shape + (conv_dim1, conv_dim2, conv_dim3, channels)

您指定为:

batch_shape + (width=128, height=128, depth=3, channels=1)

因此,您有一个数据,其形状为 (128,128,3) 并且有 1 个通道。

由于卷积操作适用于前 3 个维度,即 (128,128,3),在第一次卷积 kernel_size=3 后,第 3 个维度(您指定为 depth=3 的维度)收缩为 1。然后在下一层 (MaxPooling3D) 它不能得到 2 的池化,因为形状不适合。因此,考虑通过更大的数字更改深度维度或更改 kernel_size 参数。例如,输入形状可以是 (128,128,128,1)kernel_size 应该更改为其他类似 (3,3,1) 的内容。

P.S:如果你有一张 RGB 图像,那么通道数是 3,最后一个维度应该设置为 3。在 3D 图像中,还有另一个名为深度(另一个维度)的概念,它不同于通道。所以:

  • 3D 图像 RGB:(width, height, depth, 3)
  • 3D 图像灰度:(width, height, depth, 1)
  • 2D 图像 RGB:(width, height, 3)
  • 2D 图像灰度:(width, height, 1)
相关问题