Tensorflow:使用tf.Variable替换/馈送图形的占位符?

时间:2018-05-04 13:16:31

标签: tensorflow placeholder

我有一个模型M1,其数据输入是占位符M1.input,其权重已经过培训。 我的目标是构建一个新模型M2,以o的形式从输入M1计算w的输出tf.Variable(及其训练过的权重) }(而不是将实际值提供给M1.input)。换句话说,我使用训练有素的模型M1作为黑盒函数来构建新模型o = M1(w)(在我的新模型中,要学习w并且权重{ {1}}被固定为常量)。问题是M1仅接受我们需要提供实际值的输入M1,而不是像M1.input这样的变量。

作为构建w的天真解决方案,我可以在M2内手动构建M1,然后使用预先训练的值初始化M2的权重并保持它们在M1内无法训练。但是,在实践中,M2很复杂,我不想在M1内再次手动构建M1。我正在寻找更优雅的解决方案,例如变通方法或直接解决方案,用{tf.Variable M2替换M1.input的输入占位符M1

感谢您的时间。

1 个答案:

答案 0 :(得分:2)

这是可能的。怎么样:

import tensorflow as tf


def M1(input, reuse=False):
    with tf.variable_scope('model_1', reuse=reuse):
        param = tf.get_variable('param', [1])
        o = input + param
        return o


w = tf.get_variable('some_w', [1])
plhdr = tf.placeholder_with_default(w, [1])

output_m1 = M1(plhdr)

with tf.Session() as sess:

    sess.run(tf.global_variables_initializer())

    sess.run(w.assign([42]))

    print(sess.run(output_m1, {plhdr: [0]}))  # direct from placeholder
    print(sess.run(output_m1))                # direct from variable

因此,当feed_dict具有占位符的值时,将使用此值。否则,使用变量“w”的后备选项处于活动状态。