在3D theano张量上广播linalg.pinv

时间:2015-12-04 14:42:26

标签: python linear-algebra theano matrix-inverse

在下面的例子中,有一个3d numpy矩阵的大小(4,3,3)+关于如何计算numpy中3个3 * 3矩阵中的4个的pinv的解决方案。我也尝试使用在numpy中工作的相同功能,在theano希望它实现相同,但它失败了。知道如何在theano中做到这一点吗?

dt = np.dtype(np.float32)

a=[[[12,3,1],
   [2,4,1],
   [2,4,2],],
   [[12,3,3],
   [2,4,4],
   [2,4,5],],
   [[12,3,6],
   [2,4,5],
   [2,4,4],],
   [[12,3,3],
   [2,4,5],
   [2,4,6]]]

a=np.asarray(a,dtype=dt)
print(a.shape)

apinv=np.zeros((4,3,3))
print(np.linalg.pinv(a[0,:,:]).shape)

#numpy solution
apinv = map(lambda n: np.linalg.pinv(n), a)
apinv = np.asarray(apinv,dtype=dt)

#theano solution (not working)
at=T.tensor3('a')
apinvt = map(lambda n: T.nlinalg.pinv(n), at)

错误是:

Original exception was:
Traceback (most recent call last):
  File "pydevd.py", line 2403, in <module>
    globals = debugger.run(setup['file'], None, None, is_module)
  File "pydevd.py", line 1794, in run
    launch(file, globals, locals)  # execute the script
  File "exp_thn_pinv_map.py", line 35, in <module>
    apinvt = map(lambda n: T.nlinalg.pinv(n), at)
  File "theano/tensor/var.py", line 549, in __iter__
    raise TypeError(('TensorType does not support iteration. '
TypeError: TensorType does not support iteration. Maybe you are using builtin.sum instead of theano.tensor.sum? (Maybe .max?)

1 个答案:

答案 0 :(得分:1)

错误消息是

Traceback (most recent call last):
  File "D:/Dropbox/source/intro_theano/pinv.py", line 32, in <module>
    apinvt = map(lambda n: T.nlinalg.pinv(n), at)
  File "d:\dropbox\source\theano\theano\tensor\var.py", line 549, in __iter__
    raise TypeError(('TensorType does not support iteration. '
TypeError: TensorType does not support iteration. Maybe you are using builtin.sum instead of theano.tensor.sum? (Maybe .max?)

这是因为,如错误消息所示,符号变量at不可迭代。

这里的根本问题是你错误地将立即执行的Python代码与延迟执行的Theano符号代码混合。

您需要使用符号循环,而不是Python循环。正确的解决方案是使用Theano的scan运算符:

at=T.tensor3('a')
apinvt, _ = theano.scan(lambda n: T.nlinalg.pinv(n), at, strict=True)
f = theano.function([at], apinvt)
print np.allclose(f(a), apinv)