Python keras神经网络(Theano)包返回有关数据维度的错误

时间:2015-05-21 21:45:02

标签: python numpy canopy theano keras

我有这段代码:

import numpy as np
from keras.models import Sequential
from keras.layers.core import Dense, Dropout, Activation
from keras.optimizers import SGD
from sklearn import datasets
import theano

iris = datasets.load_iris()
X = iris.data[:,0:3]  # we only take the first two features.
Y = iris.target

X = X.astype(theano.config.floatX)
Y = Y.astype(theano.config.floatX)


model = Sequential()
model.add(Dense(150, 1, init='uniform'))
model.add(Activation('tanh'))
model.add(Dropout(0.5))
model.add(Dense(150, 1, init='uniform'))
model.add(Activation('tanh'))
model.add(Dropout(0.5))
model.add(Dense(150, 1, init='uniform'))
model.add(Activation('softmax'))

sgd = SGD(lr=0.1, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(loss='mean_squared_error', optimizer=sgd)

model.fit(X, Y, nb_epoch=20, batch_size=150)


score = model.evaluate(X_train, y_train, batch_size=16)

返回此错误:

ValueError: Shape mismatch: x has 3 cols (and 150 rows) but y has 150 rows (and 1 cols)
Apply node that caused the error: Dot22(<TensorType(float64, matrix)>, <TensorType(float64, matrix)>)
Inputs types: [TensorType(float64, matrix), TensorType(float64, matrix)]
Inputs shapes: [(150L, 3L), (150L, 1L)]
Inputs strides: [(24L, 8L), (8L, 8L)]
Inputs values: ['not shown', 'not shown']

有什么问题?

1 个答案:

答案 0 :(得分:7)

您为内部图层指定了错误的输出尺寸。例如,请参阅Keras文档中的此示例:

model = Sequential()
model.add(Dense(20, 64, init='uniform'))
model.add(Activation('tanh'))
model.add(Dropout(0.5))
model.add(Dense(64, 64, init='uniform'))
model.add(Activation('tanh'))
model.add(Dropout(0.5))
model.add(Dense(64, 2, init='uniform'))
model.add(Activation('softmax'))

注意一层的输出大小如何与下一层的输入大小匹配:

20x64 -> 64x64 -> 64x2

第一个数字始终是输入大小(前一层上的神经元数量),第二个数字是输出大小(下一层上的神经元数量)。因此,在此示例中,您有四个层:

  • 具有20个神经元的输入层
  • 具有64个神经元的隐藏层
  • 具有64个神经元的隐藏层
  • 具有2个神经元的输出层

你唯一的硬限制是第一个(输入)层需要拥有与你有特征一样多的神经元,而最后一个(输出)层需要拥有你需要的任意数量的神经元。

对于您的示例,由于您有三个功能,您需要将输入图层大小更改为3,并且您可以保持此示例中的两个输出神经元进行二进制分类(或使用一个,如您所做的那样,使用逻辑损失)。