TensorFlow,如何重用变量范围名称

时间:2017-08-10 12:41:34

标签: tensorflow

我在这里定义了一个类

class BasicNetwork(object):
    def __init__(self, scope, task_name, is_train=False, img_shape=(80, 80)):
        self.scope = scope
        self.is_train = is_train
        self.task_name = task_name
        self.__create_network(scope, img_shape=img_shape)

    def __create_network(self, scope, img_shape=(80, 80)):
        with tf.variable_scope(scope):
            with tf.variable_scope(self.task_name):
                with tf.variable_scope('input_data'):
                    self.inputs = tf.placeholder(shape=[None, *img_shape, cfg.HIST_LEN], dtype=tf.float32)
                with tf.variable_scope('networks'):
                    with tf.variable_scope('conv_1'):
                        self.conv_1 = slim.conv2d(activation_fn=tf.nn.relu, inputs=self.inputs, num_outputs=32,
                                                  kernel_size=[8, 8], stride=4, padding='SAME', trainable=self.is_train)
                    with tf.variable_scope('conv_2'):
                        self.conv_2 = slim.conv2d(activation_fn=tf.nn.relu, inputs=self.conv_1, num_outputs=64,
                                                  kernel_size=[4, 4], stride=2, padding='SAME', trainable=self.is_train)
                    with tf.variable_scope('conv_3'):
                        self.conv_3 = slim.conv2d(activation_fn=tf.nn.relu, inputs=self.conv_2, num_outputs=64,
                                                  kernel_size=[3, 3], stride=1, padding='SAME', trainable=self.is_train)
                    with tf.variable_scope('f_c'):
                        self.fc = slim.fully_connected(slim.flatten(self.conv_3), 512,
                                                       activation_fn=tf.nn.elu, trainable=self.is_train)

我想定义两个具有不同任务名称的BasicNetwork实例。范围是全球性的#39;但是当我检查输出时,有

ipdb> for i in net_1.layres: print(i)
Tensor("global/simple/networks/conv_1/Conv/Relu:0", shape=(?, 20, 20, 32), dtype=float32, device=/device:GPU:2)
Tensor("global/simple/networks/conv_2/Conv/Relu:0", shape=(?, 10, 10, 64), dtype=float32, device=/device:GPU:2)
Tensor("global/simple/networks/conv_3/Conv/Relu:0", shape=(?, 10, 10, 64), dtype=float32, device=/device:GPU:2)
Tensor("global/simple/networks/f_c/fully_connected/Elu:0", shape=(?, 512), dtype=float32, device=/device:GPU:2)

ipdb> for i in net_2.layres: print(i)
Tensor("global_1/supreme/networks/conv_1/Conv/Relu:0", shape=(?, 20, 20, 32), dtype=float32, device=/device:GPU:2)
Tensor("global_1/supreme/networks/conv_2/Conv/Relu:0", shape=(?, 10, 10, 64), dtype=float32, device=/device:GPU:2)
Tensor("global_1/supreme/networks/conv_3/Conv/Relu:0", shape=(?, 10, 10, 64), dtype=float32, device=/device:GPU:2)
Tensor("global_1/supreme/networks/f_c/fully_connected/Elu:0", shape=(?, 512), dtype=float32, device=/device:GPU:2)

正如您在输出中看到的,已创建新范围global_1,但我想将其设为global。我设置reuse=True但后来我发现当没有名为global的范围时,reuse=True无法使用。我该怎么办?

1 个答案:

答案 0 :(得分:0)

使用reuse您可以获取现有变量。现在要重用变量软管应该存在于图中。如果存在具有相同名称的变量,则可以将这些变量用于其他操作。

class BasicNetwork(object):
def __init__(self, scope, task_name, reuse, is_train=False, img_shape=(80, 80)):
    self.scope = scope
    self.is_train = is_train
    self.reuse = reuse
    self.task_name = task_name
    self.__create_network(scope, reuse=self.reuse, img_shape=img_shape)

def __create_network(self, scope, reuse=None, img_shape=(80, 80)):
    with tf.variable_scope(scope, reuse=reuse):
    ...
        # delete this line with tf.variable_scope(self.task_name): 
        # or replace with; with tf.name_scope(self.task_name):               

trainnet = BasicNetwork('global', taskname, None)
# resue the created variables
valnet = BasicNetwork('global', taskname, True)