使用keras的LSTM的A3C

时间:2018-03-30 08:09:06

标签: deep-learning lstm reinforcement-learning

我试图使用keras使用LSTM实现A3C模型,我使用此版本的A3C而不使用LSTM:" https://github.com/coreylynch/async-rl",并尝试仅修改网络代码,但我很难编译整个模型:

我错过了什么吗?

这是我的模特:

state = tf.placeholder("float", [None, agent_history_length, resized_width, resized_height])

vision_model = Sequential()
vision_model.add(Conv2D(activation="relu", filters=16, kernel_size=(8, 8), name="conv1", padding="same", strides=(4, 4),input_shape=(agent_history_length,resized_width, resized_height)))
vision_model.add(Conv2D(activation="relu", filters=32, kernel_size=(4, 4), name="conv2", padding="same", strides=(2, 2)))
vision_model.add(Flatten())
vision_model.add(Dense(activation="relu", units=256, name="h1"))

# Now let's get a tensor with the output of our vision model:

state_input = Input(shape=(1,agent_history_length,resized_width,resized_height))

encoded_frame_sequence = TimeDistributed(vision_model)(state_input)
encoded_video = LSTM(256)(encoded_frame_sequence)  # the output will be a vector

action_probs = Dense(activation="softmax", units=4, name="p")(encoded_video)
state_value = Dense(activation="linear", units=1, name="v")(encoded_video)

policy_network = Model(inputs=state_input, outputs=action_probs)
value_network = Model(inputs=state_input, outputs=state_value)

p_params = policy_network.trainable_weights
v_params = value_network.trainable_weights

policy_network.summary()
value_network.summary()

p_out = policy_network(state_input)
v_out = value_network(state_input)

1 个答案:

答案 0 :(得分:0)

keras-rl示例lib不支持超过2D输入形状!