CONV和LSTM的Keras输入形状尺寸

时间:2019-07-16 22:13:06

标签: python tensorflow keras conv-neural-network lstm

我正在使用CNN-LSTM进行图像分类,尽管出现形状尺寸错误。错误是关于从4维到5维的输入层。我在X类中将图像序列作为224x224x3尺寸。我想对它应用LSTM来为每个类别组的每个图像进行分类。

def generator(train_data_dir, img_size, batch_size, val_split):

  train_datagen = ImageDataGenerator(validation_split=val_split,rescale=1./255) # set validation split

  train_generator = train_datagen.flow_from_directory(
      train_data_dir,
      target_size=(img_size, img_size),
      batch_size=batch_size,
      color_mode='rgb',
      shuffle=True,
      subset='training') # set as training data

  validation_generator = train_datagen.flow_from_directory(
      train_data_dir, # same directory as training data
      target_size=(img_size, img_size),
      color_mode='rgb',
      shuffle=False,
      subset='validation') # set as validation data

  return train_generator, validation_generator


channels = 3
classes = 101
img_size = 224 
img_seq_length = 10


base_model = load_model(weight_file, custom_objects={'mae':mae})
last_conv_layer = base_model.get_layer("global_average_pooling2d_1")
cnn_model = Model(input=base_model.input, output=last_conv_layer.output)
cnn_model.trainable = False

seq_input = Input(shape=(img_seq_length, img_size, img_size, channels), name='seq_input')

encoded_frame = TimeDistributed(cnn_model)(seq_input)
encoded_vid = LSTM(2048)(encoded_frame) 
lstm = Dropout(0.5)(encoded_vid)

# output
predictions = Dense(101, activation='softmax', name='pred')(lstm)

model = Model(input=seq_input, output=predictions)

opt = get_optimizer('adam', lr_rate)
model.compile(loss=['categorical_crossentropy'],
    optimizer=opt,
    metrics=['accuracy',"mae"])

model.fit_generator(generator=train_gen,
                     epochs=EPOCHS,
                     steps_per_epoch = STEPS_PER_EPOCH,
                     validation_data=val_gen,
                     validation_steps = VALIDATION_STEPS,
                     verbose=1,
                     callbacks=callbacks)

我收到以下错误:

  

ValueError:检查输入时出错:预期seq_input具有5   尺寸,但数组的形状为(64,224,224,3)

0 个答案:

没有答案
相关问题