Python中对象数组的矩阵乘法

时间:2018-03-20 14:23:41

标签: python arrays numpy multiplication

我想知道如何使用dtype=object数组在numpy中支持矩阵乘法。我有同形加密的数字封装在类Ciphertext中,我已经覆盖了__add____mul__等基本数学运算符。

我创建了numpy数组,其中每个条目都是我的类Ciphertext的一个实例,numpy理解如何广播加法和乘法运算。

    encryptedInput = builder.encrypt_as_array(np.array([6,7])) # type(encryptedInput) is <class 'numpy.ndarray'>
    encryptedOutput = encryptedInput + encryptedInput
    builder.decrypt(encryptedOutput)                           # Result: np.array([12,14])

然而,numpy不让我做矩阵乘法

out = encryptedInput @ encryptedInput # TypeError: Object arrays are not currently supported

我不太明白为什么会出现这种情况,因为加法和乘法有效。我想这与numpy无法知道对象的形状有关,因为它可能是一个列表或某种东西。

天真的解决方案:我可以编写自己的扩展ndarray的类并覆盖__matmul__操作,但我可能会失去性能,而且这种方法需要实现广播等,所以我基本上会重新发明轮子,因为它应该像现在一样工作。

问题:如何在dtype=objects的数组上使用numpy提供的标准矩阵乘法,其中对象的行为与数字完全相同?

提前谢谢!

3 个答案:

答案 0 :(得分:1)

无论出于什么原因,matmul都不起作用,但是tensordot功能按预期工作。

encryptedInput = builder.encrypt_as_array(np.array([6,7]))
out = np.tensordot(encryptedInput, encryptedInput, axes=([1,0])) 
    # Correct Result: [[ 92. 105.]
    #                  [120. 137.]]

现在调整轴只是一件麻烦事。我仍然想知道这是否比使用for循环的天真实现更快。

答案 1 :(得分:1)

tensordot有一个使用object dtype和字符串连接的扩展示例。它实际上是使用np.dot

In [89]: np.dot(np.array([['a'],['b']],object),np.array([[2,3]]))
Out[89]: 
array([['aa', 'aaa'],
       ['bb', 'bbb']], dtype=object)

这个例子很小,但确实表明object版本的路线较慢(比同等数字版本):

In [98]: timeit np.dot(np.array([[1],[2]]),np.array([[2,3]]))
7.3 µs ± 20.8 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
In [99]: timeit np.dot(np.array([[1],[2]],object),np.array([[2,3]]))
12 µs ± 121 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
编译了

np.dot代码,因此验证差异需要做更多的工作。

对于1和2d数组,np.dotnp.matmul一样好。引入matmul以方便@运算符,并将其扩展为3d及更高版本。以前只能使用einsum或在上部维度上进行迭代来实现3d +行为。

对于2个3d阵列,

matmul是有效的:

 for i in range(a.shape[0]):
     data[i,:,:] = a[i,:,:].dot(b[i,:,:])

答案 2 :(得分:0)

您可以使用ndarray.dot方法,即使np.object运算符失败,该方法显然也适用于@ dtypes:

out = encryptedInput.dot(encryptedInput)