使用==比较numpy数组的规则是什么?

时间:2016-02-14 20:47:57

标签: python numpy numpy-broadcasting

例如,尝试理解这些结果:

>>> x
array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
>>> (x == np.array([[1],[2]])).astype(np.float32)
array([[ 0.,  1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.]], dtype=float32)
>>> (x == np.array([1,2]))
   False
>>> (x == np.array([[1]])).astype(np.float32)
array([[ 0.,  1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.]], dtype=float32)
>>> (x == np.array([1])).astype(np.float32)
array([ 0.,  1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.], dtype=float32)

>>> (x == np.array([[1,3],[2]]))
False
>>> 

这里发生了什么?在[1]的情况下,它将1与x的每个元素进行比较,并将结果聚合在一个数组中。在[[1]]的情况下,同样的事情。只需在repl上进行试验,就可以很容易地弄清楚特定阵列形状会发生什么。但是双方可以拥有任意形状的基本规则是什么?

3 个答案:

答案 0 :(得分:4)

NumPy尝试在比较之前将两个数组广播为兼容的形状。 如果广播失败,则返回False。 In the future

  

等于运算符#include <stdio.h> #include <stdlib.h> int main(int argc, char *argv[]) { int num, pos, count; FILE *fp; char (*array)[4096]; /* pointer to an array of buffers */ if (argc < 2) { printf("Usage headtail filename [number]\n"); return 1; } fp = fopen(argv[1], "r"); if (fp == NULL) { printf("Cannot open file %s\n", argv[1]); return 1; } if (argc > 2) { /* get the number from the command line if 2 args were given */ if (sscanf(argv[2], "%d", &num) != 1) { num = -1; } } else { /* otherwise read from standard input */ if (scanf("%d", &num) != 1) { num = -1; } } if (num < 0) { printf("Invalid number\n"); /* negative or non numeric */ return 1; } /* allocate space for num+1 buffers */ array = malloc(4096 * (num + 1)); for (count = pos = 0; fgets(array[pos], 4096, fp) != NULL; count++) { /* printing the first num lines */ if (count < num) fputs(array[pos], stdout); /* cycle buffers for num lines + 1 extra buffer */ if (++pos >= num + 1) pos = 0; } if (count > num) { /* more lines to print */ pos = count - num; if (pos > num) { /* print place holder for missing lines */ printf("...\n"); } else { /* print from the last line printed */ pos = num; } for (; pos < count; pos++) { fputs(array[pos % (num + 1)], stdout); } } fclose(fp); return 0; } 将来会引发错误   np.equal如果广播或元素比较等失败。

否则,返回由逐个元素比较产生的布尔数组。例如,由于==x是可广播的,因此返回一个形状数组(10,):

np.array([1])

由于In [49]: np.broadcast(x, np.array([1])).shape Out[49]: (10,) x无法播放,np.array([[1,3],[2]])会返回False

x == np.array([[1,3],[2]])

答案 1 :(得分:3)

令你困惑的是:

  1. 正在进行一些broadcasting

  2. 您似乎有一个旧版本的numpy。

  3. x == np.array([[1],[2]])
    

    是广播。它将x与第一个和第二个数组中的每一个进行比较;因为它们是标量,广播暗示它将x的每个元素与每个标量进行比较。

    然而,每个

    x == np.array([1,2])
    

    x == np.array([[1,3],[2]])
    

    无法播放。通过我,numpy 1.10.4,这给出了

    /usr/local/bin/ipython:1: DeprecationWarning: elementwise == comparison failed; this will raise an error in the future.
    #!/usr/bin/python
    False
    

答案 2 :(得分:0)

添加到unutbu的答案中,数组不必具有相同的维数。例如,尺寸为1的尺寸会拉伸以匹配其他尺寸。

A      (4d array):  8 x 1 x 6 x 1
B      (3d array):      7 x 1 x 5
Result (4d array):  8 x 7 x 6 x 5

A      (2d array):  5 x 4
B      (1d array):      1
Result (2d array):  5 x 4

A      (2d array):  5 x 4
B      (1d array):      4
Result (2d array):  5 x 4

A      (3d array):  15 x 3 x 5
B      (3d array):  15 x 1 x 5
Result (3d array):  15 x 3 x 5

A      (3d array):  15 x 3 x 5
B      (2d array):       3 x 5
Result (3d array):  15 x 3 x 5

A      (3d array):  15 x 3 x 5
B      (2d array):       3 x 1
Result (3d array):  15 x 3 x 5