如何从所有Tensorflow的while_loop()迭代中获得结果作为未指定长度的列表?

时间:2018-03-01 15:16:15

标签: python tensorflow recurrent-neural-network

我想从所有while_loop()次迭代中获得结果,以便tf.shape(result) == [batch_size, input_length]

batch_sizeinput_length可能因小批量批量不同而不同,因此我会动态评估其形状。

我创建了这段代码:

batch_size = tf.shape(x)[0]
input_length = tf.shape(x)[1]

result = []
lstm = BasicLSTMCell(hidden_units)
cell = MultiRNNCell([lstm])

init_state = cell.zero_state(batch_size, dtype=tf.float32)

def body(i, x, result, state):
    h, state = cell(x[:, i, :], state)
    result.append(h)
    i += 1

    return i, x, result, state

    def cond(i, *_):
        return tf.less(i, input_length - 1)

    _, _, result, state = tf.while_loop(cond, body, (i, x, result, init_state))

但Tensorflow隐式flatten()我传递给所有loop_vars,因此在2次迭代后,我的while_loop()输入的size_of为4,但输出的size_of为5(由于{{中的2个元素) 1}}变量)。

我尝试使用result,但它不起作用,可能是因为使用了常规的Python列表。但仍然 - 我无法提供该列表的确切大小(如果我愿意 - 这将是一个黑客攻击) - 因为它匹配迭代次数,这是使用{{1}后面的关键特征}根本 - 动态停止!

我不需要添加它,如果我只是重新分配这样的结果而不是将shape_invatiants附加到while_loop()列表,它可以正常工作:

h

所以在我看来有点不合理 - 在大多数情况下result只是一个累加器,而不是迭代器!但是如果你想用它作为迭代器 - 它几乎可以防止你做到这一点......

任何人都知道如何处理它吗?我将非常感激。

我正在使用Tensorflow 1.5

修改

从TF github库中rnn.py的第749行,我了解到至少h, state = cell(x[:, i, :], state) result = h while_loop()必须保持不变,但两者都不能batch

所以 - 如果你想动态计算RNN - 你应该为结果提供至少input_length参数,但是你的循环将在None之后停止(当你重塑输入max_input_length时到batch_size。这是对的吗?

0 个答案:

没有答案