在张量流

时间:2017-02-06 06:01:19

标签: tensorflow

我问了一个类似的问题here

然而,答案并不能满足我的需要,也没有人回复我的评论,所以我必须重新发布这个问题并使其更清楚。

我有2个网络,名为Target&资源。简单来说,网络定义如下:

# definition for Source
s_input = tf.placeholder(tf.float32, [None, 1], name = 'input_layer')
s_output = tf.contrib.layers.fully_connected(input = s_input, num_outputs=1)
#structure of target is the same as Source's with t_input & t_output
#loss
loss = (alpha*t_output-s_input+beta*label)**2
opt = tf.train.GradientDescentOptimizer(learning_rate=0.1).minimize(loss)

现在,一些训练步骤后,我想层s_ouput的内容(参数)复制到层t_ouput使得t_ouput不作为s_ouput改变由于梯度下降和安培变化; t_ouput仍然从t_input获得输入。

我已经尝试过Yaroslav Bulatov建议的解决方案,但它不起作用。

如果我使用简单的tf.Variable来定义我的网络,我可以通过tf.assign轻松复制变量,但是现在我想使用更简单和更简单的tf.contrib.layers。足够灵活地定义我自己的网络。

如果有人不理解我的问题,请通知我,以便我可以解决。

2 个答案:

答案 0 :(得分:2)

您可以使用assign来创建复制操作,例如

s_output = tf.contrib.layers.fully_connected(input = s_input, num_outputs=1, weights_initializer=tf.contrib.layers.xavier_initializer())
t_output = tf.contrib.layers.fully_connected(input = s_input, num_outputs=1)

现在您可以访问可训练的变量

vars = tf.trainable_variables()

并复制它们(前半部分是来自s_output的变量,后半部分是来自t_output的变量):

copy_ops = [vars[ix+len(vars)//2].assign(var.value()) for ix, var in enumerate(vars[0:len(vars)//2])]

现在您可以使用以下方式复制数据:

init = tf.global_variables_initializer()
sess = tf.Session() 
sess.run(init)
map(lambda x: sess.run(x), copy_ops)
print(sess.run(vars[2]))

希望这是你想要的。

答案 1 :(得分:0)

最简单的方法是使用tf进行数学运算。首先,获取每一层的重量。要从任何层获得重量: 我的fc1层如下所示:-

with tf.variable_scope("fc1"):
    fc1 = tf.contrib.layers.fully_connected(inputs = input_,
                                            num_outputs = 10,
                                            activation_fn=tf.nn.relu,
                                            weights_initializer=tf.contrib.layers.xavier_initializer())

要获取图层的权重和偏差,可以执行以下操作:var_fc1 = tf.trainable_variables('fc1') fc1_w_ = np.array(sess.run(var_fc1)[0]) #get weight fc1_b_ = np.array(sess.run(var_fc1)[1]) # get biasess 模仿fc1在做什么:

fc1_old = tf.add(tf.matmul(input_ , fc1_w) , fc1_b)
fc1_old = tf.nn.relu(fc1_old)