如何从RNN层返回参差不齐的张量?

时间:2020-07-01 14:04:17

标签: tensorflow recurrent-neural-network ragged

我想构建一个RNN层,该层接受一个参差不齐的张量并返回形状相同的参差不齐的张量。

这是我的尝试(这是最小的琐碎RNN层,仅返回输入):

from tensorflow.keras import Model, Input
from tensorflow.keras.layers import RNN, Layer

class TrivialRNNCell(Layer):

    def __init__(self, **kwargs):
        self.state_size = 0
        self.output_size = None
        super(TrivialRNNCell, self).__init__(**kwargs)

    def call(self, inputs, states):
        output = inputs
        return output, states

cell = TrivialRNNCell()
x = Input(shape=(None,None), ragged=True)
layer = RNN(cell, return_sequences=True)
y = layer(x)

trivial_model = Model(inputs=x, outputs=y, name="trivial_model")

此模型将精确返回密集张量的输入,但是它将参差不齐的张量的零值填充到未使用的输出中:

>>> print(tf.ragged.constant([[[1.,2.,3.,4.]],[[]],[[1.,2.,3.]]]))
<tf.RaggedTensor [[[1.0, 2.0, 3.0, 4.0]], [[]], [[1.0, 2.0, 3.0]]]>

>>> print(trivial_model(tf.ragged.constant([[[1.,2.,3.,4.]],[[]],[[1.,2.,3.]]])))
<tf.RaggedTensor [[[1.0, 2.0, 3.0, 4.0]], [[0.0, 0.0, 0.0, 0.0]], [[1.0, 2.0, 3.0, 0.0]]]>

我该如何获得所需的结果?

>>> print(trivial_model(tf.ragged.constant([[[1.,2.,3.,4.]],[[]],[[1.,2.,3.]]])))
<tf.RaggedTensor [[[1.0, 2.0, 3.0, 4.0]], [[]], [[1.0, 2.0, 3.0]]]>

0 个答案:

没有答案