Keras:合并两个具有不同输入的模型,并使用fit_generator来训练合并模型

时间:2017-11-22 07:49:02

标签: keras

我想将两个模型与不同模型合并,并使用fit_generator来训练合并模型。而发电机就是我自己。 这是发电机之一。

def image_generator(self, batch_size, train_test, data_type, concat=False):
    train, test = self.split_train_test()
    data = train if train_test == 'train' else test

    print("Creating %s generator with %d samples." % (train_test, len(data)))
    print ("image_generator")

    while 1:
        X, y = [], []

        # Generate batch_size samples.
        for _ in range(batch_size):
            # Reset to be safe.
            sequence = None

            # Get a random sample.
            sample = random.choice(data)

            # Check to see if we've already saved this sequence.
            if data_type is "images":
                # Get and resample frames.
                frames = self.get_frames_for_sample(sample)
                frames = self.rescale_list(frames, self.seq_length)

                # Build the image sequence
                sequence = self.build_image_sequence(frames)
            else:
                # Get the sequence from disk.
                sequence = self.get_image_sequence(data_type, sample, train_test)
            if sequence is None:
                print("Can't find sequence. Did you generate them?")
                sys.exit()  # TODO this should raise

            if concat:
                # We want to pass the sequence back as a single array. This
                # is used to pass into an MLP rather than an RNN.
                sequence = np.concatenate(sequence).ravel()

            X.append(sequence)
            y.append(self.get_class_one_hot(sample[1]))

        yield np.array(X), np.array(y)

这是get_image_sequences:

def get_image_sequence(self, data_type, sample, train_test):
    """get the images shaped with array."""
    # train,ApplyEyeMakeup,v_ApplyEyeMakeup_g10_c02,99
    num = random.randint(1, int(sample[3]))
    path = glob.glob('./data/' + train_test + '/' + sample[1] + '/' + sample[2] + '-' + '*' + num + '.jpg')
    if os.path.isfile(path):
        img = Image.open(path)
        if img.size != target_size:
            img = img.resize(target_size)
        img = img_to_array(img)
        img = np.expand_dims(img, axis=0)
        img /= 255
        return img
    else:
        print ("path is error" + path)
        return None

现在,合并并适应它:

modeltmp = merge([model1.output, model2.output], mode='concat', concat_axis=1)
modeltmp = BatchNormalization()(modeltmp)
modeltmp = Dense(1024, activation='relu')(modeltmp)
modeltmp = Dense(len(classes), activation='softmax')(modeltmp)

model = Model(input=[model1.input, model2.input], outputs=modeltmp)

     # model1 --- generator
train_gen_1 = data.image_generator(batch_size, 'train', cnn_lstm_datatype, concat)
test_gen_1 = data.image_generator(batch_size, 'test', cnn_lstm_datatype, concat)

     # model2 ---- generator
train_gen_2 = data.frame_generator(batch_size=batch_size, train_test='train', data_type=cnn_lstm_datatype, concat=concat)
test_gen_2 = data.frame_generator(batch_size=batch_size, train_test='test', data_type=cnn_lstm_datatype, concat=concat)
    model.fit_generator([train_gen_1, train_gen_2],
                    verbose=1,
                    steps_per_epoch=batch_size,
                    validation_steps=10,
                    epochs=10000,
                    callbacks=[checkpointer, tb, early_stopper, csv_logger],
                    validation_data=[test_gen_1, test_gen_2]
                    )

然而,我收到错误: TypeError:检查模型输入时出错:数据应该是Numpy数组,或Numpy数组的列表/字典。找到:生成器对象image_generator在0x12205df00 ...
我该如何解决?谢谢!

0 个答案:

没有答案
相关问题