在许多点评估许多单项式

时间:2017-07-31 15:56:07

标签: python arrays numpy

以下问题涉及在许多点评估许多单项式(x**k * y**l * z**m)。

我想计算两个numpy数组的“内部能力”,即

import numpy

a = numpy.random.rand(10, 3)
b = numpy.random.rand(3, 5)

out = numpy.ones((10, 5))
for i in range(10):
    for j in range(5):
        for k in range(3):
            out[i, j] *= a[i, k]**b[k, j]

print(out.shape)

如果相反,该行将读取

out[i, j] += a[i, k]*b[j, k]

这可能是一些内部产品,可通过简单doteinsum计算。

是否可以在一条numpy线上执行上述循环?

4 个答案:

答案 0 :(得分:5)

将这些数组扩展到3D版本后,您可以使用broadcasting -

(a[:,:,None]**b[None,:,:]).prod(axis=1)

简单地说 -

(a[...,None]**b[None]).prod(1)

基本上,我们保持两个阵列的最后一个轴和第一个轴对齐,同时在两个输入的第一个轴和最后一个轴之间执行元素方向的功率。使用给定样本在形状上进行示意性放置 -

  10   x   3   x   1
   1   x   3   x   5

答案 1 :(得分:5)

用对数来思考它怎么样:

 Vector<String> attributeValues = new Vector();
    attributeValues.add("CN=nw-PPARead,OU=LDAP,OU=NSC Managed,OU=Global,OU=Groups,DC=NWIE,DC=NET");
    if (attributeValues != null) {
        System.out.println("attributeValues not null = " + attributeValues); // 4
        if (attributeValues.contains("CN=nw-PPARead,OU=LDAP,OU=NSC Managed,OU=Global,OU=Groups,DC=NWIE,DC=NET")
                || attributeValues
                        .contains("CN=nw-PPARead,OU=LDAP,OU=NSC Managed,OU=Global,OU=Groups,DC=NWIEPILOT,DC=NET")
                || attributeValues
                        .contains("CN=nw-PPAWrite,OU=LDAP,OU=NSC Managed,OU=Global,OU=Groups,DC=NWIE,DC=NET")
                || attributeValues.contains(
                        "CN=nw-PPAWrite,OU=LDAP,OU=NSC Managed,OU=Global,OU=Groups,DC=NWIEPILOT,DC=NET")) {
            // Not getting to below statement
            System.out.println("AttributeValues out of first if" + attributeValues); // 5
        }

import numpy a = numpy.random.rand(10, 3) b = numpy.random.rand(3, 5) out = np.exp(np.matmul(np.log(a), b)) 开始,然后是c_ij = prod(a_ik ** b_kj, k=1..K)

注意:log(c_ij) = sum(log(a_ik) * b_ik, k=1..K) 中的零可能搞乱了结果(也是负面的,但结果无论如何都不会很好地定义结果)。我试了一下它似乎并没有真正打破;我不知道这种行为是否由NumPy保证,但为了安全起见,你可以在最后添加一些内容,如:

a

答案 2 :(得分:2)

另外两个解决方案:

内联

numpy.array([
    numpy.prod([a[:, i]**bb[i] for i in range(len(bb))], axis=0)
    for bb in b.T
    ]).T

并使用power.outer

numpy.prod([numpy.power.outer(a[:, k], b[k]) for k in range(len(b))], axis=0)

两者都比广播解决方案慢一点。

即使有一些逻辑可以容纳零值和负值,exp - log解决方案也可以解决问题。

enter image description here

重现情节的代码:

import numpy
import perfplot


def loop(data):
    a, b = data
    m = a.shape[0]
    n = b.shape[1]
    out = numpy.ones((m, n))
    for i in range(m):
        for j in range(n):
            for k in range(3):
                out[i, j] *= a[i, k]**b[k, j]
    return out


def broadcasting(data):
    a, b = data
    return (a[..., None]**b[None]).prod(1)


def log_exp(data):
    a, b = data
    neg_a = numpy.zeros(a.shape, dtype=int)
    neg_a[a < 0.0] = 1
    odd_b = numpy.zeros(b.shape, dtype=int)
    odd_b[b % 2 == 1] = 1
    negative_count = numpy.dot(neg_a, odd_b)

    out = (-1)**negative_count * numpy.exp(
            numpy.matmul(
                numpy.log(abs(a), where=abs(a) > 0.0),
                b
                ))

    zero_a = numpy.zeros(a.shape, dtype=int)
    zero_a[a == 0.0] = 1
    pos_b = numpy.zeros(b.shape, dtype=int)
    pos_b[b > 0] = 1
    zero_count = numpy.dot(zero_a, pos_b)
    out[zero_count > 0] = 0.0
    return out


def inline(data):
    a, b = data
    return numpy.array([
        numpy.prod([a[:, i]**bb[i] for i in range(len(bb))], axis=0)
        for bb in b.T
        ]).T


def outer_power(data):
    a, b = data
    return numpy.prod([
        numpy.power.outer(a[:, k], b[k]) for k in range(len(b))
        ], axis=0)


perfplot.show(
    setup=lambda n: (
        numpy.random.rand(n, 3) - 0.5,
        numpy.random.randint(0, 10, (3, n))
        ),
    n_range=[2**k for k in range(11)],
    repeat=10,
    kernels=[
        loop,
        broadcasting,
        inline,
        log_exp,
        outer_power
        ],
    logx=True,
    logy=True,
    xlabel='len(a)',
    )

答案 3 :(得分:0)

import numpy

a = numpy.random.rand(10, 3)
b = numpy.random.rand(3, 5)

out = [[numpy.prod([a[i, k]**b[k, j] for k in range(3)]) for j in range(5)] for i in range(10)]