Theano中的符号矩阵幂系列

时间:2016-10-31 15:18:53

标签: python linear-algebra theano

假设我有一个2d张量A。我想象征性地计算ApowA的幂级数,这是一个3d张量定义如下:

Apow = [I, A, A^2, A^3, ..., A^k]

其中A^2表示A.dot(A)(即幂系列是针对点积而不是元素定义的)。 k是一个符号标量,用于指定系列的长度。

我如何在Theano中实现这一点?似乎解决方案将基于scan,但我无法让它工作。

有什么想法吗?

1 个答案:

答案 0 :(得分:1)

让我将答案分成 numpy 实现和 theano 实现:

使用numpy:

def kpow(A, k):
     if k == 0:
         return np.identity(A.shape[0])
     if k == 1:
        return A
     else:
        return np.dot(A, kpow(A, k-1))

然后你就可以得到你的Apow

k = 5
A = np.array([[1, 1, 1],[2, 2, 2],[3, 3, 3]])
Apow = [kpow(A,i) for i in range(k)]

当然,您可以通过实际累积列表来提高效率。需要注意的重要事项是重复,我们如何使用先前的结果来计算下一个结果。

使用theano:

首先,让我们为k和矩阵M定义两个符号变量:

k = T.iscalar('k')
M = T.dmatrix('M')

接下来,让我们定义一个递归函数:

def recurrence(M, prev_result):
    return prev_result * M

最后,是扫描功能的时间:

result, updates = theano.scan(fn=recurrence,
                              outputs_info=T.identity_like(M),
                              non_sequences=M,
                              n_steps=k)

现在让我们得到一些结果:

A = np.array([[1, 1, 1],[2, 2, 2],[3, 3, 3]], dtype='int32')
kpow_theano = theano.function(inputs=[M,k], outputs=result)
Apow = [kpow_theano(A,10)[i] for i in range(10)]

我不确定如何使用theano在前面获得单位矩阵。我想你可以将它添加到列表中。