如何使用mxnet.jl和Julia获取渐变节点?

时间:2016-03-07 09:28:19

标签: julia gradient-descent mxnet mxnet.jl

我正试图在Julia中使用mxnet.jl从mxnet主文档中复制以下示例:

A = Variable('A')
B = Variable('B')
C = B * A
D = C + Constant(1)
# get gradient node.
gA, gB = D.grad(wrt=[A, B])
# compiles the gradient function.
f = compile([gA, gB])
grad_a, grad_b = f(A=np.ones(10), B=np.ones(10)*2)

该示例显示如何自动提取symoblic表达式并获取其渐变。

mxnet.jl(最新版​​本2016-03-07)中的等效内容是什么?

1 个答案:

答案 0 :(得分:1)

MXNet.jl/src/symbolic-node.jl中的代码可能有助于您找到答案。

我对这个包不熟悉。 这是我的猜测: A = mx.Variable("A") B = mx.Variable("B") C = B .* A D = C + 1 如果存在,mx.normalized_gradient可能是其余部分的解决方案。