numpy - 矩阵多个3x3和100x100x3阵列?

时间:2017-12-05 05:41:58

标签: python numpy matrix-multiplication

我有以下内容:

import numpy as np

XYZ_to_sRGB_mat_D50 = np.asarray([
    [3.1338561, -1.6168667, -0.4906146],
    [-0.9787684, 1.9161415, 0.0334540],
    [0.0719453, -0.2289914, 1.4052427],
])

XYZ_1 = np.asarray([0.25, 0.4, 0.1])
XYZ_2 = np.random.rand(100,100,3)

np.matmul(XYZ_to_sRGB_mat_D50, XYZ_1) # valid operation
np.matmul(XYZ_to_sRGB_mat_D50, XYZ_2) # makes no sense mathematically

如何在XYZ_2上执行与在XYZ_2上相同的操作?我是否先以某种方式重塑阵列?

2 个答案:

答案 0 :(得分:2)

您似乎正在尝试sum-reduce XYZ_to_sRGB_mat_D50 (axis=1)的最后一个轴XYZ_2 (axis=2)的最后一个轴np.tensordot(XYZ_2, XYZ_to_sRGB_mat_D50, axes=((2),(1))) 。所以,你可以这样使用np.tensordot -

np.matmul

Related post to understand tensordot

为了完整性,我们在转换XYZ_2的最后两个轴之后肯定也会使用np.matmul(XYZ_to_sRGB_mat_D50, XYZ_2.swapaxes(1,2)).swapaxes(1,2) ,就像这样 -

tensordot

这不会像In [158]: XYZ_to_sRGB_mat_D50 = np.asarray([ ...: [3.1338561, -1.6168667, -0.4906146], ...: [-0.9787684, 1.9161415, 0.0334540], ...: [0.0719453, -0.2289914, 1.4052427], ...: ]) ...: ...: XYZ_1 = np.asarray([0.25, 0.4, 0.1]) ...: XYZ_2 = np.random.rand(100,100,3) # @Julien's soln In [159]: %timeit XYZ_2.dot(XYZ_to_sRGB_mat_D50.T) 1000 loops, best of 3: 450 µs per loop In [160]: %timeit np.tensordot(XYZ_2, XYZ_to_sRGB_mat_D50, axes=((2),(1))) 10000 loops, best of 3: 73.1 µs per loop 一样高效。

运行时测试 -

sum-reductions

一般来说,在张量上tensordot时,sum-reduction效率更高。由于2D的轴只有一个,我们可以通过重新整形使张量成为np.dot数组,使用3D,得到结果并重新形成回package com.testing; import java.io.BufferedReader; import java.io.FileNotFoundException; import java.io.FileReader; import java.io.IOException; public class JunkEx { public static void main(String[] args) { String filePath = ".\\src\\test\\resources\\text-files\\orders\\orders-2017.txt"; String contents = fileToString(filePath); System.out.println(contents); } private static String fileToString(String filePath) { StringBuilder stringBuilder = null; BufferedReader br = null; try { br = new BufferedReader(new FileReader(filePath)); stringBuilder = new StringBuilder(); String currentLine; while ((currentLine = br.readLine()) != null) { stringBuilder.append(currentLine); stringBuilder.append("\n"); } }catch (FileNotFoundException ex1) { ex1.printStackTrace(); }catch (IOException ex2) { ex2.printStackTrace(); }finally { try { br.close(); } catch (IOException e) { e.printStackTrace(); } } return stringBuilder.toString(); } } 。< / p>

答案 1 :(得分:1)

你可能只是想要这个:

XYZ_2.dot(XYZ_to_sRGB_mat_D50.T)

np.matmul(XYZ_to_sRGB_mat_D50, XYZ_1)相当于XYZ_1.dot(XYZ_to_sRGB_mat_D50.T),您可以简单地将操作广播到100 x 100 x 3矩阵。