我拥有一个包含Batchnormalization层的网络单词,并且使用Keras.backends.function更新了该网络单词。但是它没有用。
我这样构建作品。
observation_input = Input(shape=(1,) + env.observation_space.shape, name='observation_input')
flattened_observation = Flatten()(observation_input)
x = Dense(64)(flattened_observation)
x = Activation('relu')(x)
x = BatchNormalization()(x)
x = Dense(32)(x)
x = Activation('relu')(x)
xa = Dense(2)(x)
x_a = Activation('tanh')(xa)
xp = Dense(1)(x)
x_p = Activation('sigmoid')(xp)
x_out = Concatenate()([x_a, x_p])
actor = Model(inputs=[observation_input], outputs=[x_out])
我这样训练网络:
self.actor_train_fn = K.function(actor.inputs + [K.learning_phase()][self.actor(actor.inputs)], updates=updates)
错误消息显示为
File "H:/UAV/UAV/ddpf—with-warm-start.py", line 121, in <module>
history = agent.fit(env, nb_steps=5e6, visualize=False, log_interval=1000, verbose=2, nb_max_episode_steps=2000)
File "H:\UAV\UAV\rl\core.py", line 201, in fit
metrics = self.backward(reward, terminal=done)
File "H:\UAV\UAV\rl\agents\ddpg.py", line 327, in backward
action_values = self.actor_train_fn(inputs)[0]
File "E:\anaconda\envs\deeplearning\lib\site-packages\keras\backend\tensorflow_backend.py", line 2715, in __call__
return self._call(inputs)
File "E:\anaconda\envs\deeplearning\lib\site-packages\keras\backend\tensorflow_backend.py", line 2675, in _call
fetched = self._callable_fn(*array_vals)
File "E:\anaconda\envs\deeplearning\lib\site-packages\tensorflow\python\client\session.py", line 1439, in __call__
run_metadata_ptr)
File "E:\anaconda\envs\deeplearning\lib\site-packages\tensorflow\python\framework\errors_impl.py", line 528, in __exit__
c_api.TF_GetCode(self.status.status))
tensorflow.python.framework.errors_impl.InvalidArgumentError: You must feed a value for placeholder tensor 'observation_input' with dtype float and shape [?,1,19]
[[{{node observation_input}} = Placeholder[dtype=DT_FLOAT, shape=[?,1,19], _device="/job:localhost/replica:0/task:0/device:GPU:0"]()]]
[[{{node model_1/concatenate_1/concat/_491}} = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device_incarnation=1, tensor_name="edge_1216_model_1/concatenate_1/concat", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]