Tensorflow:如何给出变量范围

时间:2016-06-25 17:32:37

标签: python tensorflow

我必须在训练前先预先训练网络。我使用自己会话的单独文件中的代码执行此操作,但第一个会话中的变量仍然会被转移并导致问题(因为我在一个' main'文件中运行这两个文件)。

我可以通过简单地运行我的pretrain文件来解决这个问题,该文件保存训练好的图层,然后运行我的训练文件,加载保存的图层。但是能够一步完成这两件事情会很好。我怎样才能打破这个链接'并避免不必要的变量具有全局范围?

'主要'文件看起来像这样:

from util import pretrain_nn
from NN import Network

shape = [...]
layer_save_file = ''
data = get_data()

# Trains and saves layers
pretrain_nn(shape, data, layer_save_file) 

# If I were to print all variables (using tf.all_variables) 
# variables only used in pretrain_nn show up 
# (the printing would be done inside `Network`)
NN = Network(shape, pretrain=True, layer_save_file) 

NN.train(data)

# Doesn't work because apparently some variables haven't been initialized.
NN.save()

1 个答案:

答案 0 :(得分:1)

变量的生命周期与TensorFlow图隐式关联,默认情况下,两个计算都将添加到同一(全局)图中。您可以使用每个子计算周围的with tf.Graph().as_default():块来适当地调整它们的范围:

with tf.Graph().as_default():
  # Trains and saves layers
  pretrain_nn(shape, data, layer_save_file) 

with tf.Graph().as_default():
  NN = Network(shape, pretrain=True, layer_save_file) 

  NN.train(data)

  NN.save()