我正在使用类似于alexnet的CNN进行与图像相关的回归任务。我为损失函数定义了一个rmse。然而,在第一个时代的训练期间,损失带来了巨大的价值。但是在第二个时代之后,它降到了一个有意义的价值。这是:
1/51 [..............................] - ETA:847s - 损失:104.1821 - acc:0.2500 - root_mean_squared_error:104.1821 2/51 [> .............................] - ETA:470s - 损失:5277326.0910 - acc:0.5938 - root_mean_squared_error:5277326.0910 3/51 [> .............................] - ETA:345s - 损失:3518246.7337 - acc:0.5000 - root_mean_squared_error:3518246.7337 4/51 [=> ............................] - ETA:281s - 损失:2640801.3379 - acc:0.6094 - root_mean_squared_error:2640801.3379 5/51 [=> ............................] - ETA:241s - 损失:2112661.3062 - acc:0.5000 - root_mean_squared_error:2112661.3062 6/51 [==> ...........................] - ETA:214s - 损失:1760566.4758 - acc:0.4375 - root_mean_squared_error:1760566.4758 7/51 [===> ..........................] - ETA:194s - 损失:1509067.6495 - acc:0.4464 - root_mean_squared_error:1509067.6495 8/51 [===> ..........................] - ETA:178s - 损失:1320442.6319 - acc:0.4570 - root_mean_squared_error:1320442.6319 9/51 [====> .........................] - ETA:165s - 损失:1173734.9212 - acc:0.4792 - root_mean_squared_error:1173734.9212 10/51 [====> .........................] - ETA:155s - 损失:1056369.3193 - acc:0.4875 - root_mean_squared_error:1056369.3193 11/51 [=====> ........................] - ETA:146s - 损失:960343.5998 - acc:0.4943 - root_mean_squared_error:960343.5998 12/51 [======> .......................] - ETA:139s - 损失:880320.3762 - acc:0.5052 - root_mean_squared_error:880320.3762 13/51 [======> .......................] - ETA:131s - 损失:812608.7112 - acc:0.5216 - root_mean_squared_error:812608.7112 14/51 [=======> ......................] - ETA:125s - 损失:754570.1939 - acc:0.5402 - root_mean_squared_error:754570.1939 15/51 [=======> ......................] - ETA:120s - 损失:704269.2443 - acc:0.5479 - root_mean_squared_error:704269.2443 16/51 [========> .....................] - ETA:114s - 损失:660256.3035 - acc:0.5508 - root_mean_squared_error:660256.3035 17/51 [========> .....................] - ETA:109s - 损失:621420.7248 - acc:0.5607 - root_mean_squared_error:621420.7248 18/51 [=========> ....................] - ETA:104s - 损失:586900.8398 - acc:0.5712 - root_mean_squared_error:586900.8398 19/51 [==========> ...................] - ETA:100s - 损失:556014.6719 - acc:0.5806 - root_mean_squared_error:556014.6719 20/51 [==========> ...................] - ETA:95s - 损失:528216.9077 - acc: 0.5875 - root_mean_squared_error:528216.9077 21/51 [===========> ..................] - ETA:91s - 损失:503065.7743 - ACC: 0.5967 - root_mean_squared_error:503065.7743 22/51 [===========> ..................] - ETA:87s - 损失:480206.3521 - ACC: 0.6094 - root_mean_squared_error:480206.3521 23/51 [============> .................] - ETA:83s - 损失:459331.8636 - ACC: 0.6114 - root_mean_squared_error:459331.8636 24/51 [=============> ................] - ETA:80s - 损失:440196.2991 - ACC: 0.6159 - root_mean_squared_error:440196.2991 25/51 [=============> ................] - ETA:76s - 损失:422590.8381 - ACC: 0.6162 - root_mean_squared_error:422590.8381 26/51 [==============> ...............] - ETA:73s - 损失:406339.5179 - ACC: 0.6178 - root_mean_squared_error:406339.5179 27/51 [==============> ...............] - ETA:69s - 损失:391292.6992 - ACC: 0.6238 - root_mean_squared_error:391292.6992 28/51 [===============> ..............] - ETA:66s - 损失:377319.9851 - ACC: 0.6306 - root_mean_squared_error:377319.9851 29/51 [===============> ..............] - ETA:63s - 损失:364310.7557 - ACC: 0.6336 - root_mean_squared_error:364310.7557 30/51 [================> .............] - ETA:60s - 损失:352169.1059 - ACC: 0.6385 - root_mean_squared_error:352169.1059 31/51 [=================> ............] - ETA:57s - 损失:340810.8854 - ACC: 0.6401 - root_mean_squared_error:340810.8854 32/51 [=================> ............] - ETA:53s - 损失:330162.1334 - ACC: 0.6455 - root_mean_squared_error:330162.1334 33/51 [==================> ...........] - ETA:50s - 损失:320158.7622 - ACC: 0.6553 - root_mean_squared_error:320158.7622 34/51 [==================> ...........] - ETA:47s - 损失:310744.0080 - ACC: 0.6645 - root_mean_squared_error:310744.0080 35/51 [===================> ..........] - ETA:44s - 损失:301866.8259 - ACC: 0.6714 - root_mean_squared_error:301866.8259 36/51 [====================> .........] - ETA:41s - 损失:293483.0129 - ACC: 0.6762 - root_mean_squared_error:293483.0129 37/51 [====================> .........] - ETA:39s - 损失:285552.8197 - ACC: 0.6757 - root_mean_squared_error:285552.8197 38/51 [=====================> ........] - ETA:36s - 损失:278039.4488 - ACC: 0.6752 - root_mean_squared_error:278039.4488 39/51 [=====================> ........] - ETA:33s - 损失:270911.4670 - ACC: 0.6795 - root_mean_squared_error:270911.4670 40/51 [======================> .......] - ETA:30s - 损失:264140.2391 - ACC: 0.6820 - root_mean_squared_error:264140.2391 41/51 [=======================> ......] - ETA:27s - 损失:257699.1895 - ACC: 0.6852 - root_mean_squared_error:257699.1895 42/51 [=======================> ......] - ETA:25s - 损失:251564.6846 - ACC: 0.6890 - root_mean_squared_error:251564.6846 43/51 [========================> .....] - ETA:22s - 损失:245715.4124 - ACC: 0.6933 - root_mean_squared_error:245715.4124 44/51 [========================> .....] - ETA:19s - 损失:240131.9916 - ACC: 0.6960 - root_mean_squared_error:240131.9916 45/51 [=========================> ....] - ETA:16s - 损失:234796.6948 - ACC: 0.7007 - root_mean_squared_error:234796.6948 46/51 [=========================> ....] - ETA:14s - 损失:229693.3717 - ACC: 0.7045 - root_mean_squared_error:229693.3717 47/51 [==========================> ...] - ETA:11s - 损失:224807.2748 - ACC: 0.7055 - root_mean_squared_error:224807.2748 48/51 [===========================> ..] - ETA:8s - 损失:220125.0731 - ACC: 0.7077 - root_mean_squared_error:220125.0731 49/51 [===========================> ..] - ETA:5s - 损失:215634.5638 - ACC: 0.7117 - root_mean_squared_error:215634.5638 50/51 [============================&gt ;.] - ETA:3s - 损失:211323.1692 - ACC: 0.7144 - root_mean_squared_error:211323.1692 51/51 [============================&gt ;.] - ETA:0s - 损失:207180.6328 - ACC: 0.7151 - root_mean_squared_error:207180.6328 52/51 [==============================] - 143s - 损失:203253.6237 - acc: 0.7157 - root_mean_squared_error:203253.6237 - val_loss:44.4203 - val_acc:0.9878 - val_root_mean_squared_error:44.4203 Epoch 2/128 1/51 [..............................] - ETA:117s - 损失:52.6087 - acc:0.7188 - root_mean_squared_error:52.6087
如何理解这种行为?这是我的实施。首先定义rmse函数:
from keras import backend as K
def root_mean_squared_error(y_true, y_pred):
return K.sqrt(K.mean(K.square(y_pred - y_true), axis=-1))
然后是模型:
model.compile(optimizer="rmsprop", loss=root_mean_squared_error, metrics=['accuracy', root_mean_squared_error])
然后适合模型:
estimator = alexmodel()
datagen = ImageDataGenerator()
datagen.fit(x_train)
start = time.time()
history = estimator.fit_generator(datagen.flow(x_train, x_train,batch_size=batch_size, shuffle=True),
epochs=epochs,
steps_per_epoch=x_train.shape[0]/batch_size,
validation_data=(x_test, y_test))
end = time.time()
谁能告诉我为什么会这样?有什么潜在的错误吗?
答案 0 :(得分:1)
因此 - 规范化数据非常重要。您似乎没有对目标进行规范化,因为网络通常会以这样的方式进行初始化,即在开始时会产生较小的值 - 这会使您在第一个时期的损失如此巨大。因此,我仍建议您规范化目标(使用StandardScaler
或MinMaxScaller
),因为需要生成高比例值会使网络的权重具有更高的绝对值,这是你应该阻止你的网络。