如何从共享变量获取csv格式的预测结果

时间:2015-08-11 07:06:23

标签: theano

数据是共享变量。我想以csv格式获得预测结果。下面是代码。 它抛出一个错误。怎么修?谢谢你的帮助!

TypeError: ('Bad input argument to theano function with name "4.py:305"  at index 
0(0-based)', 'Expected an array-like object, 
but found a Variable: maybe you are trying to call a function on a (possibly shared)  
variable instead of a numeric array?')

test_model = theano.function(
    inputs=[index],
    outputs=classifier.errors(y),
    givens={
        x: test_set_x[index * batch_size:(index + 1) * batch_size],
        y: test_set_y[index * batch_size:(index + 1) * batch_size]
    }
)

def make_submission_csv(predict, is_list=False):
    if is_list:
        df = pd.DataFrame({'Id': range(1, 101), 'Label': predict})
        df.to_csv("submit.csv", index=False)
        return
    pred = []
    for i in range(100):
        pred.append(test_model(test.values[i]))
    df = pd.DataFrame({'Id': range(1, 101), 'Label': pred})
    df.to_csv("submit.csv", index=False)
make_submission_csv(np.argmax(test_model(test_set_x), axis=1), is_list=True)

有关" index"。

的更多信息
index = T.iscalar()  
x = T.matrix('x')  
y = T.ivector('y')
输入时

test_set_x.get_value(borrow=True)

控制台显示:

array([[ 0.,  0.,  0., ...,  0.,  0.,  0.],
       [ 0.,  0.,  0., ...,  0.,  0.,  0.],
       [ 0.,  0.,  0., ...,  0.,  0.,  0.],
       ..., 
       [ 0.,  0.,  0., ...,  0.,  0.,  0.],
       [ 0.,  0.,  0., ...,  0.,  0.,  0.],
       [ 0.,  0.,  0., ...,  0.,  0.,  0.]], dtype=float32)

输入时:

test_model(test_set_x.get_value())

它会抛出错误:

TypeError: ('Bad input argument to theano function with name "4.py:311"  at index 0(0-based)', 'TensorType(int32, scalar) cannot store a value of dtype float32 without risking loss of precision. 

1 个答案:

答案 0 :(得分:0)

您的test_model函数只有一个输入值,

inputs=[index],

您的粘贴代码并未显示变量index的创建,但我的猜测是它是带有整数类型的Theano符号标量。如果是这样,您需要使用单个整数输入调用已编译的函数,例如

test_model(1)

您正在尝试调用test_model(test_set_x),因为test_set_x(也可能是)共享变量,而不是函数所期望的整数索引,因此test_losses = [test_model(i) for i in xrange(n_test_batches)] 无法正常工作。

请注意,tutorial code执行此操作:

{{1}}