使用批量归一化构建网络并使用K.function训练时出错

时间:2019-07-10 06:55:07

标签: python tensorflow keras

我拥有一个包含Batchnormalization层的网络单词,并且使用Keras.backends.function更新了该网络单词。但是它没有用。

我这样构建作品。

    observation_input = Input(shape=(1,) + env.observation_space.shape, name='observation_input')
    flattened_observation = Flatten()(observation_input)
    x = Dense(64)(flattened_observation)
    x = Activation('relu')(x)
    x = BatchNormalization()(x)
    x = Dense(32)(x)
    x = Activation('relu')(x)
    xa = Dense(2)(x)
    x_a = Activation('tanh')(xa)
    xp = Dense(1)(x)
    x_p = Activation('sigmoid')(xp)
    x_out = Concatenate()([x_a, x_p])
    actor = Model(inputs=[observation_input], outputs=[x_out]) 

我这样训练网络:

    self.actor_train_fn = K.function(actor.inputs + [K.learning_phase()][self.actor(actor.inputs)], updates=updates)

错误消息显示为

     File "H:/UAV/UAV/ddpf—with-warm-start.py", line 121, in <module>
        history = agent.fit(env, nb_steps=5e6, visualize=False, log_interval=1000, verbose=2, nb_max_episode_steps=2000)
      File "H:\UAV\UAV\rl\core.py", line 201, in fit
        metrics = self.backward(reward, terminal=done)
      File "H:\UAV\UAV\rl\agents\ddpg.py", line 327, in backward
        action_values = self.actor_train_fn(inputs)[0]
      File "E:\anaconda\envs\deeplearning\lib\site-packages\keras\backend\tensorflow_backend.py", line 2715, in __call__
        return self._call(inputs)
      File "E:\anaconda\envs\deeplearning\lib\site-packages\keras\backend\tensorflow_backend.py", line 2675, in _call
        fetched = self._callable_fn(*array_vals)
      File "E:\anaconda\envs\deeplearning\lib\site-packages\tensorflow\python\client\session.py", line 1439, in __call__
        run_metadata_ptr)
      File "E:\anaconda\envs\deeplearning\lib\site-packages\tensorflow\python\framework\errors_impl.py", line 528, in __exit__
        c_api.TF_GetCode(self.status.status))
    tensorflow.python.framework.errors_impl.InvalidArgumentError: You must feed a value for placeholder tensor 'observation_input' with dtype float and shape [?,1,19]
         [[{{node observation_input}} = Placeholder[dtype=DT_FLOAT, shape=[?,1,19], _device="/job:localhost/replica:0/task:0/device:GPU:0"]()]]
         [[{{node model_1/concatenate_1/concat/_491}} = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device_incarnation=1, tensor_name="edge_1216_model_1/concatenate_1/concat", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]

0 个答案:

没有答案