带有Theano / pymc的TypeError(比较操作)

时间:2017-07-20 10:00:44

标签: numpy theano pymc

我运行的问题与最初在https://github.com/pymc-devs/pymc3/issues/1209报告的问题相同,即在Theano对象和numpy数组之间进行比较时出现TypeError。我的代码是使用运算符左侧的Theano对象编写的,并使用Numpy 1.13.1

调查我测试了一下

import pymc3
with pymc3.Model() as model:
    a = pymc3.Uniform("a", 1,2)
    print 1 < a

输出,没有抱怨:

Elemwise{gt,no_inplace}.0

正在运行

with pymc3.Model() as model:
    a = pymc3.Uniform("a", 1,2)
    if 1 < a:
        print "bingo"

产生TypeError

/usr/local/lib/python2.7/dist-packages/theano/tensor/var.pyc in __nonzero__(self)
     73     def __nonzero__(self):
     74         # Python 2.x
---> 75         return self.__bool__()
     76 
     77     def __bool__(self):
/usr/local/lib/python2.7/dist-packages/theano/tensor/var.pyc in __bool__(self)
     89         else:
     90             raise TypeError(
---> 91                 "Variables do not support boolean operations."
     92             )
     93 

TypeError: Variables do not support boolean operations.

所以我的问题是,我应该怎么做这种类型的测试?我想保留我的代码通用,因为在大多数情况下它不会在Theano对象上运行(当然我可以在pymc3 / Theano上下文中使用此函数的一个版本)。值得冒犯的代码是

......./refsans_tools/abeles/abeles.py in guess_optimal_x(self, thickness, roughness)
   1303                                                safety=self.safety
   1304                                               )
-> 1305         if this_xmin < self._xmin:
   1306             self._xmin = this_xmin
   1307             self._xmin = - self.shift_orig
/usr/local/lib/python2.7/dist-packages/theano/tensor/var.pyc in nonzero(self)
     73     def nonzero(self):
     74         # Python 2.x
---> 75         return self.bool()
     76 
     77     def bool(self):
/usr/local/lib/python2.7/dist-packages/theano/tensor/var.pyc in bool(self)
     89         else:
     90             raise TypeError(
---> 91                 "Variables do not support boolean operations."
     92             )
     93 
TypeError: Variables do not support boolean operations.

1 个答案:

答案 0 :(得分:0)

尝试一下:

with pymc3.Model() as model:
a = pymc3.Uniform("a", 1,2)
if tt.lt(1, a):
    print "bingo"