如何从现有模型中获取一个张量并在另一个模型中使用它?

时间:2018-08-15 02:20:33

标签: python tensorflow

我想做的是从现有训练有素的模型中获取一些权重和偏差,然后在我的自定义op(模型或图形)中使用它们。

我可以使用以下方式恢复模型:

# Create context
with tf.Graph().as_default(), tf.Session() as sess:
    # Create model
    with tf.variable_scope('train'):
        train_model = MyModel(some_args)

然后获取张量:

latest_ckpt = tf.train.latest_checkpoint(path)
if latest_ckpt:
    saver.restore(sess, latest_ckpt)
weight = tf.get_default_graph().get_tensor_by_name("example:0")

我的问题是,如果我想在其他上下文(模型或图形)中使用该weight,如何安全地将其值复制到新图形中,例如:

with self.test_session(use_gpu=True, graph=ops.Graph()) as sess:
    with vs.variable_scope("test", initializer=initializer):
        # How can I make it possible?
        w = tf.get_variable('name', initializer=weight)

欢迎任何帮助,非常感谢。


感谢@Sorin的启发,我找到了一种简单明了的方法:

z = graph.get_tensor_by_name('prefix/NN/W1:0')

with tf.Session(graph=graph) as sess:
    z_value = sess.run(z)

with tf.Graph().as_default() as new_graph, tf.Session(graph=new_graph) as sess:
    w = tf.get_variable('w', initializer=z_value)

1 个答案:

答案 0 :(得分:0)

hacky的方法是使用tf.assign将权重分配给想要的变量(确保它在开始时只发生一次,而不是每次迭代都发生,否则模型将无法调整这些权重)

稍微不太聪明的方法是加载图形和训练模型的会话,并修改图形以添加所需的操作。由于您还拥有原始模型的整个图形,因此这会使图形更加混乱,但由于您可以直接依赖于运算而不是权重(也就是说,如果原始模型正在执行S型激活),因此它会更干净一些。 ,这也会复制激活信息)。图中未使用的部分将由tensorflow自动修剪。

干净的方法是使用www.tenforflow.com/hub。它是一个库,可让您将图的各个部分定义为模块,可以将其导出和导入到任何图中。这样可以处理所有依赖关系和配置,还可以很好地控制训练(即,如果您要冻结权重,或者将训练延迟一定的迭代次数等)