我有一个类似RNN的结构,它有一些由用户传入的构建块(组件神经网络)。这是一个最小的例子:
import tensorflow as tf
tf.reset_default_graph()
def initialize(shape):
init = tf.random_normal(shape, mean=0, stddev=0.1, dtype=tf.float32)
return init
def test_rnn_with_external(input, hiddens, external_fct):
"""
A simple rnn that makes the standard update, then
feeds the new hidden state through some external
function.
"""
dim_in = input.get_shape().as_list()[-1]
btsz = input.get_shape().as_list()[1]
shape = (dim_in + hiddens, hiddens)
_init = initialize(shape)
W = tf.get_variable("rnn_w", initializer=_init)
_init = tf.zeros([hiddens])
b = tf.get_variable("rnn_b", initializer=_init)
def _step(previous, input):
concat = tf.concat(1, [input, previous])
h_t = tf.tanh(tf.add(tf.matmul(concat, W), b))
h_t = external_fct(h_t)
return h_t
h_0 = tf.zeros([btsz, hiddens])
states = tf.scan(_step,
input,
initializer=h_0,
name="states")
return states
# the external function, relying on the templating mechanism.
def ext_fct(hiddens):
"""
"""
def tmp(input):
shape = (hiddens, hiddens)
_init = initialize(shape)
W = tf.get_variable("ext_w", initializer=_init)
b = 0
return tf.add(tf.matmul(input, W), b, name="external")
return tf.make_template(name_="external_fct", func_=tmp)
# run from here on
t = 5
btsz = 4
dim = 2
hiddens = 3
x = tf.placeholder(tf.float32, shape=(t, btsz, dim))
ext = ext_fct(hiddens)
states = test_rnn_with_external(x, hiddens, external_fct=ext)
sess = tf.InteractiveSession()
sess.run(tf.initialize_all_variables())
错误以:
结尾InvalidArgumentError: All inputs to node external_fct/ext_w/Assign must be from the same frame.
使用Frame
,我会关联堆栈中的某个区域。所以我认为tf.make_template
可能会做一些非常有线的事情,因此这里不可用。外部函数可以稍微重写,然后更直接地调用,如下所示:
import tensorflow as tf
tf.reset_default_graph()
def initialize(shape):
init = tf.random_normal(shape, mean=0, stddev=0.1, dtype=tf.float32)
return init
def test_rnn_with_external(input, hiddens, external_fct):
dim_in = input.get_shape().as_list()[-1]
btsz = input.get_shape().as_list()[1]
shape = (dim_in + hiddens, hiddens)
_init = initialize(shape)
W = tf.get_variable("rnn_w", initializer=_init)
_init = tf.zeros([hiddens])
b = tf.get_variable("rnn_b", initializer=_init)
def _step(previous, input):
"""
"""
concat = tf.concat(1, [input, previous])
h_t = tf.tanh(tf.add(tf.matmul(concat, W), b))
h_t = external_fct(h_t, hiddens)
return h_t
h_0 = tf.zeros([btsz, hiddens])
states = tf.scan(_step,
input,
initializer=h_0,
name="states")
return states
def ext_fct_new(input, hiddens):
"""
"""
shape = (hiddens, hiddens)
_init = initialize(shape)
W = tf.get_variable("ext_w_new", initializer=_init)
b = 0
return tf.add(tf.matmul(input, W), b, name="external_new")
t = 5
btsz = 4
dim = 2
hiddens = 3
x = tf.placeholder(tf.float32, shape=(t, btsz, dim))
states = test_rnn_with_external(x, hiddens, external_fct=ext_fct_new)
sess = tf.InteractiveSession()
sess.run(tf.initialize_all_variables())
但是,仍然存在相同的错误InvalidArgumentError: All inputs to node ext_w_new/Assign must be from the same frame.
当然,将外部函数的内容移动到_step
部分(以及tf.get_variable
之前的部分)是有效的。但随后灵活性(原始代码中必需的)已经消失。
我做错了什么?非常感谢任何帮助/提示/指示。
(注意:在github上也问过这个问题:https://github.com/tensorflow/tensorflow/issues/4478)