Keras 模型不兼容的形状[自定义图层]

时间:2021-05-03 04:17:26

标签: tensorflow machine-learning keras

我不明白我的模型能够输出模型摘要,但是当我想通过编写 model.predict 来测试它是否真的有效时,它向我显示了 Incompatible shapes 错误(如下所述) .

我的模型

def PVT():

  # Inputs
  input = layers.Input(shape=input_shape)
  augment = data_augmentation(input)

  # Stage 1
  patches_1 = Patch(patch_size_1)(augment)
  patches_1 = PatchEncoder(num_patches=(image_size // patch_size_1) ** 2, projection_dim=projection_dim)(patches_1)
  for _ in range(transformer_layers):
      x1 = layers.LayerNormalization(epsilon=1e-6)(patches_1)
      attention_output = layers.MultiHeadAttention(
          num_heads = num_heads, key_dim = projection_dim, dropout = 0.1 
      )(x1,x1)
      x2 = attention_output + patches_1
      x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
      x3 = mlp(x3, transformer_units, 0.2)
      encoded_patches = x3 + x2

  representation = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
  input_2 = layers.Reshape([image_size // patch_size_1, image_size // patch_size_1, 64])(representation)

  # Stage 2
  patches_2 = Patch(patch_size_2)(input_2)
  patches_2 = PatchEncoder(num_patches=(image_size // patch_size_2) ** 2, projection_dim=projection_dim)(patches_2)


  for _ in range(transformer_layers):
      x1 = layers.LayerNormalization(epsilon=1e-6)(patches_2)
      attention_output = layers.MultiHeadAttention(
          num_heads = num_heads, key_dim = projection_dim, dropout = 0.1 
      )(x1,x1)
      x2 = attention_output + patches_2
      x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
      x3 = mlp(x3, transformer_units, 0.2)
      encoded_patches = x3 + x2
  encoded_patches.shape

  representation = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
  input_3 = layers.Reshape([image_size // patch_size_2, image_size // patch_size_2, 64])(representation)

  # Stage 3
  patches_3 = Patch(patch_size_3)(input_3)
  patches_3 = PatchEncoder(num_patches=(image_size // patch_size_3) ** 2, projection_dim=projection_dim)(patches_3)
  for _ in range(transformer_layers):
      x1 = layers.LayerNormalization(epsilon=1e-6)(patches_3)
      attention_output = layers.MultiHeadAttention(
          num_heads = num_heads, key_dim = projection_dim, dropout = 0.1 
      )(x1,x1)
      x2 = attention_output + patches_3
      x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
      x3 = mlp(x3, transformer_units, 0.2)
      encoded_patches = x3 + x2

  representation = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
  input_4 = layers.Reshape([image_size // patch_size_3, image_size // patch_size_3, 64])(representation)

  # Stage 4
  patches_4 = Patch(patch_size_4)(input_4)
  patches_4 = PatchEncoder(num_patches=(image_size // patch_size_4) ** 2, projection_dim=projection_dim)(patches_4)
  for _ in range(transformer_layers):
      x1 = layers.LayerNormalization(epsilon=1e-6)(patches_4)
      attention_output = layers.MultiHeadAttention(
          num_heads = num_heads, key_dim = projection_dim, dropout = 0.1 
      )(x1,x1)
      x2 = attention_output + patches_4
      x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
      x3 = mlp(x3, transformer_units, 0.2)
      encoded_patches = x3 + x2

  representation = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
  input_5 = layers.Reshape([image_size // patch_size_4, image_size // patch_size_4, 64])(representation)

  representation = layers.Flatten()(input_5)
  representation = layers.Dropout(0.5)(representation)
  # Classify outputs.
  logits = layers.Dense(num_classes)(representation)
  Create the Keras model.  
  model = keras.Model(inputs=input, outputs=logits)

  return model

问题

---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-20-3e4aa14aa594> in <module>()
----> 1 model.predict(xtrain)

5 frames
/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
     58     ctx.ensure_initialized()
     59     tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
---> 60                                         inputs, attrs, num_outputs)
     61   except core._NotOkStatusException as e:
     62     if name is not None:

InvalidArgumentError:  Incompatible shapes: [32,4,64] vs. [81,64]
     [[node model_2/patch_encoder_11/add (defined at <ipython-input-6-93bf719690a9>:12) ]] [Op:__inference_predict_function_48830]

Function call stack:
predict_function

模型总结(仅供参考)

Model: "model_2"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_4 (InputLayer)            [(None, 32, 32, 3)]  0                                            
__________________________________________________________________________________________________
data_augmentation (Sequential)  (None, 72, 72, 3)    7           input_4[0][0]                    
__________________________________________________________________________________________________
patch_10 (Patch)                (None, None, 48)     0           data_augmentation[3][0]          
__________________________________________________________________________________________________
patch_encoder_10 (PatchEncoder) (None, 324, 64)      23872       patch_10[0][0]                   
__________________________________________________________________________________________________
layer_normalization_124 (LayerN (None, 324, 64)      128         patch_encoder_10[0][0]           
__________________________________________________________________________________________________
multi_head_attention_60 (MultiH (None, 324, 64)      82944       layer_normalization_124[0][0]    
                                                                 layer_normalization_124[0][0]    
__________________________________________________________________________________________________
tf.__operators__.add_119 (TFOpL (None, 324, 64)      0           multi_head_attention_60[0][0]    
                                                                 patch_encoder_10[0][0]           
__________________________________________________________________________________________________
layer_normalization_125 (LayerN (None, 324, 64)      128         tf.__operators__.add_119[0][0]   
__________________________________________________________________________________________________
dense_131 (Dense)               (None, 324, 128)     8320        layer_normalization_125[0][0]    
__________________________________________________________________________________________________
dropout_120 (Dropout)           (None, 324, 128)     0           dense_131[0][0]                  
__________________________________________________________________________________________________
dense_132 (Dense)               (None, 324, 64)      8256        dropout_120[0][0]                
__________________________________________________________________________________________________
dropout_121 (Dropout)           (None, 324, 64)      0           dense_132[0][0]                  
__________________________________________________________________________________________________
tf.__operators__.add_120 (TFOpL (None, 324, 64)      0           dropout_121[0][0]                
                                                                 tf.__operators__.add_119[0][0]   
__________________________________________________________________________________________________
layer_normalization_126 (LayerN (None, 324, 64)      128         tf.__operators__.add_120[0][0]   
__________________________________________________________________________________________________
reshape_5 (Reshape)             (None, 18, 18, 64)   0           layer_normalization_126[0][0]    
__________________________________________________________________________________________________
patch_11 (Patch)                (None, None, 4096)   0           reshape_5[0][0]                  
__________________________________________________________________________________________________
patch_encoder_11 (PatchEncoder) (None, 81, 64)       267392      patch_11[0][0]                   
__________________________________________________________________________________________________
layer_normalization_145 (LayerN (None, 81, 64)       128         patch_encoder_11[0][0]           
__________________________________________________________________________________________________
multi_head_attention_70 (MultiH (None, 81, 64)       82944       layer_normalization_145[0][0]    
                                                                 layer_normalization_145[0][0]    
__________________________________________________________________________________________________
tf.__operators__.add_139 (TFOpL (None, 81, 64)       0           multi_head_attention_70[0][0]    
                                                                 patch_encoder_11[0][0]           
__________________________________________________________________________________________________
layer_normalization_146 (LayerN (None, 81, 64)       128         tf.__operators__.add_139[0][0]   
__________________________________________________________________________________________________
dense_152 (Dense)               (None, 81, 128)      8320        layer_normalization_146[0][0]    
__________________________________________________________________________________________________
dropout_140 (Dropout)           (None, 81, 128)      0           dense_152[0][0]                  
__________________________________________________________________________________________________
dense_153 (Dense)               (None, 81, 64)       8256        dropout_140[0][0]                
__________________________________________________________________________________________________
dropout_141 (Dropout)           (None, 81, 64)       0           dense_153[0][0]                  
__________________________________________________________________________________________________
tf.__operators__.add_140 (TFOpL (None, 81, 64)       0           dropout_141[0][0]                
                                                                 tf.__operators__.add_139[0][0]   
__________________________________________________________________________________________________
layer_normalization_147 (LayerN (None, 81, 64)       128         tf.__operators__.add_140[0][0]   
__________________________________________________________________________________________________
reshape_6 (Reshape)             (None, 9, 9, 64)     0           layer_normalization_147[0][0]    
__________________________________________________________________________________________________
patch_12 (Patch)                (None, None, 16384)  0           reshape_6[0][0]                  
__________________________________________________________________________________________________
patch_encoder_12 (PatchEncoder) (None, 16, 64)       1049664     patch_12[0][0]                   
__________________________________________________________________________________________________
layer_normalization_166 (LayerN (None, 16, 64)       128         patch_encoder_12[0][0]           
__________________________________________________________________________________________________
multi_head_attention_80 (MultiH (None, 16, 64)       82944       layer_normalization_166[0][0]    
                                                                 layer_normalization_166[0][0]    
__________________________________________________________________________________________________
tf.__operators__.add_159 (TFOpL (None, 16, 64)       0           multi_head_attention_80[0][0]    
                                                                 patch_encoder_12[0][0]           
__________________________________________________________________________________________________
layer_normalization_167 (LayerN (None, 16, 64)       128         tf.__operators__.add_159[0][0]   
__________________________________________________________________________________________________
dense_173 (Dense)               (None, 16, 128)      8320        layer_normalization_167[0][0]    
__________________________________________________________________________________________________
dropout_160 (Dropout)           (None, 16, 128)      0           dense_173[0][0]                  
__________________________________________________________________________________________________
dense_174 (Dense)               (None, 16, 64)       8256        dropout_160[0][0]                
__________________________________________________________________________________________________
dropout_161 (Dropout)           (None, 16, 64)       0           dense_174[0][0]                  
__________________________________________________________________________________________________
tf.__operators__.add_160 (TFOpL (None, 16, 64)       0           dropout_161[0][0]                
                                                                 tf.__operators__.add_159[0][0]   
__________________________________________________________________________________________________
layer_normalization_168 (LayerN (None, 16, 64)       128         tf.__operators__.add_160[0][0]   
__________________________________________________________________________________________________
reshape_7 (Reshape)             (None, 4, 4, 64)     0           layer_normalization_168[0][0]    
__________________________________________________________________________________________________
patch_13 (Patch)                (None, None, 65536)  0           reshape_7[0][0]                  
__________________________________________________________________________________________________
patch_encoder_13 (PatchEncoder) (None, 4, 64)        4194624     patch_13[0][0]                   
__________________________________________________________________________________________________
layer_normalization_187 (LayerN (None, 4, 64)        128         patch_encoder_13[0][0]           
__________________________________________________________________________________________________
multi_head_attention_90 (MultiH (None, 4, 64)        82944       layer_normalization_187[0][0]    
                                                                 layer_normalization_187[0][0]    
__________________________________________________________________________________________________
tf.__operators__.add_179 (TFOpL (None, 4, 64)        0           multi_head_attention_90[0][0]    
                                                                 patch_encoder_13[0][0]           
__________________________________________________________________________________________________
layer_normalization_188 (LayerN (None, 4, 64)        128         tf.__operators__.add_179[0][0]   
__________________________________________________________________________________________________
dense_194 (Dense)               (None, 4, 128)       8320        layer_normalization_188[0][0]    
__________________________________________________________________________________________________
dropout_180 (Dropout)           (None, 4, 128)       0           dense_194[0][0]                  
__________________________________________________________________________________________________
dense_195 (Dense)               (None, 4, 64)        8256        dropout_180[0][0]                
__________________________________________________________________________________________________
dropout_181 (Dropout)           (None, 4, 64)        0           dense_195[0][0]                  
__________________________________________________________________________________________________
tf.__operators__.add_180 (TFOpL (None, 4, 64)        0           dropout_181[0][0]                
                                                                 tf.__operators__.add_179[0][0]   
__________________________________________________________________________________________________
layer_normalization_189 (LayerN (None, 4, 64)        128         tf.__operators__.add_180[0][0]   
__________________________________________________________________________________________________
reshape_8 (Reshape)             (None, 2, 2, 64)     0           layer_normalization_189[0][0]    
__________________________________________________________________________________________________
flatten_2 (Flatten)             (None, 256)          0           reshape_8[0][0]                  
__________________________________________________________________________________________________
dropout_182 (Dropout)           (None, 256)          0           flatten_2[0][0]                  
__________________________________________________________________________________________________
dense_196 (Dense)               (None, 100)          25700       dropout_182[0][0]                
==================================================================================================
Total params: 5,960,875
Trainable params: 5,960,868
Non-trainable params: 7
___________________________

0 个答案:

没有答案