我使用以下Network
类内的方法将预训练网络加载到Tensorflow中(因此调用self.xyz)。首先,调用define_network()
,然后我对其他变量和优化器进行初始化,然后调用load_model()
。
然而,尽管使用tf.variable_scope(self.name)
,图表中的变量仍会加载到变量的通用空间中。这是有问题的,因为我有这个类的两个实例,每个实例加载相同的网络,我想将out分成不同的范围。
如何将变量加载到特定范围?
P.S。如果我的代码中有任何错误,请随时纠正我!
def load_model(self):
with tf.variable_scope(self.name) as scope:
self.saver.restore(self.sess, self.model_path)
print("Loaded model from {}".format(self.model_path))
def define_model(self):
with tf.variable_scope(self.name) as scope:
self.saver = tf.train.import_meta_graph(self.model_path + '.meta')
print("Loaded model from {}".format(self.model_path + '.meta'))
graph = tf.get_default_graph()
self.inputs = []
inp_names = ['i_hand1:0', 'i_hand2:0', 'i_flop1:0', 'i_flop2:0', 'i_flop3:0',
'i_turn:0', 'i_river:0', 'i_other:0', 'i_allowed_mod:0', 'keras_learning_phase:0']
for inp in inp_names:
self.inputs.append(tf.get_default_graph().get_tensor_by_name(inp))
self.outputs = tf.get_default_graph().get_tensor_by_name("Tanh:0")
self.add_output_conversions()
all_vars = tf.trainable_variables()
for var in all_vars:
self.var[var.name] = var
答案 0 :(得分:1)
我认为你的问题可以通过在
中添加一个参数来解决self.saver = tf.train.import_meta_graph(self.model_path + '.meta', 'import_scope'=self.name)
这里是reference