如何在Keras的Model.fit方法中传递validation_data

时间:2019-08-07 05:32:24

标签: python keras lstm

我在调用model.fit方法时收到以下异常。

ValueError                                Traceback (most recent call last)
<ipython-input-30-a7c25bd01b61> in <module>()
      8            cv_x.school_state.values,cv_x.teacher_prefix.values,cv_x.project_grade_category.values,
      9            cv_x.clean_categories.values,cv_x.clean_subcategories.values],
---> 10            cv_y))

6 frames
/usr/local/lib/python3.6/dist-packages/keras/engine/training_utils.py in standardize_input_data(data, names, shapes, check_batch_axis, exception_prefix)
    100                 'Expected to see ' + str(len(names)) + ' array(s), '
    101                 'but instead got the following list of ' +
--> 102                 str(len(data)) + ' arrays: ' + str(data)[:200] + '...')
    103         elif len(names) > 1:
    104             raise ValueError(

ValueError: Error when checking model input: the list of Numpy arrays that you are passing to your model is not the size the model expected. Expected to see 8 array(s), but instead got the following list of 1 arrays: [array([[  48,   24,    2, ...,    0,    0,    0],
       [  40,  787,  310, ...,    0,    0,    0],
       [   4,    5, 1474, ...,    0,    0,    0],
       ...,
       [1725, 2095,  716, ...,    0, ...
model = Model(inputs=input_a, outputs=[output])
model.compile(optimizer='rmsprop', loss='binary_crossentropy')
model.fit([padded_docs_train,train_x.price,train_x.teacher_number_of_previously_posted_projects,
           train_x.school_state,train_x.teacher_prefix,train_x.project_grade_category,
           train_x.clean_categories,train_x.clean_subcategories], [train_y],epochs=12, batch_size=1000,
          callbacks = [print_auc,tensorboard],
          validation_data = ([padded_docs_cv,cv_x.price,cv_x.teacher_number_of_previously_posted_projects,
           cv_x.school_state,cv_x.teacher_prefix,cv_x.project_grade_category,
           cv_x.clean_categories,cv_x.clean_subcategories],[cv_y]))

如果相同的格式适用于输入参数,为什么会给validation_data例外。 以及如何正确传递validation_data中的参数。

input_a中的输入层如下:

[<tf.Tensor 'text_input:0' shape=(?, 500) dtype=float32>,
 <tf.Tensor 'price_input:0' shape=(?, 1) dtype=int64>,
 <tf.Tensor 'no_projects:0' shape=(?, 1) dtype=int64>,
 <tf.Tensor 'input_school_state:0' shape=(?, 1) dtype=int32>,
 <tf.Tensor 'input_teacher_prefix:0' shape=(?, 1) dtype=int32>,
 <tf.Tensor 'input_project_grade_category:0' shape=(?, 1) dtype=int32>,
 <tf.Tensor 'input_clean_categories:0' shape=(?, 1) dtype=int32>,
 <tf.Tensor 'input_clean_subcategories:0' shape=(?, 1) dtype=int32>]

0 个答案:

没有答案