使用Lasagne / Theano获取输出分类
我正在将我的代码从纯Theano迁移到Lasagne。 我从教程中获得了这些特定代码,以获得具有特定数据的预测结果,并且我将生成一个csv文件以发送给kaggle。 但是对于千层面,它不起作用。 我尝试过几件事,但都会出错。
如果有人能帮我弄清楚什么是错的,我很乐意!
我在这里粘贴了整个代码: http://pastebin.com/e7ry3280
test_data = np.loadtxt("../inputData/test.csv", dtype=np.uint8, delimiter=',', skiprows=1)
# The inputs are vectors now, we reshape them to monochrome 2D images,
# following the shape convention: (examples, channels, rows, columns)
data = data.reshape(-1, 1, 28, 28)
test_data = test_data.reshape(-1, 1, 28, 28)
index = T.lscalar() # index to a [mini]batch
preds = []
for it in range(len(test_data)):
test_data = test_data[it]
N = len(test_data)
# print "N : ", N
test_data = theano.shared(np.asarray(test_data, dtype=theano.config.floatX))
test_labels = T.cast(theano.shared(np.asarray(np.zeros(batch_size), dtype=theano.config.floatX)),'uint8')
###target_var
#y = T.ivector('y') # the labels are presented as 1D vector of [int] labels
#index = T.lscalar() # index to a [mini]batch
ppm = theano.function([index],lasagne.layers.get_output(network, deterministic=True),
givens={
input_var: test_data[index * batch_size: (index + 1) * batch_size],
target_var: test_labels
}, on_unused_input='warn')
p = [ppm(ii) for ii in range(N // batch_size)]
p = np.array(p).reshape((N, 10))
print (p)
p = np.argmax(p, axis=1)
p = p.astype(int)
preds.append(p)
subm = np.empty((len(preds), 2))
subm[:, 0] = np.arange(1, len(preds) + 1)
subm[:, 1] = preds
np.savetxt('submission.csv', subm, fmt='%d', delimiter=',',header='ImageId,Label', comments='')
return preds
代码在以ppm = theano.function...
开头的行上失败:
TypeError:无法将Type TensorType(float32,3D)(Variable Subtensor {int64:int64:}。0)转换为Type TensorType(float32,4D)。您可以尝试手动将Subtensor {int64:int64:}。0转换为TensorType(float32,4D)。
我只是想将测试数据输入CNN并将结果输入CSV文件。我该怎么做?我知道我必须使用minibatches,因为整个测试数据都不适合GPU。
答案 0 :(得分:2)
正如评论中的错误消息和Daniel Renshaw所指出的,问题是test_data
和input_var
之间的维度不匹配。在循环的第一行,你写:
test_data = test_data[it]
将4D数组test_data
转换为具有相同名称的3D数组(这就是为什么不建议对不同类型使用相同的变量名的原因:))。之后,您将其封装在一个不会更改维度的共享变量中,然后对其进行切片以将其分配给input_var
,这也不会更改维度。
如果我理解您的代码,我认为您应该删除第一行。那样test_data
仍然是一个示例列表,您可以对其进行切片以进行批处理。