带有图像的CNN多输入列车

时间:2020-07-09 18:40:33

标签: python api tensorflow keras cnn

我想用多个输入创建一个cnn,并通过以下方式识别6种人类活动:

enter image description here

每个活动都有一个培训和验证文件夹:

from tensorflow.keras.preprocessing.image import ImageDataGenerator

train_batches1 = ImageDataGenerator().flow_from_directory(directory=train_dir1, target_size=(224,224), classes=['applauding', 'blowing_bubbles'], batch_size=10, class_mode='categorical')

valid_batches1 = ImageDataGenerator().flow_from_directory(directory=validation_dir1, target_size=(224,224), classes=['applauding', 'blowing_bubbles'], batch_size=10, class_mode='categorical')

train_batches2 = ImageDataGenerator().flow_from_directory(directory=train_dir2, target_size=(224,224), classes=['brushing_teeth', 'cleaning_the_floor'], batch_size=10, class_mode='categorical')

valid_batches2 = ImageDataGenerator().flow_from_directory(directory=validation_dir2, target_size=(224,224), classes=['brushing_teeth', 'cleaning_the_floor'], batch_size=10, class_mode='categorical')

train_batches3 = ImageDataGenerator().flow_from_directory(directory=train_dir3, target_size=(224,224), classes=['climbing','writing'], batch_size=10, class_mode='categorical')

valid_batches3 = ImageDataGenerator().flow_from_directory(directory=validation_dir3, target_size=(224,224), classes=['climbing','writing'],batch_size=10, class_mode='categorical')

输入

Datos_red_1 = Input(shape=(224,224,3), name ='Datos_red_1')

Datos_red_2 = Input(shape=(224,224,3), name ='Datos_red_2')

Datos_red_3 = Input(shape=(224,224,3), name ='Datos_red_3')

网络1

conv1 = Conv2D(32, kernel_size=5, activation='relu')(Datos_red_1)

pool1 = MaxPooling2D(pool_size=(2,2))(conv1)

flat1 = Flatten()(pool1)

网络2

conv2 = Conv2D(32, kernel_size=5, activation='relu')(Datos_red_2)

pool2 = MaxPooling2D(pool_size=(2,2))(conv2)

flat2 = Flatten()(pool2)

网络3

conv3 = Conv2D(32, kernel_size=5, activation='relu')(Datos_red_3)

pool3 = MaxPooling2D(pool_size=(2,2))(conv3)

flat3 = Flatten()(pool3)

连接(concatenacion)

con = concatenate([flat1,flat2,flat3])

Salida = Dense(6, activation='softmax', name='Salida')(con)

型号

model = Model(inputs=[Datos_red_1, Datos_red_2, Datos_red_3], outputs=[Salida])

enter image description here

编译

model.compile(

    optimizer=keras.optimizers.RMSprop(),  

    loss=keras.losses.SparseCategoricalCrossentropy(),

    metrics=[keras.metrics.SparseCategoricalAccuracy()],
)

火车

inputs_train = [train_batches1, train_batches2, train_batches3]

history = model.fit(

    inputs_train,    

    epochs=5,

    validation_data=(valid_batches1, valid_batches2, valid_batches3) 
)

但是在训练中会出现以下错误:

enter image description here

请帮助我。

0 个答案:

没有答案