使用Keras VGG16的工作 - 故障输出形状

时间:2017-11-28 22:29:18

标签: python-3.x machine-learning computer-vision keras keras-layer

我正试图从头开始训练VGG16模型(来自keras.applications),我得到一个奇怪的错误

X_train形状是(73257,48,48,3)

Y_train形状是(73257,10)

我不知道发生了什么......我认为它与之前的转换层有关,但由于我直接从keras导入模型,因此我遇到了解决问题的方法。

我的数据集由73,257张图像组成,这些图像的形状为(48,48,3)。我本质上是在尝试进行字符识别(想想mnist风格),但是我正忙着通过模型(重量设置为0)来喂它。

model = keras.applications.vgg16.VGG16(include_top=False,
                                       weights=None,
                                       input_shape=(48, 48, 3),
                                       input_tensor=None, pooling='avg', classes=10)
sgd = SGD(lr=.1)

model.compile(loss='categorical_crossentropy',
              optimizer=sgd,
              metrics=['accuracy'])
print('X_train shape is {}'.format(X_train.shape))
print('Y_train shape is {}'.format(y_train.shape))

model.fit(X_train, y_train,
          epochs=20,
          batch_size=128)

score = model.evaluate(X_test, y_test, batch_size=128)

这是我得到的错误

  

文件“/ home / codebrotherone / PycharmProjects / Computer Vision / deep_neural / dnn.py”,第169行,在VGG16中       batch_size = 128)

     

文件“/home/codebrotherone/anaconda2/envs/tensorflow/lib/python3.4/site-packages/keras/engine/training.py”,第1574行,in fit       batch_size = batch_size)

     

文件“/home/codebrotherone/anaconda2/envs/tensorflow/lib/python3.4/site-packages/keras/engine/training.py”,第1411行,_standardize_user_data       exception_prefix ='target')

     

文件“/home/codebrotherone/anaconda2/envs/tensorflow/lib/python3.4/site-packages/keras/engine/training.py”,第141行,_standardize_input_data       str(array.shape))

     

ValueError:检查目标时出错:预期block5_pool有4个维度,但是有阵列形状(73257,10)

1 个答案:

答案 0 :(得分:1)

理想情况下,对于分类问题,您应该<!-- End Head --> <!-- Begin `Daily' Graph (5 Minute --><div class="graph"> <h2>`Daily' Graph (5 Minute Average)</h2> <img src="aklsr2_gi0_1-day.png" title="day" alt="day" /> <table> <tr> <th></th> <th scope="col">Max</th> <th scope="col">Average</th> <th scope="col">Current</th> </tr> <tr class="in"> <th scope="row">In</th> <td>9939.4 kb/s (99.4%)</td> <td>1908.7 kb/s (19.1%) </td> <td>80.8 kb/s (0.8%) </td> </tr> <tr class="out"> <th scope="row">Out</th> <td>9682.3 kb/s (96.8%) </td> <td>344.1 kb/s (3.4%) </td> <td>83.8 kb/s (0.8%) </td> </tr> <tr> <td colspan="8"> Average max 5 min values for `Daily' Graph (5 Minute interval): <span class="in">In</span> 2264.1 kb/s (22.6%)/ <span class="out">Out</span> 451.0 kb/s (4.5%) </td> </tr> </table> </div> <!-- End `Daily' Graph (5 Minute --> <!-- Begin `Weekly' Graph (30 Minute --> <div class="graph"> <h2>`Weekly' Graph (30 Minute Average)</h2> <img src="aklsr2_gi0_1-week.png" title="week" alt="week" /> <table> <tr> <th></th> <th scope="col">Max</th> <th scope="col">Average</th> <th scope="col">Current</th> </tr> <tr class="in"> <th scope="row">In</th> <td>9939.4 kb/s (99.4%)</td> <td>1273.3 kb/s (12.7%) </td> <td>98.8 kb/s (1.0%) </td> </tr> <tr class="out"> <th scope="row">Out</th> <td>9775.1 kb/s (97.8%) </td> <td>249.9 kb/s (2.5%) </td> <td>61.6 kb/s (0.6%) </td> </tr> <tr> <td colspan="8"> Average max 5 min values for `Weekly' Graph (30 Minute interval): <span class="in">In</span> 2236.6 kb/s (22.4%)/ <span class="out">Out</span> 593.8 kb/s (5.9%) </td> </tr> </table> </div> <!-- End `Weekly' Graph (30 Minute --> include_top=True

这就够了。由于你没有包括top,并且正在使用全局池,你应该得到像(73257,512)这样的东西。但是你得到的消息表明你没有在这次尝试中使用池。有些东西不太匹配。

无论如何,请继续:

classes=10