如何在tensorflow中正确加载训练有素的网络?

时间:2018-06-22 18:44:26

标签: python-3.x tensorflow

我已经建立,训练并保存了以下网络:

import tensorflow as tf
import tensorflow.contrib as tc

obs = tf.placeholder(tf.float32, shape=(None,5), name='obs')

def network(obs)
  x = obs
  x = tf.layers.dense(x, 64)
  x = tc.layers.layer_norm(x, center=True, scale=True)
  x = tf.nn.relu(x)
  x = tf.layers.dense(x, 64)
  x = tc.layers.layer_norm(x, center=True, scale=True)
  x = tf.nn.relu(x)
  x = tf.layers.dense(x, 1., kernel_initializer=tf.random_uniform_initializer(minval=-3e-3, maxval=3e-3))
  x = tf.nn.tanh(x, name = "net_op")
  return x

actor = network(obs)

现在,我想加载它并在各种输入上使用它。我知道,如何加载文件,图形和变量值。但是,我不知道如何还原和实际使用网络。我尝试通过执行以下操作来还原网络:

import tensorflow as tf

sess = tf.Session()
saver = tf.train.import_meta_graph('network_model.meta')
saver.restore(sess, 'network_model') # The syntax here might be not correct. It is not important here.

graph = tf.get_default_graph()
obs = graph.get_tensor_by_name("obs:0")
net_op = graph.get_tensor_by_name("net_op:0")
sess.run(net_op, feed_dict ={obs0: [obs]})

但是,net_op:0不会出现在图形中保存的变量中。我在做什么错了?

提前谢谢!

0 个答案:

没有答案