我有一个返回张量的检查点回调函数(即custom_dec)和一个字典。但是似乎此函数不返回字典(或其他数据类型),而仅返回张量。解决此问题的方法是什么,因为我要检查点的模块将返回张量以及数据类型作为字典:
def custom_dec(self, module):
def custom_forward(*inputs):
output = module(inputs[0], inputs[1],
encoder_attn_mask=inputs[2],
decoder_padding_mask=inputs[3],
layer_state=inputs[4],
causal_mask=inputs[5],
output_attentions=inputs[6],
)
# output[2] is a python dictionary
return output[0], output[2]
以下是检查点调用:
x, layer_past = \
checkpoint.checkpoint(
self.custom_dec(decoder_layer),
x,
encoder_hidden_states,
encoder_padding_mask,
decoder_padding_mask,
layer_state,
decoder_causal_mask,
output_attentions,
)
错误:
TypeError:CheckpointFunctionBackward.forward:预期变量(获取 字典)的返回值1
答案 0 :(得分:0)
讨论了类似的情况here。
你可以做的是将字典转换成某种张量形式。我遇到了一个错误,它是由 torch.utils.checkpoint
不接受的输入列表引起的。我的解决方案是将列表中的张量作为独立张量传递,并在 custom_forward
中形成一个列表。
我不知道你的字典的形式(例如,如果每个键总是有一个值),但你可以想出一个适用于你的字典的字典-张量互换方案。