梯度检查点返回值

时间:2020-07-26 16:11:51

标签: python pytorch torch

我有一个返回张量的检查点回调函数(即custom_dec)和一个字典。但是似乎此函数不返回字典(或其他数据类型),而仅返回张量。解决此问题的方法是什么,因为我要检查点的模块将返回张量以及数据类型作为字典:

def custom_dec(self, module):
        def custom_forward(*inputs):
            output = module(inputs[0], inputs[1],
                            encoder_attn_mask=inputs[2],
                            decoder_padding_mask=inputs[3],
                            layer_state=inputs[4],
                            causal_mask=inputs[5],
                            output_attentions=inputs[6],
                            )
            # output[2] is a python dictionary
            return output[0], output[2]

以下是检查点调用:

x, layer_past = \
                checkpoint.checkpoint(
                    self.custom_dec(decoder_layer),
                    x,
                    encoder_hidden_states,
                    encoder_padding_mask,
                    decoder_padding_mask,
                    layer_state,
                    decoder_causal_mask,
                    output_attentions,
                )

错误:

TypeError:CheckpointFunctionBackward.forward:预期变量(获取 字典)的返回值1

1 个答案:

答案 0 :(得分:0)

讨论了类似的情况here

你可以做的是将字典转换成某种张量形式。我遇到了一个错误,它是由 torch.utils.checkpoint 不接受的输入列表引起的。我的解决方案是将列表中的张量作为独立张量传递,并在 custom_forward 中形成一个列表。

我不知道你的字典的形式(例如,如果每个键总是有一个值),但你可以想出一个适用于你的字典的字典-张量互换方案。

相关问题