标签: tensorflow pytorch
我正在尝试用C ++写一个自定义的tensorflow操作。我想存储前向传递的一些中间结果,以帮助在后向传递中进行梯度计算。我知道在Pytorch中我可以做ctx.save_for_backward-对于自定义C ++操作,在tensorflow中是否有类似的东西?
谢谢!