如何在Keras中将多个输入传递给train_on_batch

时间:2019-04-25 08:38:02

标签: python reinforcement-learning keras-rl

 ValueError: could not broadcast input array from shape (60,60,2) into shape (1)

我尝试在代码中进行某种方式的修改,但仍然存在相同的错误。

  1. state.append(np.array(s))#标记1 target_f_list.append(np.array(target_f))#标记2
  2. self.model.train_on_batch([状态],[target_f_list])#标记3
  3. self.model.train_on_batch(np.array(state),np.array(target_f_list))#标记3

这是我的网络Keras:

    input_1 = Input(shape=(60, 60, 2))
    input_2 = Input(shape=(self.action_size, self.action_size))
    x1 = Conv2D(32, (4, 4), strides=(2, 2), padding='Same', activation=LeakyReLU(alpha=self.Beta))(input_1)
    x1 = Conv2D(64, (2, 2), strides=(2, 2), padding='Same', activation=LeakyReLU(alpha=self.Beta))(x1)
    x1 = Conv2D(128, (2, 2), strides=(1, 1), padding='Same', activation=LeakyReLU(alpha=self.Beta))(x1)
    x1 = Flatten()(x1)
    x1 = Dense(128, activation=LeakyReLU(alpha=self.Beta))(x1)
    x1_value = Dense(64, activation=LeakyReLU(alpha=self.Beta))(x1)
    value = Dense(1, activation=LeakyReLU(alpha=self.Beta))(x1_value)
    x1_advantage = Dense(64, activation=LeakyReLU(alpha=self.Beta))(x1)
    advantage = Dense(self.action_size, activation=LeakyReLU(alpha=self.Beta))(x1_advantage)

    A = Dot(axes=1)([input_2, advantage])
    A_subtract = Subtract()([advantage, A])

    Q_value = Add()([value, A_subtract])

    model = Model(inputs=[input_1, input_2], outputs=[Q_value])
    model.compile(optimizer=Adam(lr=self.epsilon_r), loss='mse')

这是我训练的功能:

    state = []
    target_f_list = []
    for s, a, r, next_s, done in minibatch:
        if not done:

            ... do calculate target_f ...

            state.append(s)                   # mark 1
            target_f_list.append(target_f)    # mark 2

            # this is fit function i use before and it's worked fine. But i want to train all minibatch add the same time.
            # self.model.fit(s, target_f, epochs=1, verbose=0, batch_size=self.minibatch_size)

    # This is my code has error
    self.model.train_on_batch(state,target_f_list)  # mark 3

感谢您阅读我的问题。

0 个答案:

没有答案