我刚刚开始玩Theano,我对这段代码的结果感到惊讶。
from theano import *
import theano.tensor as T
a = T.vector()
out = a + a ** 10
f = function([a], out)
print(f([0, 1, 2]))
使用python3我得到:
array([ 0., 2., 1026.])
数组本身是正确的,它包含正确的值,但打印输出是奇数。我希望这样的事情:
array([0, 2, 1026])
或
array([0.0, 2.0, 1026.0])
为什么会这样?什么是额外的白色空间?我要关注吗?
答案 0 :(得分:1)
您要打印的是numpy.ndarray
。默认情况下,它们在打印时会像这样格式化。
输出数组是一个浮点数组,因为默认情况下,Theano使用浮点数张量。
如果要使用整数张量,则需要指定dtype
:
a = T.vector(dtype='int64')
或者使用一些语法糖:
a = T.lvector()
将您的输出与以下输出进行比较:
print numpy.array([0, 2, 1026], dtype=numpy.float64)
print numpy.array([0, 2, 1026], dtype=numpy.int64)
您可以使用numpy.set_printoptions
更改numpy的默认打印选项。