如何在Keras MLP中训练单个示例

时间:2018-10-19 12:16:06

标签: python tensorflow neural-network keras

所以我正在基于代理的建模中,每个代理中都有MLP类型的网络,但是我有局限性,代理一次只能有一个示例,因此代理只能训练一个示例

所以当我尝试使用keras fit函数进行训练时,出现了一系列错误

model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(128,activation=tf.nn.tanh))
model.add(tf.keras.layers.Dense(128,activation=tf.nn.tanh))
model.add(tf.keras.layers.Dense(1,activation=tf.nn.tanh))
model.compile(optimizer='SGD',
            loss='mean_squared_error',validation_split=0)

model.fit(x_train[0][np.newaxis,:,:],np.array([y_train[0]]),epochs=3,batch_size=1)

我只想一次训练一个例子 但出现以下错误

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-81-688643bc66bc> in <module>()
----> 1 model.fit(x_train[0][np.newaxis,:,:],np.array([y_train[0]]),epochs=3,batch_size=1)

~/anaconda3/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, **kwargs)
   1261         steps_name='steps_per_epoch',
   1262         steps=steps_per_epoch,
-> 1263         validation_split=validation_split)
   1264 
   1265     # Prepare validation data.

~/anaconda3/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py in _standardize_user_data(self, x, y, sample_weight, class_weight, batch_size, check_steps, steps_name, steps, validation_split)
    905           feed_output_shapes,
    906           check_batch_axis=False,  # Don't enforce the batch size.
--> 907           exception_prefix='target')
    908 
    909       # Generate sample-wise weight values given the `sample_weight` and

~/anaconda3/lib/python3.6/site-packages/tensorflow/python/keras/engine/training_utils.py in standardize_input_data(data, names, shapes, check_batch_axis, exception_prefix)
    189                 'Error when checking ' + exception_prefix + ': expected ' +
    190                 names[i] + ' to have shape ' + str(shape) +
--> 191                 ' but got array with shape ' + str(data_shape))
    192   return data
    193 

ValueError: Error when checking target: expected dense_5 to have shape (10,) but got array with shape (1,)

0 个答案:

没有答案