tensorflow static_rnn:获取在单元格调用中创建的张量

时间:2017-10-25 12:48:04

标签: tensorflow rnn



此外,我想使用一个自定义单元格来计算一些不重复的额外数量(即在后续步骤中不需要它们),但在以后的计算中需要这些数量。 虽然我可以简单地将它们作为单元格输出的一部分并从rnn调用的结果中获取它们,但我更倾向于保持输出不变并按名称访问它们。



  • MyGRUWrapper是一个简单的包装器,用于计算有趣的数量
  • 我使用static_rnn 和<{strong> sequence_length参数
  • 创建RNN后,我按名称
  • 处理有趣的张量
  • 在运行时,我获取结果以及其中一个有趣的数量(每个时间步有一个)

将tensorflow导入为tf     导入numpy为np     进口重新     来自tensorflow.contrib.rnn导入GRUCell,RNNCell

# A simple cell wrapper that computes some "interesting" quantity
# in addition to calling the wrapped cell
class MyGRUWrapper(RNNCell):
  def __init__(self, cell, reuse=None):
    super(MyGRUWrapper, self).__init__(_reuse=reuse)
    assert(isinstance(cell, GRUCell))
    self.cell = cell

  def state_size(self):
      return self.cell.state_size

  def output_size(self):
    return self.cell.output_size

  def call(self, inputs, state):
    new_state, _ = self.cell(inputs, state)

    tf.identity(tf.nn.l2_loss(new_state, name="interesting_quantity"))

    return new_state, new_state

# (somewhat native) helper to fetch tensors by name
def simple_get_tensor_by_name(pattern):
  all_nodes = tf.get_default_graph().as_graph_def().node
  nodes = list(filter(lambda t : re.match(pattern, t.name) , all_nodes))
  tensors = [tf.get_default_graph().get_tensor_by_name("%s:0" % (t.name)) for t in nodes]
  return tensors

# Define a variable batch_size, variable length, but fixed maximum length RNN   
batch_size = 16
input_size = 100
max_length = 10

inputs = tf.placeholder(tf.float32, [None, max_length, input_size], name="input")
length = tf.placeholder(tf.int32, [None])

cell = GRUCell(100)
cell = MyGRUWrapper(cell)

states, _ = tf.nn.static_rnn(cell, tf.unstack(inputs, axis=1), dtype=tf.float32, sequence_length=length)

# Fetch our interesting quantities of which there will be max_length many
interesting_quantities = simple_get_tensor_by_name(".*interesting_quantity.*")
assert(len(interesting_quantities) == max_length)

# Run the graph with dummy input
with tf.Session() as session:

  batch_length = np.random.randint(1, max_length)
  inpt = np.random.rand(batch_size, max_length, input_size)

  out, q = session.run([states, interesting_quantities[0]], feed_dict = {inputs : inpt, length :[batch_length] * batch_size})


我确实理解,对于长度为t<max_length的输入,在步骤t'>t中提取数量毫无意义。但是,对于实际计算的时间步长 - 例如示例中的步骤0 - 我希望我可以获取数量。

然而,我得到了一个 ValueError: Operation 'rnn/cond/rnn/my_gru_wrapper/interesting_quantity' has been marked as not fetchable 在运行命令中。



(在TF 1.2.0和1.4.0-rc4中测试)

0 个答案:
