MXNet打印中间符号值

时间:2017-03-24 22:03:37

标签: python machine-learning deep-learning mxnet

如何找到MXNet符号中保存的实际数值。

假设我有,

x = mx.sym.Variable('x')
y = mx.sym.Variable('y')
z = x + y, 

如果x = [100,200]且y = [300,400], 我想打印:

z = [400,600]

有点像tensorflow的eval()方法

1 个答案:

答案 0 :(得分:8)

看了一下后,我发现你可以通过以下方式做到这一点:

x = mx.sym.Variable('x')
y = mx.sym.Variable('y')
z = x + y
executor = z.bind(mx.cpu(), {'x': mx.nd.array([100,200]), 'y':mx.nd.array([300,400])})
output = executor.forward()

会给你'输出':

[<NDArray 2 @cpu(0)>]

打印实际数字输出:

print output[0].asnumpy()
array([ 400.,  600.], dtype=float32)