Tensorflow:使用不同的“调用”函数

时间:2017-08-06 02:58:25

标签: python tensorflow subclass lstm

在输入序列上使用dynamic_rnn会返回一系列输出和最后一个单元格状态。

对于手头的任务(可以在序列中的任何索引处开始/结束的截断的反向支持)我不仅需要访问最后一个单元状态,而且还需要访问整个中间状态序列。快速在线查看我找到了以下帖子:

https://github.com/tensorflow/tensorflow/issues/5731

该线程建议最好的做法是通过将整个状态作为输出的一部分返回来扩展原始LSTM单元的功能,因此调用dynamic_rnn的第一个返回值也将包含输出序列。 / p>

编辑3:

经过4个小时的探索并在线寻找解决方案后,似乎我需要更新output_size属性。这是更新的代码:

class SSLSTMCell(tf.contrib.rnn.LSTMCell):
  @property
  def output_size(self):
    return self.state_size * 2

  def __call__(self, inputs, state):
    cell_out, cell_state = super(SSLSTMCell, self).__call__(inputs, state)
    together = tf.concat([cell_state.c, cell_state.h], axis=1)
    return together, cell_state

我为stackLSTM实现了类似的结果,我会做同样的技巧,修改__call__函数并修改output_size。

我有一个真正的问题,这样做,但是,现在“输出”是所有一起的连接,并失去了元组的结构,因为我看不到一种方法来让它工作,同时保留元组结构。我知道我可以执行一些重新整形以使它们重新进入元组形式,然后可以再次将其用作stackedLSTM的初始状态,但如果无论如何都要保留元组结构,那就太棒了。

可能提供上下文但可能不再与讨论相关的旧编辑

  

这是我到目前为止所做的:

class SSLSTMCell(tf.contrib.rnn.LSTMCell):

  def call(self, inputs, state):
    cell_out, cell_state = super(SSLSTMCell, self).call(inputs, state)
    return cell_state, cell_state
     

如您所见,而不是输出(输出,状态)元组   每个时间步骤我只是输出(状态,状态)输出,其中   会让我有能力进入所有中间状态。

     

然而,它似乎不是我的自定义的“调用”功能   在dynamic_rnn期间完全调用子类SSLSTMCell   函数调用,确实当我尝试在体内放置一个“断言0”时   调用函数我的程序不会崩溃。

     

那可能是什么错?我在dynamic_rnn上查找了实现   肯定会使用“调用”函数,但不知何故它没有   使用我的自定义子类

定义的那个      

提前致谢。

     

编辑1:

     

我似乎应该在“通话”功能上加分   不知道为什么,但这是我更新的代码:

class SSLSTMCell(tf.contrib.rnn.LSTMCell):
  def __call__(self, inputs, state):
    cell_out, cell_state = super(SSLSTMCell, self).__call__(inputs, state)
    return cell_state, cell_state
     

当调用“__call__”函数时,效果更好   这一次(在这个函数中放置一个断言0会崩溃python)。   但是,我收到了一个不同的错误:

  File "/home/evan/tensorflow/local/lib/python2.7/site-packages/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py",
     

第655行,致电           cur_inp,new_state = cell(cur_inp,cur_state)         在调用中输入“try1.py”第10行           cell_out,cell_state = super(SSLSTMCell,self)。 call (输入,状态)         文件“/home/evan/tensorflow/local/lib/python2.7/site-packages/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py”,   第327行,在调用中           input_size = inputs.get_shape()。with_rank(2)[1]       AttributeError:'LSTMStateTuple'对象没有属性'get_shape'

     

编辑2:

     

看来输出必须是下一步的常规张量   计算,LSTM-Tuple-State是张量的元组。所以我试过了   把它们放在一起

class SSLSTMCell(tf.contrib.rnn.LSTMCell):
  def __call__(self, inputs, state):
    cell_out, cell_state = super(SSLSTMCell, self).__call__(inputs, state)
    together = tf.concat([cell_state.c, cell_state.h], axis=1)
    return together, cell_state
     

但是现在我有一个不同的错误:

ValueError: Dimension 1 in both shapes must be equal, but are 10 and 20 for 'rnn/while/Select' (op: 'Select') with input shapes: [?],
     

显然,框架并不期望突然出现输出   更大...但我不明白为什么会这样,   不应该输出简单的输出,也不应该对输出产生任何影响   计算?这很令人困惑,为什么“输出”应该影响   计算。

     

尝试扩展LSTMCell类只是一个坏主意   我想做什么?我喜欢interfact到dynamic_rnn但是如果我   可以进入中间状态...

1 个答案:

答案 0 :(得分:1)

似乎没有一种简单的方法可以保留元组。看起来RNNCell子类的用户假设输出是大小为[batch x output_size]的2D张量。我认为他们不会使用元组。此外,恢复原始值应该很简单 c, h = tf.split(c, num_or_size_splits=2, axis=1)