Tensorflow:在`tf.scan`中调用外部设置功能(例如使用`tf.make_template`)会导致错误

时间:2016-10-23 08:45:10

标签: tensorflow

我有一个类似RNN的结构,它有一些由用户传入的构建块(组件神经网络)。这是一个最小的例子:

import tensorflow as tf
tf.reset_default_graph()

def initialize(shape):
    init = tf.random_normal(shape, mean=0, stddev=0.1, dtype=tf.float32)
    return init

def test_rnn_with_external(input, hiddens, external_fct):
    """
    A simple rnn that makes the standard update, then
    feeds the new hidden state through some external
    function.
    """
    dim_in = input.get_shape().as_list()[-1]
    btsz = input.get_shape().as_list()[1]
    shape = (dim_in + hiddens, hiddens)
    _init = initialize(shape)
    W = tf.get_variable("rnn_w", initializer=_init)
    _init = tf.zeros([hiddens])
    b = tf.get_variable("rnn_b", initializer=_init)

    def _step(previous, input):
        concat = tf.concat(1, [input, previous])     
        h_t = tf.tanh(tf.add(tf.matmul(concat, W), b))

        h_t = external_fct(h_t)

        return h_t

    h_0 = tf.zeros([btsz, hiddens])
    states = tf.scan(_step,
                     input,
                     initializer=h_0,
                     name="states")
    return states

# the external function, relying on the templating mechanism.
def ext_fct(hiddens):
    """
    """
    def tmp(input):
        shape = (hiddens, hiddens)
        _init = initialize(shape)
        W = tf.get_variable("ext_w", initializer=_init)
        b = 0
        return tf.add(tf.matmul(input, W), b, name="external")
    return tf.make_template(name_="external_fct", func_=tmp)

# run from here on
t = 5
btsz = 4
dim = 2
hiddens = 3

x = tf.placeholder(tf.float32, shape=(t, btsz, dim))
ext = ext_fct(hiddens)

states = test_rnn_with_external(x, hiddens, external_fct=ext)

sess = tf.InteractiveSession()
sess.run(tf.initialize_all_variables())

错误以:

结尾
InvalidArgumentError: All inputs to node external_fct/ext_w/Assign must be from the same frame.

使用Frame,我会关联堆栈中的某个区域。所以我认为tf.make_template可能会做一些非常有线的事情,因此这里不可用。外部函数可以稍微重写,然后更直接地调用,如下所示:

import tensorflow as tf
tf.reset_default_graph()

def initialize(shape):
    init = tf.random_normal(shape, mean=0, stddev=0.1, dtype=tf.float32)
    return init

def test_rnn_with_external(input, hiddens, external_fct):
    dim_in = input.get_shape().as_list()[-1]
    btsz = input.get_shape().as_list()[1]
    shape = (dim_in + hiddens, hiddens)
    _init = initialize(shape)
    W = tf.get_variable("rnn_w", initializer=_init)
    _init = tf.zeros([hiddens])
    b = tf.get_variable("rnn_b", initializer=_init)

    def _step(previous, input):
        """
        """
        concat = tf.concat(1, [input, previous])     
        h_t = tf.tanh(tf.add(tf.matmul(concat, W), b))

        h_t = external_fct(h_t, hiddens)

        return h_t

    h_0 = tf.zeros([btsz, hiddens])
    states = tf.scan(_step,
                     input,
                     initializer=h_0,
                     name="states")
    return states

def ext_fct_new(input, hiddens):
    """
    """
    shape = (hiddens, hiddens)
    _init = initialize(shape)
    W = tf.get_variable("ext_w_new", initializer=_init)
    b = 0
    return tf.add(tf.matmul(input, W), b, name="external_new")

t = 5
btsz = 4
dim = 2
hiddens = 3
x = tf.placeholder(tf.float32, shape=(t, btsz, dim))

states = test_rnn_with_external(x, hiddens, external_fct=ext_fct_new)

sess = tf.InteractiveSession()
sess.run(tf.initialize_all_variables())

但是,仍然存在相同的错误InvalidArgumentError: All inputs to node ext_w_new/Assign must be from the same frame.

当然,将外部函数的内容移动到_step部分(以及tf.get_variable之前的部分)是有效的。但随后灵活性(原始代码中必需的)已经消失。

我做错了什么?非常感谢任何帮助/提示/指示。

(注意:在github上也问过这个问题:https://github.com/tensorflow/tensorflow/issues/4478

1 个答案:

答案 0 :(得分:0)

使用tf.constant_initializer解决了这个问题。这被描述为here