在某些步骤中暂停LSTM的方法或从张量弹出某些矢量的快速方法

时间:2019-05-30 12:00:37

标签: python tensorflow

我有一批不连续的序列,其中零个占位符。如何将批处理文件输入到bi-LSTM(或CudnnLSTM)中,从而跳过那些占位符以保持其状态为静态?

inputs = [[r r - - r e e],
          [r - - r - r e]]
run =    [[1 1 0 0 1 1 1],
          [1 0 0 1 0 1 1]]
# rs are variants
# e is a constant vector other than 0s

或者,

如何从输入批次中弹出这些分散的占位符,填充它,然后将新的占位符插入类似的输出中?

input_ = [[r r r e e],
          [r r e e e]]
output_ = [[r r - - r e e],
           [r - - r - r e]]

在tensorflow官方网站上,CudnnLSTM为__call__函数提供了** kwargs,但未指定任何内容。我可以提供一个掩码来在这些步骤中暂停LSTM吗?

另一方面,我知道我可以弹出这些占位符。但这会使每个序列的长度不匹配,无法形成张量。此外,在使用tf.gather_nd和tf.scatter_nd时,运行时告诉我稀疏索引会占用过多的内存和时间。

我如何帮助这种情况?

rnn = LSTM()

out = rnn(input, suspend_at = run) # I want this ! XD

import tensorflow as tf
tf.enable_eager_execution()
import numpy as np

dim3 = 5

r = np.arange(dim3)
e = np.ones(dim3)
z = np.zeros(dim3) # placeholder

input = np.asarray([[r, z, z, r * 2, e], [r * 3, r * 4, z, r * 5, e]])
input = tf.constant(input)
# expected *fast* solution:
# [r 2r e e][3r 4r 5r e]
existance = tf.reduce_sum(input, 2) > 0

ex_lens = tf.cast(existance, tf.int32)
ex_lens = tf.reduce_sum(ex_lens, 1)
ex_len_max = tf.reduce_max(ex_lens)

def shrink_pad(args):
    seq_tensor, exist_vector = args
    ex_idx = tf.where(exist_vector)
    ex_tensor = tf.gather_nd(seq_tensor, ex_idx)
    e_pad = ex_len_max - tf.shape(ex_idx)[0]
    e_pad = tf.tile(seq_tensor[-1][None], [e_pad, 1])
    return tf.concat([ex_tensor, e_pad], 0)

def recover(args):
    s_tensor, exist_vector = args
    ori_len = tf.shape(exist_vector)[0]
    ex_idx = tf.where(exist_vector)
    ex_len = tf.reduce_sum(tf.cast(exist_vector, tf.int32))
    seq_tensor = tf.scatter_nd(ex_idx, s_tensor[:ex_len], [ori_len, dim3])
    return seq_tensor

# tf.map_fn with .scatter_nd is very slow for large tensors!

s_tensor = tf.map_fn(shrink_pad, (input, existance), input.dtype)
print(s_tensor)

r_tensor = tf.map_fn(recover, (s_tensor, existance), input.dtype)
result = tf.equal(input, r_tensor)
print(tf.reduce_all(result)) # True

请注意,e是序列填充!

我在张量和切片张量流中总是遇到麻烦:(

谢谢!

0 个答案:

没有答案