mxnet符号API是否支持条件控制流?

时间:2017-07-09 03:11:53

标签: mxnet

我想在符号中添加一些条件控件,似乎if-else在符号构造时间内进行评估。但是我希望它在符号运行时进行评估。

a = mx.symbol.Variable(name='a')
b = mx.symbol.Variable(name='b')

if a>b:
    c = a-b
else:
    c = a+b

TensorFlow提供了tf.cond()运算符来处理它,mxnet中有对应的吗?

1 个答案:

答案 0 :(得分:4)

您可以使用mx.symbol.where

您可以计算a_minus_ba_plus_b并返回一个数组,其中每个元素来自a_minus_ba_plus_b,具体取决于另一个condition中的相应值阵列。这是一个例子:

a = mx.symbol.Variable(name='a')
b = mx.symbol.Variable(name='b')

a_minus_b = a - b
a_plus_b  = a + b

# gt = a > b
gt = a.__gt__(b) 

result = mx.sym.where(condition=gt, x=a_minus_b, y=a_plus_b)

ex = result.bind(ctx=mx.cpu(), args={'a':mx.nd.array([1,2,3]), 'b':mx.nd.array([3,2,1])})
r = ex.forward()

print(r[0].asnumpy()) #result should be [1+3, 2+2, 3-1]