所以我正在基于代理的建模中,每个代理中都有MLP类型的网络,但是我有局限性,代理一次只能有一个示例,因此代理只能训练一个示例
所以当我尝试使用keras fit函数进行训练时,出现了一系列错误
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(128,activation=tf.nn.tanh))
model.add(tf.keras.layers.Dense(128,activation=tf.nn.tanh))
model.add(tf.keras.layers.Dense(1,activation=tf.nn.tanh))
model.compile(optimizer='SGD',
loss='mean_squared_error',validation_split=0)
model.fit(x_train[0][np.newaxis,:,:],np.array([y_train[0]]),epochs=3,batch_size=1)
我只想一次训练一个例子 但出现以下错误
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-81-688643bc66bc> in <module>()
----> 1 model.fit(x_train[0][np.newaxis,:,:],np.array([y_train[0]]),epochs=3,batch_size=1)
~/anaconda3/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, **kwargs)
1261 steps_name='steps_per_epoch',
1262 steps=steps_per_epoch,
-> 1263 validation_split=validation_split)
1264
1265 # Prepare validation data.
~/anaconda3/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py in _standardize_user_data(self, x, y, sample_weight, class_weight, batch_size, check_steps, steps_name, steps, validation_split)
905 feed_output_shapes,
906 check_batch_axis=False, # Don't enforce the batch size.
--> 907 exception_prefix='target')
908
909 # Generate sample-wise weight values given the `sample_weight` and
~/anaconda3/lib/python3.6/site-packages/tensorflow/python/keras/engine/training_utils.py in standardize_input_data(data, names, shapes, check_batch_axis, exception_prefix)
189 'Error when checking ' + exception_prefix + ': expected ' +
190 names[i] + ' to have shape ' + str(shape) +
--> 191 ' but got array with shape ' + str(data_shape))
192 return data
193
ValueError: Error when checking target: expected dense_5 to have shape (10,) but got array with shape (1,)