假设我想在可变长度数据(例如文本)上运行RNN,但我知道最大长度,我甚至知道在一个批次中每个样本具有相同的长度。因此,static_rnn
和dynamic_rnn
会成为候选人。
此外,我想使用一个自定义单元格来计算一些不重复的额外数量(即在后续步骤中不需要它们),但在以后的计算中需要这些数量。
虽然我可以简单地将它们作为单元格输出的一部分并从rnn
调用的结果中获取它们,但我更倾向于保持输出不变并按名称访问它们。
以下代码段举例说明了问题
MyGRUWrapper
是一个简单的包装器,用于计算有趣的数量static_rnn
和<{strong> sequence_length
参数将tensorflow导入为tf 导入numpy为np 进口重新 来自tensorflow.contrib.rnn导入GRUCell,RNNCell
# A simple cell wrapper that computes some "interesting" quantity
# in addition to calling the wrapped cell
class MyGRUWrapper(RNNCell):
def __init__(self, cell, reuse=None):
super(MyGRUWrapper, self).__init__(_reuse=reuse)
assert(isinstance(cell, GRUCell))
self.cell = cell
@property
def state_size(self):
return self.cell.state_size
@property
def output_size(self):
return self.cell.output_size
def call(self, inputs, state):
new_state, _ = self.cell(inputs, state)
tf.identity(tf.nn.l2_loss(new_state, name="interesting_quantity"))
return new_state, new_state
# (somewhat native) helper to fetch tensors by name
def simple_get_tensor_by_name(pattern):
all_nodes = tf.get_default_graph().as_graph_def().node
nodes = list(filter(lambda t : re.match(pattern, t.name) , all_nodes))
tensors = [tf.get_default_graph().get_tensor_by_name("%s:0" % (t.name)) for t in nodes]
return tensors
# Define a variable batch_size, variable length, but fixed maximum length RNN
batch_size = 16
input_size = 100
max_length = 10
inputs = tf.placeholder(tf.float32, [None, max_length, input_size], name="input")
length = tf.placeholder(tf.int32, [None])
cell = GRUCell(100)
cell = MyGRUWrapper(cell)
states, _ = tf.nn.static_rnn(cell, tf.unstack(inputs, axis=1), dtype=tf.float32, sequence_length=length)
# Fetch our interesting quantities of which there will be max_length many
interesting_quantities = simple_get_tensor_by_name(".*interesting_quantity.*")
assert(len(interesting_quantities) == max_length)
# Run the graph with dummy input
with tf.Session() as session:
session.run(tf.global_variables_initializer())
batch_length = np.random.randint(1, max_length)
inpt = np.random.rand(batch_size, max_length, input_size)
out, q = session.run([states, interesting_quantities[0]], feed_dict = {inputs : inpt, length :[batch_length] * batch_size})
我确实理解,对于长度为t<max_length
的输入,在步骤t'>t
中提取数量毫无意义。但是,对于实际计算的时间步长 - 例如示例中的步骤0 - 我希望我可以获取数量。
然而,我得到了一个
ValueError: Operation 'rnn/cond/rnn/my_gru_wrapper/interesting_quantity' has been marked as not fetchable
在运行命令中。
我想这是static_rnn
在使用tf.cond
的每个时间步使用sequence_length
的结果。
有没有办法绕过这个而不放弃提供长度(因为这对性能至关重要)并且没有将所有这些都打包到单元格输出中?也欢迎使用dynamic_rnn
的提示,但我想tf.while
的内部使用并不会让事情变得更容易。
(在TF 1.2.0和1.4.0-rc4中测试)