每批只根据最后一帧计算Loss

时间:2021-04-19 10:11:38

标签: python deep-learning lstm recurrent-neural-network

我正在训练一个带有 ConvLSTM 模块的对象检测网络。我使用较早的帧来提高当前帧的网络性能。 在训练期间,我计算 LSTM 在每个时间步中的输出损失。


例如,如果每个输入序列的长度 T = 4:

Input  = (t0 t1 t2 t3)
Output = (t0 t1 t2 t3)

# Input is a 5D tensor of size (Batch=1, T=4, Channels, Height, Width)

现在计算所有时间步的损失。 (t0 t1 t2 t3)

我想尝试的是:

Input  = (t0 t1 t2 t3)
Output =          (t3)  # don't care about the earlier outputs

网络将根据最后一帧“t3”进行优化。


这种修改对提高最后一帧的网络性能是否有意义?

0 个答案:

没有答案