我在这里定义了一个类
class BasicNetwork(object):
def __init__(self, scope, task_name, is_train=False, img_shape=(80, 80)):
self.scope = scope
self.is_train = is_train
self.task_name = task_name
self.__create_network(scope, img_shape=img_shape)
def __create_network(self, scope, img_shape=(80, 80)):
with tf.variable_scope(scope):
with tf.variable_scope(self.task_name):
with tf.variable_scope('input_data'):
self.inputs = tf.placeholder(shape=[None, *img_shape, cfg.HIST_LEN], dtype=tf.float32)
with tf.variable_scope('networks'):
with tf.variable_scope('conv_1'):
self.conv_1 = slim.conv2d(activation_fn=tf.nn.relu, inputs=self.inputs, num_outputs=32,
kernel_size=[8, 8], stride=4, padding='SAME', trainable=self.is_train)
with tf.variable_scope('conv_2'):
self.conv_2 = slim.conv2d(activation_fn=tf.nn.relu, inputs=self.conv_1, num_outputs=64,
kernel_size=[4, 4], stride=2, padding='SAME', trainable=self.is_train)
with tf.variable_scope('conv_3'):
self.conv_3 = slim.conv2d(activation_fn=tf.nn.relu, inputs=self.conv_2, num_outputs=64,
kernel_size=[3, 3], stride=1, padding='SAME', trainable=self.is_train)
with tf.variable_scope('f_c'):
self.fc = slim.fully_connected(slim.flatten(self.conv_3), 512,
activation_fn=tf.nn.elu, trainable=self.is_train)
我想定义两个具有不同任务名称的BasicNetwork实例。范围是全球性的#39;但是当我检查输出时,有
ipdb> for i in net_1.layres: print(i)
Tensor("global/simple/networks/conv_1/Conv/Relu:0", shape=(?, 20, 20, 32), dtype=float32, device=/device:GPU:2)
Tensor("global/simple/networks/conv_2/Conv/Relu:0", shape=(?, 10, 10, 64), dtype=float32, device=/device:GPU:2)
Tensor("global/simple/networks/conv_3/Conv/Relu:0", shape=(?, 10, 10, 64), dtype=float32, device=/device:GPU:2)
Tensor("global/simple/networks/f_c/fully_connected/Elu:0", shape=(?, 512), dtype=float32, device=/device:GPU:2)
ipdb> for i in net_2.layres: print(i)
Tensor("global_1/supreme/networks/conv_1/Conv/Relu:0", shape=(?, 20, 20, 32), dtype=float32, device=/device:GPU:2)
Tensor("global_1/supreme/networks/conv_2/Conv/Relu:0", shape=(?, 10, 10, 64), dtype=float32, device=/device:GPU:2)
Tensor("global_1/supreme/networks/conv_3/Conv/Relu:0", shape=(?, 10, 10, 64), dtype=float32, device=/device:GPU:2)
Tensor("global_1/supreme/networks/f_c/fully_connected/Elu:0", shape=(?, 512), dtype=float32, device=/device:GPU:2)
正如您在输出中看到的,已创建新范围global_1
,但我想将其设为global
。我设置reuse=True
但后来我发现当没有名为global
的范围时,reuse=True
无法使用。我该怎么办?
答案 0 :(得分:0)
使用reuse
您可以获取现有变量。现在要重用变量软管应该存在于图中。如果存在具有相同名称的变量,则可以将这些变量用于其他操作。
class BasicNetwork(object):
def __init__(self, scope, task_name, reuse, is_train=False, img_shape=(80, 80)):
self.scope = scope
self.is_train = is_train
self.reuse = reuse
self.task_name = task_name
self.__create_network(scope, reuse=self.reuse, img_shape=img_shape)
def __create_network(self, scope, reuse=None, img_shape=(80, 80)):
with tf.variable_scope(scope, reuse=reuse):
...
# delete this line with tf.variable_scope(self.task_name):
# or replace with; with tf.name_scope(self.task_name):
trainnet = BasicNetwork('global', taskname, None)
# resue the created variables
valnet = BasicNetwork('global', taskname, True)