我做了一个非常简单的神经网络,旨在进行强化学习。但是,我无法预测任何事情,因为尝试进行预测时会出错。
有问题的错误:
检查输入时出错:预期density_203_input具有形状(1202),但具有形状为(1,)的数组
有问题的模型:
def _build_compile_model(self):
model = Sequential()
model.add(Dense(300, activation='relu', input_dim=1202))
model.add(Dense(300, activation='relu'))
model.add(Dense(200, activation='relu'))
model.add(Dense(self._action_size, activation='softmax'))
model.compile(loss='mse', optimizer=self._optimizer)
return model
调用model.predict(state)时会发生错误,其中state是形状数组(1202,1)。
完整的错误消息是
ValueError Traceback (most recent call last)
<ipython-input-148-06b7a01facef> in <module>
18 new_state, reward = env.step(action, new_demand_a, new_demand_b) # Take action, get new state and reward
19 new_state = np.reshape(new_state, [1202, -1])
---> 20 agent.update(old_state, new_state, action, reward) # Let the agent update internal
21 average_reward.append(reward) # Keep score
22 if i % 100 == 0 and i != 0: # Print out metadata every 100th iteration
<ipython-input-145-142ae54ce43f> in update(self, old_state, new_state, action, reward)
49 def update(self, old_state, new_state, action, reward):
50 print(old_state.shape)
---> 51 target = self.q_network.predict(old_state)
52 t = self.target_network.predict(new_state)
53 target[0][action] = reward + self.gamma * np.amax(t)
/opt/conda/envs/tensorflow2/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py in predict(self, x, batch_size, verbose, steps, callbacks, max_queue_size, workers, use_multiprocessing)
1011 max_queue_size=max_queue_size,
1012 workers=workers,
-> 1013 use_multiprocessing=use_multiprocessing)
1014
1015 def reset_metrics(self):
/opt/conda/envs/tensorflow2/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_v2.py in predict(self, model, x, batch_size, verbose, steps, callbacks, max_queue_size, workers, use_multiprocessing, **kwargs)
496 model, ModeKeys.PREDICT, x=x, batch_size=batch_size, verbose=verbose,
497 steps=steps, callbacks=callbacks, max_queue_size=max_queue_size,
--> 498 workers=workers, use_multiprocessing=use_multiprocessing, **kwargs)
499
500
/opt/conda/envs/tensorflow2/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_v2.py in _model_iteration(self, model, mode, x, y, batch_size, verbose, sample_weight, steps, callbacks, max_queue_size, workers, use_multiprocessing, **kwargs)
424 max_queue_size=max_queue_size,
425 workers=workers,
--> 426 use_multiprocessing=use_multiprocessing)
427 total_samples = _get_total_number_of_samples(adapter)
428 use_sample = total_samples is not None
/opt/conda/envs/tensorflow2/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_v2.py in _process_inputs(model, mode, x, y, batch_size, epochs, sample_weights, class_weights, shuffle, steps, distribution_strategy, max_queue_size, workers, use_multiprocessing)
644 standardize_function = None
645 x, y, sample_weights = standardize(
--> 646 x, y, sample_weight=sample_weights)
647 elif adapter_cls is data_adapter.ListsOfScalarsDataAdapter:
648 standardize_function = standardize
/opt/conda/envs/tensorflow2/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py in _standardize_user_data(self, x, y, sample_weight, class_weight, batch_size, check_steps, steps_name, steps, validation_split, shuffle, extract_tensors_from_dataset)
2381 is_dataset=is_dataset,
2382 class_weight=class_weight,
-> 2383 batch_size=batch_size)
2384
2385 def _standardize_tensors(self, x, y, sample_weight, run_eagerly, dict_inputs,
/opt/conda/envs/tensorflow2/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py in _standardize_tensors(self, x, y, sample_weight, run_eagerly, dict_inputs, is_dataset, class_weight, batch_size)
2408 feed_input_shapes,
2409 check_batch_axis=False, # Don't enforce the batch size.
-> 2410 exception_prefix='input')
2411
2412 # Get typespecs for the input data and sanitize it if necessary.
/opt/conda/envs/tensorflow2/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_utils.py in standardize_input_data(data, names, shapes, check_batch_axis, exception_prefix)
580 ': expected ' + names[i] + ' to have shape ' +
581 str(shape) + ' but got array with shape ' +
--> 582 str(data_shape))
583 return data
584
ValueError: Error when checking input: expected dense_211_input to have shape (1202,) but got array with shape (1,)
答案 0 :(得分:1)
在模型上输入输入时,有两种方法:
第一个选项:使用 input_shape
state = np.array(np.ones((BATCH_SIZE,1202,1)))
print("Input Rank: {}".format(tf.rank(state))) # Check for the Rank of Input
此处输入形状为 2D ,但是您应为网络提供 3D 输入(等级3 ),因为您需要添加 batch_size 。
示例输入:
model_dim.add(Dense(300, activation='relu', input_dim=1202))
第二个选项:使用 input_dim
state = np.array(np.ones((1,1202,)))
print("Input Rank: {}".format(tf.rank(state))) # Check for the Rank of Input
此处输入形状为 1D ,但您应为网络提供 2D 输入(等级2 ),因为您需要添加 batch_size 。
示例输入:
{{1}}