我想对循环神经网络权重矩阵使用numpy where
来更新大于阈值的值。
更新pre-> post重量还可以。
但是更新帖子->重量不起作用。
擅长numpy的人,请帮助我!
>>> import numpy as np
>>> v = np.arange(3*2*2).reshape((3,2,2))
"""
array([[[ 0, 1],
[ 2, 3]],
[[ 4, 5],
[ 6, 7]],
[[ 8, 9],
[10, 11]]])
"""
>>> w = np.arange(3*2*2*3*2*2).reshape((3,2,2,3,2,2))
"""
array([[[[[[ 0, 1],
[ 3, 3]],
[[ 6, 5],
[ 9, 7]],
[[ 12, 9],
[ 15, 11]]],
[[[ 18, 13],
[ 21, 15]],
[[ 24, 17],
[ 27, 19]],
[[ 30, 21],
[ 33, 23]]]],
~~~~~~~~~~~~~~~~~~~~~~~~~~
[[144, 137],
[147, 139]],
[[150, 141],
[153, 143]]]]]])
"""
>>> w[np.where(v>3)] += 1 # pre -> post is OK.
>>> w[:,:,:, np.where(v>3)] += 1 # post -> pre is not working!! I can't understand this result.
# incremented all elements!!