训练过程中,验证准确性的问题无法改善

时间:2020-04-30 19:39:50

标签: tensorflow tensorflow2.0

在使用ResNet50的Transfer Learning训练超声图像以对良性和恶性图像进行分类时,我对验证准确性存在疑问。

我尝试更改时代,学习率,批处理大小,添加更多层,应用数据增强,但没有任何改善。我的文件夹设置为数据->训练或测试->良性或恶性(对于每个训练和测试文件夹)。

我在网上找到了大部分代码,并尝试将其应用于我的培训目标。我正在使用Tensorflow 2.1并使用CNN进行培训。 我的代码是:

# -*- coding: utf-8 -*-

from keras.layers import Input, Lambda, Dense, Flatten, Dropout, BatchNormalization
from keras.models import Model
from keras.applications.resnet50 import ResNet50
import tensorflow as tf
# from keras.applications.vgg16 import VGG16
from keras.applications.vgg16 import preprocess_input
from keras.preprocessing import image
from keras.preprocessing.image import ImageDataGenerator
from keras.models import Sequential
import numpy as np
from glob import glob
import matplotlib.pyplot as plt
import os
from tensorflow.keras.callbacks import EarlyStopping

# re-size all the images to this
IMAGE_SIZE = [224, 224]


data_dir = 'C:\Spring 2020\Machine Learning and Computer Vision\data_resize'
os.listdir(data_dir)
valid_path = data_dir+'\\test\\'
train_path = data_dir+'\\train\\'


# Import the Vgg 16 library as shown below and add preprocessing layer to the front of VGG
# Here we will be using imagenet weights

resnet = ResNet50(input_shape=IMAGE_SIZE + [3], weights='imagenet', include_top=False)

# don't train existing weights
for layer in resnet.layers:
    layer.trainable = False

# useful for getting number of output classes
folders = glob(train_path + '\*')

# our layers - you can add more if you want
x = Flatten()(resnet.output)
#x = Flatten()(base_model.output)
x = Dense(4096, activation='relu')(x)
x = Dense(2048, activation='relu')(x)
x = Dense(1024, activation='relu')(x)
x = Dense(512, activation='relu')(x)
x = Dense(256, activation='relu')(x)
x = Dense(128, activation='relu')(x)
x = Dense(64, activation='relu')(x)
x = Dense(32, activation='relu')(x)
x = Dense(16, activation='relu')(x)
x = Dense(8, activation='relu')(x)
x = Dense(4, activation='relu')(x)
x = Dropout(0.5)(x)
x = BatchNormalization()(x)
prediction = Dense(1, activation = 'softmax')(x)



# prediction = Dense(len(folders), activation='sigmoid')(x)

# create a model object
model = Model(inputs=resnet.input, outputs=prediction)

# view the structure of the model
model.summary()

# tell the model what cost and optimization method to use
adam = tf.keras.optimizers.Adam(
    learning_rate=0.0001, beta_1=0.9, beta_2=0.999, epsilon=1e-07, amsgrad=False,
    name='adam' 
)
model.compile(
  loss='binary_crossentropy',
  optimizer=adam,
  metrics=['accuracy']
)

# Use the Image Data Generator to import the images from the dataset
from keras.preprocessing.image import ImageDataGenerator

train_datagen = ImageDataGenerator(rescale = 1./255,
                                   shear_range = 0.2,
                                   zoom_range = 0.2,
                                   horizontal_flip = True)

test_datagen = ImageDataGenerator(rescale = 1./255)

# Make sure you provide the same target size as initialied for the image size
training_set = train_datagen.flow_from_directory(train_path,
                                                 target_size = (224, 224),
                                                 batch_size = 32,
                                                 class_mode = 'binary')

test_set = test_datagen.flow_from_directory(valid_path,
                                            target_size = (224, 224),
                                            batch_size = 32,
                                            class_mode = 'binary',
                                            shuffle = False)


