如何在新的Tensorflow会话中运行经过培训的代理?

时间:2017-01-11 19:21:24

标签: tensorflow tensorflow-serving openai-gym

我想使用OpenAI的Universe / gym训练多个代理(可能有非常不同的图形,变量......)。

我从universe-starter-agent代码开始,并调整了保护程序,以便转储.meta文件。

恢复经过培训的代理并使用它进行推理的过程看起来非常棘手。我目前正在做的事情:

  • 我在model.py中为LSTMPolicy类添加了一些变量名:



class LSTMPolicy(object):
    def __init__(self, ob_space, ac_space):

        ##ADDED NAME:##
        self.x = x = tf.placeholder(tf.float32, [None] + list(ob_space), name = "input_pixels")

        for i in range(4):
            x = tf.nn.elu(conv2d(x, 32, "l{}".format(i + 1), [3, 3], [2, 2]))
        # introduce a "fake" batch dimension of 1 after flatten so that we can do LSTM over time dim
        x = tf.expand_dims(flatten(x), [0])

        size = 256
        lstm = rnn.rnn_cell.BasicLSTMCell(size, state_is_tuple=True)
        self.state_size = lstm.state_size
        step_size = tf.shape(self.x)[:1]

        c_init = np.zeros((1, lstm.state_size.c), np.float32)
        h_init = np.zeros((1, lstm.state_size.h), np.float32)
        self.state_init = [c_init, h_init]
        c_in = tf.placeholder(tf.float32, [1, lstm.state_size.c])
        h_in = tf.placeholder(tf.float32, [1, lstm.state_size.h])
        self.state_in = [c_in, h_in]
        ##ADDED NAME:##
        state_in_0 = tf.identity(c_in,name = "LSTM_state_in_0")
        state_in_1 = tf.identity(h_in,name = "LSTM_state_in_1")
        
        state_in = rnn.rnn_cell.LSTMStateTuple(c_in, h_in)
        lstm_outputs, lstm_state = tf.nn.dynamic_rnn(
            lstm, x, initial_state=state_in, sequence_length=step_size,
            time_major=False)
        lstm_c, lstm_h = lstm_state
        x = tf.reshape(lstm_outputs, [-1, size])
        self.logits = linear(x, ac_space, "action", normalized_columns_initializer(0.01))
        ##ADDED NAME:##
        inference_logits = tf.identity(self.logits ,name = "inference_logits")
        self.vf = tf.reshape(linear(x, 1, "value", normalized_columns_initializer(1.0)), [-1])
        self.state_out = [lstm_c[:1, :], lstm_h[:1, :]]
        self.sample = categorical_sample(self.logits, ac_space)[0, :]
        self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, tf.get_variable_scope().name)




  • 然后我编写了这个代码片段来尝试加载图表并将其用于推理:



import numpy as np
import tensorflow as tf

save_path = "universe_dumps/logs/pong_model_test/train/model.ckpt-0"

input_frame = [np.random.rand(42,42,1).astype(float)]
initial_state = np.zeros([1,256])

tf.reset_default_graph()

with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=True)) as sess, sess.as_default():
    new_saver = tf.train.import_meta_graph(save_path+".meta",clear_devices=True)
    new_saver.restore(sess, save_path)

    variables = tf.get_collection(tf.GraphKeys.VARIABLES, scope='global')
    W = variables[0]

    graph = variables[0].graph #Is this the way to get a pointer to the graph??
    inference_op = graph.get_operation_by_name("global/inference_logits")
    state_in_0 = graph.get_tensor_by_name("global/LSTM_state_in_0")
    state_in_1 = graph.get_tensor_by_name("global/LSTM_state_in_1")

    feed_dict = {"global/input_pixels:0": input_frame, state_in_0: initial_state, state_in_1: initial_state}

    logits, W = sess.run([inference_op, W], feed_dict = feed_dict)
    print(logits)
    print(W.shape)




到目前为止,我尝试过的所有内容都给了我一个空的列表返回logits,当前的片段(上面)给了我错误: "名称' global / LSTM_state_in_0'是指操作,而不是张量。"

即使我开始使用它,每次我更改代理程序图表的架构时,我都必须调整代码...所以我认为必须有一个更简单的方法吗?谁可以帮助我在这里?

理想情况下,我想要一个函数load_model(path_to_model) 启动一个会话,加载所需的一切,返回一个对象,我可以调用某种类型的.predict方法,以便我可以提供一个numpy数组(42,42,1)并从中获取logits训练有素的特工...

1 个答案:

答案 0 :(得分:0)

以下几行似乎有问题:

state_in_0 = graph.get_tensor_by_name("global/LSTM_state_in_0")
state_in_1 = graph.get_tensor_by_name("global/LSTM_state_in_1")

问题是"global/LSTM_state_in_0"操作的名称,而不是张量的名称(即操作的特定输出)。假设您对这些操作的第0个输出感兴趣,您应该将":0"附加到名称,正确的代码将是:

state_in_0 = graph.get_tensor_by_name("global/LSTM_state_in_0:0")
state_in_1 = graph.get_tensor_by_name("global/LSTM_state_in_1:0")