我知道在LSTM Followed by Mean Pooling有一个类似的话题,但那是关于Keras的,我在纯TensorFlow中工作。
我有一个LSTM网络,其中重现由:
处理outputs, final_state = tf.nn.dynamic_rnn(cell,
embed,
sequence_length=seq_lengths,
initial_state=initial_state)
我为每个样本传递正确的序列长度(用零填充)。在任何情况下,输出都包含不相关的输出,因为根据序列长度,某些样本产生的输出比其他样本长。
现在我正通过以下方法提取最后的相关输出:
def extract_axis_1(data, ind):
"""
Get specified elements along the first axis of tensor.
:param data: Tensorflow tensor that will be subsetted.
:param ind: Indices to take (one for each element along axis 0 of data).
:return: Subsetted tensor.
"""
batch_range = tf.range(tf.shape(data)[0])
indices = tf.stack([batch_range, ind], axis=1)
res = tf.reduce_mean(tf.gather_nd(data, indices), axis=0)
我将sequence_length - 1
作为索引传递。在参考最后一个主题时,我想选择所有相关的输出,然后选择平均合并,而不是最后一个。
现在,我尝试将嵌套列表作为extract_axis_1
的替代值传递,但tf.stack
不接受此操作。
有解决方案吗?
答案 0 :(得分:0)
您可以利用weight
功能的tf.contrib.seq2seq.sequence_loss
参数。
来自文档:
weights
:形状张力[batch_size
,sequence_length
]和dtypefloat
。权重构成序列中每个预测的权重。使用weights
作为屏蔽时,将所有有效时间步长设置为1,将所有填充时间步长设置为0,例如由tf.sequence_mask
返回的掩码。
您需要计算二进制掩码,以区分有效输出和无效输出。然后你可以将这个掩码提供给损失函数的weights
参数(可能,你会想要使用像这样的丢失);在计算损失时,函数不会考虑具有0权重的输出。
如果您不能使用序列丢失,您可以手动执行完全相同的操作。您计算二进制掩码,然后将输出乘以此掩码,并将其作为输入提供给完全连接的层。