索引变量范围为numpy

时间:2016-09-09 22:21:42

标签: numpy tensorflow

我有一个形状为A的numpy零矩阵(2, 5)

A = [[ 0.,  0.,  0.,  0.,  0.],
    [ 0.,  0.,  0.,  0.,  0.]]

我有另一个大小为seq的数组2。这与A的第一个轴相同。

seq = [2, 3]

我想创建另一个矩阵B,如下所示:

B = [[ 1.,  1.,  0.,  0.,  0.],
    [ 1.,  1.,  1.,  0.,  0.]]

B是通过使用seq[i]更改ith A行中的第一个1元素构建的。

这是一个玩具的例子。 Aseq可能很大,因此需要效率。 如果有人知道如何在tensorflow中执行此操作,我会特别感激。

2 个答案:

答案 0 :(得分:2)

您可以在TensorFlow(以及NumPy中的一些类似代码)中执行此操作,如下所示:

seq = [2, 3]

b = tf.expand_dims(tf.range(5), 0)   # A 1 x 5 matrix.
seq_matrix = tf.expand_dims(seq, 1)  # A 2 x 1 matrix.
b_bool = tf.greater(seq_matrix, b)   # A 2 x 5 bool matrix.

B = tf.to_int32(b_bool)              # A 2 x 5 int matrix.

示例输出:

In [7]: b = tf.expand_dims(tf.range(5), 0)
        [[0 1 2 3 4]]

In [21]: b_bool = tf.greater(seq_matrix, b)
In [22]: op = sess.run(b_bool)
In [23]: print(op)
[[ True  True False False False]
 [ True  True  True False False]]

In [24]: bint = tf.to_int32(b_bool)
In [25]: op = sess.run(bint)
In [26]: print(op)
[[1 1 0 0 0]
 [1 1 1 0 0]]

答案 1 :(得分:1)

@mrry's解决方案,表达方式略有不同

In [667]: [[2],[3]]>np.arange(5)
Out[667]: 
array([[ True,  True, False, False, False],
       [ True,  True,  True, False, False]], dtype=bool)
In [668]: ([[2],[3]]>np.arange(5)).astype(int)
Out[668]: 
array([[1, 1, 0, 0, 0],
       [1, 1, 1, 0, 0]])

这个想法是在[外部]广播意义上比较[2,3]和[0,1,2,3,4]。结果是boolean,可以很容易地改为0/1整数。

另一种方法是使用cumsum(或另一个ufunc.accumulate函数):

In [669]: A=np.zeros((2,5))
In [670]: A[range(2),[2,3]]=1
In [671]: A
Out[671]: 
array([[ 0.,  0.,  1.,  0.,  0.],
       [ 0.,  0.,  0.,  1.,  0.]])
In [672]: A.cumsum(axis=1)
Out[672]: 
array([[ 0.,  0.,  1.,  1.,  1.],
       [ 0.,  0.,  0.,  1.,  1.]])
In [673]: 1-A.cumsum(axis=1)
Out[673]: 
array([[ 1.,  1.,  0.,  0.,  0.],
       [ 1.,  1.,  1.,  0.,  0.]])

或以1's开头的变体:

In [681]: A=np.ones((2,5))
In [682]: A[range(2),[2,3]]=0
In [683]: A
Out[683]: 
array([[ 1.,  1.,  0.,  1.,  1.],
       [ 1.,  1.,  1.,  0.,  1.]])
In [684]: np.minimum.accumulate(A,axis=1)
Out[684]: 
array([[ 1.,  1.,  0.,  0.,  0.],
       [ 1.,  1.,  1.,  0.,  0.]])