# fit the model
# Run the cell. It will take some time to execute
early_stop = EarlyStopping(monitor='val_loss',patience=20)
r = model.fit_generator(
  training_set,
  validation_data=test_set,
  epochs=30,
  steps_per_epoch=len(training_set),
  validation_steps=len(test_set),
  callbacks = [early_stop]
)


# plot the loss
plt.plot(r.history['loss'], label='train loss')
plt.plot(r.history['val_loss'], label='val loss')
plt.legend()
plt.show()
plt.savefig('LossVal_loss')

# plot the accuracy
plt.plot(r.history['accuracy'], label='train acc')
plt.plot(r.history['val_accuracy'], label='val acc')
plt.legend()
plt.show()
plt.savefig('AccVal_acc')

import tensorflow as tf

from keras.models import load_model

model.save('model_vgg16.h5')

这是训练时的结果:

runfile('C:/Spring 2020/Machine Learning and Computer Vision/need_to_try.py', wdir='C:/Spring 2020/Machine Learning and Computer Vision')
Using TensorFlow backend.
C:\Users\binhd\Anaconda3\lib\site-packages\keras_applications\resnet50.py:265: UserWarning: The output shape of `ResNet50(include_top=False)` has been changed since Keras 2.2.0.
  warnings.warn('The output shape of `ResNet50(include_top=False)` '
Model: "model_1"
__________________________________________________________________________________________________

Total params: 445,822,513
Trainable params: 422,234,793
Non-trainable params: 23,587,720
__________________________________________________________________________________________________
Found 2198 images belonging to 2 classes.
Found 800 images belonging to 2 classes.
Epoch 1/10
69/69 [==============================] - 22s 325ms/step - loss: 7.6573 - accuracy: 0.5000 - val_loss: 0.0000e+00 - val_accuracy: 0.5000
Epoch 2/10
69/69 [==============================] - 18s 260ms/step - loss: 7.6604 - accuracy: 0.5000 - val_loss: 0.0000e+00 - val_accuracy: 0.5000
Epoch 3/10
69/69 [==============================] - 19s 274ms/step - loss: 7.6791 - accuracy: 0.5000 - val_loss: 0.0000e+00 - val_accuracy: 0.5000
Epoch 4/10
69/69 [==============================] - 19s 275ms/step - loss: 7.6604 - accuracy: 0.5000 - val_loss: 0.0000e+00 - val_accuracy: 0.5000
Epoch 5/10
69/69 [==============================] - 19s 275ms/step - loss: 7.6791 - accuracy: 0.5000 - val_loss: 0.0000e+00 - val_accuracy: 0.5000

这是情节:y是历元数,x是百分比(0.5是50%)。

enter image description here

enter image description here

由于身体的限制,我在下面的摘要中仅发布了一些开始和结束层:


____________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            (None, 224, 224, 3)  0                                            
__________________________________________________________________________________________________
conv1_pad (ZeroPadding2D)       (None, 230, 230, 3)  0           input_1[0][0]                    
__________________________________________________________________________________________________
conv1 (Conv2D)                  (None, 112, 112, 64) 9472        conv1_pad[0][0]                  
__________________________________________________________________________________________________
bn_conv1 (BatchNormalization)   (None, 112, 112, 64) 256         conv1[0][0]                      
__________________________________________________________________________________________________
activation_1 (Activation)       (None, 112, 112, 64) 0           bn_conv1[0][0]                   
__________________________________________________________________________________________________
pool1_pad (ZeroPadding2D)       (None, 114, 114, 64) 0           activation_1[0][0]               
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D)  (None, 56, 56, 64)   0           pool1_pad[0][0]                  
__________________________________________________________________________________________________
res2a_branch2a (Conv2D)         (None, 56, 56, 64)   4160        max_pooling2d_1[0][0]            
__________________________________________________________________________________________________
    .........
    ........
    .....

