tensorflow static_rnn:获取在单元格调用中创建的张量

时间:2017-10-25 12:48:04

标签: tensorflow rnn

目标

假设我想在可变长度数据(例如文本)上运行RNN,但我知道最大长度,我甚至知道在一个批次中每个样本具有相同的长度。因此,static_rnndynamic_rnn会成为候选人。

此外,我想使用一个自定义单元格来计算一些不重复的额外数量(即在后续步骤中不需要它们),但在以后的计算中需要这些数量。 虽然我可以简单地将它们作为单元格输出的一部分并从rnn调用的结果中获取它们,但我更倾向于保持输出不变并按名称访问它们。

守则

以下代码段举例说明了问题

  • MyGRUWrapper是一个简单的包装器,用于计算有趣的数量
  • 我使用static_rnn 和<{strong> sequence_length参数
  • 创建RNN后,我按名称
  • 处理有趣的张量
  • 在运行时,我获取结果以及其中一个有趣的数量(每个时间步有一个)

将tensorflow导入为tf     导入numpy为np     进口重新     来自tensorflow.contrib.rnn导入GRUCell,RNNCell

# A simple cell wrapper that computes some "interesting" quantity
# in addition to calling the wrapped cell
class MyGRUWrapper(RNNCell):
  def __init__(self, cell, reuse=None):
    super(MyGRUWrapper, self).__init__(_reuse=reuse)
    assert(isinstance(cell, GRUCell))
    self.cell = cell

  @property
  def state_size(self):
      return self.cell.state_size

  @property
  def output_size(self):
    return self.cell.output_size

  def call(self, inputs, state):
    new_state, _ = self.cell(inputs, state)

    tf.identity(tf.nn.l2_loss(new_state, name="interesting_quantity"))

    return new_state, new_state

# (somewhat native) helper to fetch tensors by name
def simple_get_tensor_by_name(pattern):
  all_nodes = tf.get_default_graph().as_graph_def().node
  nodes = list(filter(lambda t : re.match(pattern, t.name) , all_nodes))
  tensors = [tf.get_default_graph().get_tensor_by_name("%s:0" % (t.name)) for t in nodes]
  return tensors

# Define a variable batch_size, variable length, but fixed maximum length RNN   
batch_size = 16
input_size = 100
max_length = 10

inputs = tf.placeholder(tf.float32, [None, max_length, input_size], name="input")
length = tf.placeholder(tf.int32, [None])

cell = GRUCell(100)
cell = MyGRUWrapper(cell)

states, _ = tf.nn.static_rnn(cell, tf.unstack(inputs, axis=1), dtype=tf.float32, sequence_length=length)

# Fetch our interesting quantities of which there will be max_length many
interesting_quantities = simple_get_tensor_by_name(".*interesting_quantity.*")
assert(len(interesting_quantities) == max_length)

# Run the graph with dummy input
with tf.Session() as session:
  session.run(tf.global_variables_initializer())

  batch_length = np.random.randint(1, max_length)
  inpt = np.random.rand(batch_size, max_length, input_size)

  out, q = session.run([states, interesting_quantities[0]], feed_dict = {inputs : inpt, length :[batch_length] * batch_size})

问题

我确实理解,对于长度为t<max_length的输入,在步骤t'>t中提取数量毫无意义。但是,对于实际计算的时间步长 - 例如示例中的步骤0 - 我希望我可以获取数量。

然而,我得到了一个 ValueError: Operation 'rnn/cond/rnn/my_gru_wrapper/interesting_quantity' has been marked as not fetchable 在运行命令中。

我想这是static_rnn在使用tf.cond的每个时间步使用sequence_length的结果。

有没有办法绕过这个而不放弃提供长度(因为这对性能至关重要)并且没有将所有这些都打包到单元格输出中?也欢迎使用dynamic_rnn的提示,但我想tf.while的内部使用并不会让事情变得更容易。

(在TF 1.2.0和1.4.0-rc4中测试)

0 个答案:

没有答案