我是pytorch的初学者,并且我有一些在网络中实现所需的功能。
我的问题是:是否可以使用 tf.function 之类的方法,还是应该对变量使用“ class(nn.Module)”?
例如,让X为10x2矩阵。用伪代码:
a = Variable(1.0)
b = Variable(1.0)
Y = a*X[:,0]**2 + b*X[:,1]
答案 0 :(得分:0)
在PyTorch中,您不需要tf.function
之类的东西,只需要使用普通的Python代码(由于动态图)即可。
如果上面的内容不能回答您的问题,请给出更详细的示例(包含代码)。