我想使用OpenAI的Universe / gym训练多个代理(可能有非常不同的图形,变量......)。
我从universe-starter-agent代码开始,并调整了保护程序,以便转储.meta文件。
恢复经过培训的代理并使用它进行推理的过程看起来非常棘手。我目前正在做的事情:
class LSTMPolicy(object):
def __init__(self, ob_space, ac_space):
##ADDED NAME:##
self.x = x = tf.placeholder(tf.float32, [None] + list(ob_space), name = "input_pixels")
for i in range(4):
x = tf.nn.elu(conv2d(x, 32, "l{}".format(i + 1), [3, 3], [2, 2]))
# introduce a "fake" batch dimension of 1 after flatten so that we can do LSTM over time dim
x = tf.expand_dims(flatten(x), [0])
size = 256
lstm = rnn.rnn_cell.BasicLSTMCell(size, state_is_tuple=True)
self.state_size = lstm.state_size
step_size = tf.shape(self.x)[:1]
c_init = np.zeros((1, lstm.state_size.c), np.float32)
h_init = np.zeros((1, lstm.state_size.h), np.float32)
self.state_init = [c_init, h_init]
c_in = tf.placeholder(tf.float32, [1, lstm.state_size.c])
h_in = tf.placeholder(tf.float32, [1, lstm.state_size.h])
self.state_in = [c_in, h_in]
##ADDED NAME:##
state_in_0 = tf.identity(c_in,name = "LSTM_state_in_0")
state_in_1 = tf.identity(h_in,name = "LSTM_state_in_1")
state_in = rnn.rnn_cell.LSTMStateTuple(c_in, h_in)
lstm_outputs, lstm_state = tf.nn.dynamic_rnn(
lstm, x, initial_state=state_in, sequence_length=step_size,
time_major=False)
lstm_c, lstm_h = lstm_state
x = tf.reshape(lstm_outputs, [-1, size])
self.logits = linear(x, ac_space, "action", normalized_columns_initializer(0.01))
##ADDED NAME:##
inference_logits = tf.identity(self.logits ,name = "inference_logits")
self.vf = tf.reshape(linear(x, 1, "value", normalized_columns_initializer(1.0)), [-1])
self.state_out = [lstm_c[:1, :], lstm_h[:1, :]]
self.sample = categorical_sample(self.logits, ac_space)[0, :]
self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, tf.get_variable_scope().name)

import numpy as np
import tensorflow as tf
save_path = "universe_dumps/logs/pong_model_test/train/model.ckpt-0"
input_frame = [np.random.rand(42,42,1).astype(float)]
initial_state = np.zeros([1,256])
tf.reset_default_graph()
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=True)) as sess, sess.as_default():
new_saver = tf.train.import_meta_graph(save_path+".meta",clear_devices=True)
new_saver.restore(sess, save_path)
variables = tf.get_collection(tf.GraphKeys.VARIABLES, scope='global')
W = variables[0]
graph = variables[0].graph #Is this the way to get a pointer to the graph??
inference_op = graph.get_operation_by_name("global/inference_logits")
state_in_0 = graph.get_tensor_by_name("global/LSTM_state_in_0")
state_in_1 = graph.get_tensor_by_name("global/LSTM_state_in_1")
feed_dict = {"global/input_pixels:0": input_frame, state_in_0: initial_state, state_in_1: initial_state}
logits, W = sess.run([inference_op, W], feed_dict = feed_dict)
print(logits)
print(W.shape)

到目前为止,我尝试过的所有内容都给了我一个空的列表返回logits,当前的片段(上面)给了我错误: "名称' global / LSTM_state_in_0'是指操作,而不是张量。"
即使我开始使用它,每次我更改代理程序图表的架构时,我都必须调整代码...所以我认为必须有一个更简单的方法吗?谁可以帮助我在这里?
理想情况下,我想要一个函数load_model(path_to_model) 启动一个会话,加载所需的一切,返回一个对象,我可以调用某种类型的.predict方法,以便我可以提供一个numpy数组(42,42,1)并从中获取logits训练有素的特工...
答案 0 :(得分:0)
以下几行似乎有问题:
state_in_0 = graph.get_tensor_by_name("global/LSTM_state_in_0")
state_in_1 = graph.get_tensor_by_name("global/LSTM_state_in_1")
问题是"global/LSTM_state_in_0"
是操作的名称,而不是张量的名称(即操作的特定输出)。假设您对这些操作的第0个输出感兴趣,您应该将":0"
附加到名称,正确的代码将是:
state_in_0 = graph.get_tensor_by_name("global/LSTM_state_in_0:0")
state_in_1 = graph.get_tensor_by_name("global/LSTM_state_in_1:0")