神经网络输入形状误差

时间:2017-06-12 22:00:13

标签: python-2.7 neural-network keras

我是keras的初学者,我正在尝试使用神经网络对数据进行分类。

#include <string>
using std::string;
#include <stdexcept>
using std:: invalid_argument;

struct Pet {
  const string name;
  long age = 0;
  const string species;

  //Pet()= default;
  Pet(): name("CrashDown"),age(0),species("ferret") {};
  Pet(const string & the_name, const string & the_species): name(the_name),   age(0),
  species(the_species) {};
 };

我这个脚本总是有这个错误,我不明白为什么:

   x_train = x_train.reshape(1,x_train.shape[0],window,5)
   x_val = x_val.reshape(1,x_val.shape[0],window,5)

   x_train = x_train.astype('float32')
   x_val = x_val.astype('float32')

   model = Sequential()

   model.add(Dense(64,activation='relu',input_shape= (data_dim,window,5)))
   model.add(Dropout(0.5))

   model.add(Dense(64,activation='relu'))
   model.add(Dropout(0.5))
   model.add(Dense(2,activation='softmax'))

   model.compile(loss='categorical_crossentropy',
          optimizer='sgd',
          metrics=['accuracy'])

   weights = model.get_weights()


   model_info = model.fit(x_train, y_train,batch_size=batchsize, nb_epoch=15,verbose=1,validation_data=(x_val, y_val))

  print x_train.shape
  #(1,1600,45,5)

  print y_train.shape
  #(1600,2)

1 个答案:

答案 0 :(得分:2)

您的模型的输出(dense_3,因为它是第三个Dense图层而命名)具有四个维度。但是,您尝试将其与(y_train)进行比较的标签只有两个维度。您需要更改网络的体系结构,以便模型重塑数据以匹配标签。

刚开始时难以跟踪张量形状,因此我建议在致电plot_model(model, to_file='model.png', show_shapes=True)之前致电model.fit。您可以查看生成的PNG,以了解图层对数据形状的影响。