activation_47 (Activation)      (None, 7, 7, 512)    0           bn5c_branch2a[0][0]              
    __________________________________________________________________________________________________
    res5c_branch2b (Conv2D)         (None, 7, 7, 512)    2359808     activation_47[0][0]              
    __________________________________________________________________________________________________
    bn5c_branch2b (BatchNormalizati (None, 7, 7, 512)    2048        res5c_branch2b[0][0]             
    __________________________________________________________________________________________________
    activation_48 (Activation)      (None, 7, 7, 512)    0           bn5c_branch2b[0][0]              
    __________________________________________________________________________________________________
    res5c_branch2c (Conv2D)         (None, 7, 7, 2048)   1050624     activation_48[0][0]              
    __________________________________________________________________________________________________
    bn5c_branch2c (BatchNormalizati (None, 7, 7, 2048)   8192        res5c_branch2c[0][0]             
    __________________________________________________________________________________________________
    add_16 (Add)                    (None, 7, 7, 2048)   0           bn5c_branch2c[0][0]              
                                                                     activation_46[0][0]              
    __________________________________________________________________________________________________
    activation_49 (Activation)      (None, 7, 7, 2048)   0           add_16[0][0]                     
    __________________________________________________________________________________________________
    flatten_1 (Flatten)             (None, 100352)       0           activation_49[0][0]              
    __________________________________________________________________________________________________
    dense_1 (Dense)                 (None, 4096)         411045888   flatten_1[0][0]                  
    __________________________________________________________________________________________________
    dense_2 (Dense)                 (None, 2048)         8390656     dense_1[0][0]                    
    __________________________________________________________________________________________________
    dense_3 (Dense)                 (None, 1024)         2098176     dense_2[0][0]                    
    __________________________________________________________________________________________________
    dense_4 (Dense)                 (None, 512)          524800      dense_3[0][0]                    
    __________________________________________________________________________________________________
    dense_5 (Dense)                 (None, 256)          131328      dense_4[0][0]                    
    __________________________________________________________________________________________________
    dense_6 (Dense)                 (None, 128)          32896       dense_5[0][0]                    
    __________________________________________________________________________________________________
    dense_7 (Dense)                 (None, 64)           8256        dense_6[0][0]                    
    __________________________________________________________________________________________________
    dense_8 (Dense)                 (None, 32)           2080        dense_7[0][0]                    
    __________________________________________________________________________________________________
    dense_9 (Dense)                 (None, 16)           528         dense_8[0][0]                    
    __________________________________________________________________________________________________
    dense_10 (Dense)                (None, 8)            136         dense_9[0][0]                    
    __________________________________________________________________________________________________
    dense_11 (Dense)                (None, 4)            36          dense_10[0][0]                   
    __________________________________________________________________________________________________
    dropout_1 (Dropout)             (None, 4)            0           dense_11[0][0]                   
    __________________________________________________________________________________________________
    batch_normalization_1 (BatchNor (None, 4)            16          dropout_1[0][0]                  
    __________________________________________________________________________________________________
    dense_12 (Dense)                (None, 1)            5           batch_normalization_1[0][0]      
    ==================================================================================================

我希望得到大家的帮助。请帮忙!感谢您在阅读和回复我的帖子时所花费的时间和关注。

1 个答案:

答案 0 :(得分:0)

从技术上讲,softmax是对S类的多分类归类。但是,我得到的结果与您在最后一层中使用softmax激活时得到的结果相似。请检查代码here

在同一代码中,当我将softmax替换为sigmoid时,得到了预期的良好结果。请检查gist here

您需要将最后一层更改为

prediction = Dense(1, activation = 'softmax')(x)

prediction = Dense(1, activation = 'sigmoid')(x)

除此之外,一切看起来都很好。

如果上述方法无效,您可以尝试

替换最后一层

prediction = Dense(1, activation = 'softmax')(x)

prediction = Dense(1)(x)

在那之后更改编译源

model.compile(loss='binary_crossentropy',optimizer=adam,metrics=['accuracy'])

model.compile(loss = tf.keras.losses.BinaryCrossentropy(from_logits = True),optimizer = adam,metrics = ['accuracy'])