我正试图从头开始训练VGG16模型(来自keras.applications),我得到一个奇怪的错误
X_train形状是(73257,48,48,3)
Y_train形状是(73257,10)
我不知道发生了什么......我认为它与之前的转换层有关,但由于我直接从keras导入模型,因此我遇到了解决问题的方法。
我的数据集由73,257张图像组成,这些图像的形状为(48,48,3)。我本质上是在尝试进行字符识别(想想mnist风格),但是我正忙着通过模型(重量设置为0)来喂它。
model = keras.applications.vgg16.VGG16(include_top=False,
weights=None,
input_shape=(48, 48, 3),
input_tensor=None, pooling='avg', classes=10)
sgd = SGD(lr=.1)
model.compile(loss='categorical_crossentropy',
optimizer=sgd,
metrics=['accuracy'])
print('X_train shape is {}'.format(X_train.shape))
print('Y_train shape is {}'.format(y_train.shape))
model.fit(X_train, y_train,
epochs=20,
batch_size=128)
score = model.evaluate(X_test, y_test, batch_size=128)
这是我得到的错误
文件“/ home / codebrotherone / PycharmProjects / Computer Vision / deep_neural / dnn.py”,第169行,在VGG16中 batch_size = 128)
文件“/home/codebrotherone/anaconda2/envs/tensorflow/lib/python3.4/site-packages/keras/engine/training.py”,第1574行,in fit batch_size = batch_size)
文件“/home/codebrotherone/anaconda2/envs/tensorflow/lib/python3.4/site-packages/keras/engine/training.py”,第1411行,_standardize_user_data exception_prefix ='target')
文件“/home/codebrotherone/anaconda2/envs/tensorflow/lib/python3.4/site-packages/keras/engine/training.py”,第141行,_standardize_input_data str(array.shape))
ValueError:检查目标时出错:预期block5_pool有4个维度,但是有阵列形状(73257,10)
答案 0 :(得分:1)
理想情况下,对于分类问题,您应该<!-- End Head -->
<!-- Begin `Daily' Graph (5 Minute --><div class="graph">
<h2>`Daily' Graph (5 Minute Average)</h2>
<img src="aklsr2_gi0_1-day.png" title="day" alt="day" />
<table>
<tr>
<th></th>
<th scope="col">Max</th>
<th scope="col">Average</th>
<th scope="col">Current</th>
</tr>
<tr class="in">
<th scope="row">In</th>
<td>9939.4 kb/s (99.4%)</td>
<td>1908.7 kb/s (19.1%) </td>
<td>80.8 kb/s (0.8%) </td>
</tr>
<tr class="out">
<th scope="row">Out</th>
<td>9682.3 kb/s (96.8%) </td>
<td>344.1 kb/s (3.4%) </td>
<td>83.8 kb/s (0.8%) </td>
</tr>
<tr>
<td colspan="8">
Average max 5 min values for `Daily' Graph (5 Minute interval):
<span class="in">In</span> 2264.1 kb/s (22.6%)/
<span class="out">Out</span> 451.0 kb/s (4.5%)
</td>
</tr>
</table>
</div>
<!-- End `Daily' Graph (5 Minute -->
<!-- Begin `Weekly' Graph (30 Minute -->
<div class="graph">
<h2>`Weekly' Graph (30 Minute Average)</h2>
<img src="aklsr2_gi0_1-week.png" title="week" alt="week" />
<table>
<tr>
<th></th>
<th scope="col">Max</th>
<th scope="col">Average</th>
<th scope="col">Current</th>
</tr>
<tr class="in">
<th scope="row">In</th>
<td>9939.4 kb/s (99.4%)</td>
<td>1273.3 kb/s (12.7%) </td>
<td>98.8 kb/s (1.0%) </td>
</tr>
<tr class="out">
<th scope="row">Out</th>
<td>9775.1 kb/s (97.8%) </td>
<td>249.9 kb/s (2.5%) </td>
<td>61.6 kb/s (0.6%) </td>
</tr>
<tr>
<td colspan="8">
Average max 5 min values for `Weekly' Graph (30 Minute interval):
<span class="in">In</span> 2236.6 kb/s (22.4%)/
<span class="out">Out</span> 593.8 kb/s (5.9%)
</td>
</tr>
</table>
</div>
<!-- End `Weekly' Graph (30 Minute -->
和include_top=True
。
这就够了。由于你没有包括top,并且正在使用全局池,你应该得到像(73257,512)这样的东西。但是你得到的消息表明你没有在这次尝试中使用池。有些东西不太匹配。
无论如何,请继续:
classes=10