我必须在训练前先预先训练网络。我使用自己会话的单独文件中的代码执行此操作,但第一个会话中的变量仍然会被转移并导致问题(因为我在一个' main'文件中运行这两个文件)。
我可以通过简单地运行我的pretrain文件来解决这个问题,该文件保存训练好的图层,然后运行我的训练文件,加载保存的图层。但是能够一步完成这两件事情会很好。我怎样才能打破这个链接'并避免不必要的变量具有全局范围?
'主要'文件看起来像这样:
from util import pretrain_nn
from NN import Network
shape = [...]
layer_save_file = ''
data = get_data()
# Trains and saves layers
pretrain_nn(shape, data, layer_save_file)
# If I were to print all variables (using tf.all_variables)
# variables only used in pretrain_nn show up
# (the printing would be done inside `Network`)
NN = Network(shape, pretrain=True, layer_save_file)
NN.train(data)
# Doesn't work because apparently some variables haven't been initialized.
NN.save()
答案 0 :(得分:1)
变量的生命周期与TensorFlow图隐式关联,默认情况下,两个计算都将添加到同一(全局)图中。您可以使用每个子计算周围的with tf.Graph().as_default():
块来适当地调整它们的范围:
with tf.Graph().as_default():
# Trains and saves layers
pretrain_nn(shape, data, layer_save_file)
with tf.Graph().as_default():
NN = Network(shape, pretrain=True, layer_save_file)
NN.train(data)
NN.save()