给定位置(X,Y)沿第三轴(Z)更新rank3张量流张量中的切片

时间:2019-10-10 16:53:21

标签: python python-2.7 tensorflow

我正在尝试使用Tensorflow 1.9.0重新实现以下函数(以numpy编写)。

def lateral_inhibition2(conv_spikes,SpikesPerNeuronAllowed):
    vbn = np.where(SpikesPerNeuronAllowed==0)
    conv_spikes[vbn[0],vbn[1],:]=0 
    return conv_spikes

conv_spikes是等级3的二进制张量,SpikesPerNeuronAllowed是等级2的张量。 conv_spikes是一个变量,用于指示该位置包含1的特定位置的神经元是否已加标,而0则指示该位置的神经元未加标。 SpikesPerNeuronAllowed变量指示是否允许沿X-Y轴的Z位置上的所有神经元都尖峰。 1中的SpikesPerNeuronAllowed表示允许X-Y中相应conv_spikes位置并沿Z轴的神经元尖峰。 0指示不允许在X-Y中的相应conv_spikes位置并沿Z轴的神经元加尖峰。

conv_spikes2 = (np.random.rand(5,5,3)>=0.5).astype(np.int16)
temp2 = np.random.choice([0, 1], size=(25,), p=[3./4, 1./4])
SpikesPerNeuronAllowed2 = temp2.reshape(5,5)
print(conv_spikes2[:,:,0])
print
print(conv_spikes2[:,:,1])
print
print(conv_spikes2[:,:,2])
print
print(SpikesPerNeuronAllowed2)

产生以下输出

##First slice of conv_spikes across Z-axis
[[0 0 1 1 1]
 [1 0 0 1 1]
 [1 0 1 1 0]
 [0 1 0 1 1]
 [0 1 0 0 0]]
##Second slice of conv_spikes across Z-axis
[[0 0 1 0 0]
 [0 0 1 0 1]
 [0 0 1 1 1]
 [0 0 0 1 0]
 [1 1 1 1 1]]
##Third slice of conv_spikes across Z-axis
[[0 1 1 0 0]
 [0 0 1 0 0]
 [0 1 1 0 0]
 [0 0 0 1 0]
 [1 0 1 1 1]]
##SpikesPerNeuronAllowed2
[[0 0 0 0 1]
 [0 0 0 0 0]
 [0 0 0 0 0]
 [1 1 0 0 0]
 [0 0 0 1 0]]

现在,当调用函数时

conv_spikes2 = lateral_inhibition2(conv_spikes2,SpikesPerNeuronAllowed2)
print(conv_spikes2[:,:,0])
print
print(conv_spikes2[:,:,1])
print
print(conv_spikes2[:,:,2])

产生以下输出

##First slice of conv_spikes across Z-axis
[[0 0 0 0 1]
 [0 0 0 0 0]
 [0 0 0 0 0]
 [0 1 0 0 0]
 [0 0 0 0 0]]
##Second slice of conv_spikes across Z-axis
[[0 0 0 0 0]
 [0 0 0 0 0]
 [0 0 0 0 0]
 [0 0 0 0 0]
 [0 0 0 1 0]]
##Third slice of conv_spikes across Z-axis
[[0 0 0 0 0]
 [0 0 0 0 0]
 [0 0 0 0 0]
 [0 0 0 0 0]
 [0 0 0 1 0]]

我试图在Tensorflow中重复如下内容

conv_spikes_tf = tf.Variable((np.random.rand(5,5,3)>=0.5).astype(np.int16))
a_placeholder = tf.placeholder(tf.float32,shape=(5,5))
b_placeholder = tf.placeholder(tf.float32)
inter2 = tf.where(tf.equal(a_placeholder,b_placeholder))
output= sess.run(inter2,feed_dict{a_placeholder:SpikesPerNeuronAllowed2,b_placeholder:0})
print(output)

产生以下输出

[[0 0]
 [0 1]
 [0 2]
 [0 3]
 [1 0]
 [1 1]
 [1 2]
 [1 3]
 [1 4]
 [2 0]
 [2 1]
 [2 2]
 [2 3]
 [2 4]
 [3 2]
 [3 3]
 [3 4]
 [4 0]
 [4 1]
 [4 2]
 [4 4]]

我尝试用以下代码更新conv_spikes_tf导致错误,我尝试浏览scatter_nd_update的手册,但我认为我不太了解。

update = tf.scatter_nd_update(conv_spikes_tf, output, np.zeros(output.shape[0]))
sess.run(update)

ValueError: The inner 1 dimensions of input.shape=[5,5,3] must match the inner 1 dimensions of updates.shape=[21,2]: Dimension 0 in both shapes must be equal, but are 3 and 2. Shapes are [3] and [2]. for 'ScatterNdUpdate_8' (op: 'ScatterNdUpdate') with input shapes: [5,5,3], [21,2], [21,2].

我不理解错误消息,特别是inner 1 dimensions的含义是什么,如何使用张量流实现上述numpy功能?

1 个答案:

答案 0 :(得分:1)

updatestf.scatter_nd_update的最后一个暗度应为3,等于ref的最后一个暗度。

update = tf.scatter_nd_update(conv_spikes_tf, output, np.zeros(output.shape[0], 3))

如果我的理解正确,您想将SpikesPerNeuronAllowed2(掩码)应用于conv_spikes。一种更简单的方法是将conv_spikes重塑为(3,5,5)并乘以SpikesPerNeuronAllowed2

我使用一个常量示例来显示结果。您也可以将其更改为tf.Variable

conv = (np.random.rand(3,5,5)>=0.5).astype(np.int32)
tmp = np.random.choice([0, 1], size=(25,), p=[3./4, 1./4])
mask = tmp.reshape(5,5)
# array([[[1, 1, 0, 0, 0],
#         [0, 1, 0, 0, 1],
#         [0, 1, 0, 0, 1],
#         [1, 0, 0, 0, 1],
#         [1, 0, 0, 1, 0]],

#        [[1, 0, 0, 0, 1],
#         [1, 0, 1, 1, 1],
#         [0, 0, 1, 0, 1],
#         [0, 0, 0, 1, 1],
#         [0, 0, 0, 1, 1]],

#        [[0, 0, 0, 1, 0],
#         [0, 1, 1, 0, 1],
#         [0, 1, 1, 0, 1],
#         [1, 1, 1, 1, 0],
#         [1, 1, 1, 0, 1]]], dtype=int32)

# array([[0, 0, 0, 1, 1],
#        [0, 0, 0, 1, 0],
#        [0, 0, 0, 0, 0],
#        [0, 1, 0, 1, 0],
#        [0, 0, 1, 0, 1]])
tf_conv = tf.constant(conv, dtype=tf.int32)
tf_mask = tf.constant(mask, dtype=tf.int32)
res = tf_conv * tf_mask
sess = tf.InteractiveSession()
sess.run(res)
# array([[[0, 0, 0, 0, 0],
#         [0, 0, 0, 0, 0],
#         [0, 0, 0, 0, 0],
#         [0, 0, 0, 0, 0],
#         [0, 0, 0, 0, 0]],

#        [[0, 0, 0, 0, 1],
#         [0, 0, 0, 1, 0],
#         [0, 0, 0, 0, 0],
#         [0, 0, 0, 1, 0],
#         [0, 0, 0, 0, 1]],

#        [[0, 0, 0, 1, 0],
#         [0, 0, 0, 0, 0],
#         [0, 0, 0, 0, 0],
#         [0, 1, 0, 1, 0],
#         [0, 0, 1, 0, 1]]], dtype=int32)