我正在使用Arrayfire和Flashlight评估网络。
auto tmp = output(af::seq(2, 10), af::span, af::span, af::span);
auto softmax_tmp = fl::softmax(tmp, 0);
output(af::seq(2,10),af::span,af::span,af::span)=softmax_tmp;
output
是形状为(12,100,1,1)的张量。现在,我想拉出张量的(2,10)暗淡,并针对提取的100个9暗淡矢量进行softmax运算。然后放回去。上面的代码。
问题是第三行不起作用。 softmax_tmp
是正确的值,但是第三行中的赋值运算符刚刚失败。确实可以成功通过编译,但是output
仍然是第一行中的旧值。
谁可以帮助我?真的非常感谢。