我刚刚开始学习CNN和tensorflow。我在测试一个完全卷积网络(基于VGG16)时遇到了一个案例,其中我发现训练损失和验证损失都是零,并且永远不会改变。因此,网络永远无法学习任何积极的功能。希望有人能告诉我网络中发生了什么。
Accuracy and loss plotted in tensorboard
我使用最新的tflearn实现了网络(直接从github获取)。在conv5层,我使用了atrous_conv_2d而不是conv_2d,并且在将它们上采样到相同比例(通过使用upscore_layer)后,我合并了前四个池和最后一个完全卷积层。
代码如下所示。顺便说一下,我正在做的是通过使用tflearn或tensorflow来实现本文中提出的FCN(http://i.cs.hku.hk/~gbli/deep_saliency.html)。
import tensorflow as tf
import tflearn
from tflearn.layers.core import input_data, dropout, fully_connected, activation
from tflearn.layers.conv import conv_2d, max_pool_2d, atrous_conv_2d, upscore_layer
from tflearn.layers.merge_ops import merge
from tflearn.layers.normalization import local_response_normalization,batch_normalization
from tflearn.layers.estimator import regression
from tflearn.data_preprocessing import ImagePreprocessing
from tflearn.data_augmentation import ImageAugmentation
def InitMSFCN(t_xWSize,t_xHSize,t_channel=1,img_aug=None,img_prep=None, l_rate = 0.0002):
data = input_data(shape=[None, t_xHSize, t_xWSize, t_channel],
data_augmentation=img_aug,
data_preprocessing=img_prep)
data_conv = conv_2d(data, 128, 3, strides=1, activation='relu', name='data_conv')
data_conv = dropout(data_conv, 0.5, name='drop_data_conv')
data_fc = conv_2d(data_conv, 128, 1, strides=1, activation='relu', name='data_fc')
data_fc = dropout(data_fc, 0.5, name='drop_data_fc')
data_ms_silhouette = conv_2d(data_fc, 1, 1, name='data_ms_silhouette')
conv1_1 = conv_2d(data, 64, 3, activation='relu', name='conv1_1')
conv1_2 = conv_2d(conv1_1, 64, 3, activation='relu', name='conv1_2')
pool1 = max_pool_2d(conv1_2, 3, strides=2, name='pool1')
pool1_conv = conv_2d(pool1, 128, 3, strides=1, activation='relu', name='pool1_conv')
pool1_conv = dropout(pool1_conv, 0.5, name='drop_pool1_conv')
pool1_fc = conv_2d(pool1_conv, 128, 1, activation='relu', name='pool1_fc')
pool1_fc = dropout(pool1_fc, 0.5, name='drop_pool1_fc')
pool1_ms_silhouette = conv_2d(pool1_fc, 1, 1, name='pool1_ms_silhouette')
conv2_1 = conv_2d(pool1, 128, 3, activation='relu', name='conv2_1')
conv2_2 = conv_2d(conv2_1, 128, 3, activation='relu', name='conv2_2')
pool2 = max_pool_2d(conv2_2, 3, strides=2, name='pool2')
pool2_conv = conv_2d(pool2, 128, 3, strides=1, activation='relu', name='pool2_conv')
pool2_conv = dropout(pool2_conv, 0.5, name='drop_pool2_conv')
pool2_fc = conv_2d(pool2_conv, 128, 1, activation='relu', name='pool2_fc')
pool2_fc = dropout(pool2_fc, 0.5, name='drop_pool2_fc')
pool2_ms_silhouette = conv_2d(pool2_fc, 1, 1, name='pool2_ms_silhouette')
conv3_1 = conv_2d(pool2, 256, 3, activation='relu', name='conv3_1')
conv3_2 = conv_2d(conv3_1, 256, 3, activation='relu', name='conv3_2')
conv3_3 = conv_2d(conv3_2, 256, 3, activation='relu', name='conv3_3')
pool3 = max_pool_2d(conv3_3, 3, strides=2, name='pool3')
pool3_conv = conv_2d(pool3, 128, 3, strides=1, activation='relu', name='pool3_conv')
pool3_conv = dropout(pool3_conv, 0.5, name='drop_pool3_conv')
pool3_fc = conv_2d(pool3_conv, 128, 1, strides=1, activation='relu', name='pool3_fc')
pool3_fc = dropout(pool3_fc, 0.5, name='drop_pool3_fc')
pool3_ms_silhouette = conv_2d(pool3_fc, 1, 1, name='pool3_ms_silhouette')
conv4_1 = conv_2d(pool3, 512, 3, activation='relu', name='conv4_1')
conv4_2 = conv_2d(conv4_1, 512, 3, activation='relu', name='conv4_2')
conv4_3 = conv_2d(conv4_2, 512, 3, activation='relu', name='conv4_3')
pool4 = max_pool_2d(conv4_3, 3, strides=1, name='pool4')
pool4_conv = conv_2d(pool4, 128, 3, strides=1, activation='relu', name='pool4_conv')
pool4_conv = dropout(pool4_conv, 0.5, name='drop_pool4_conv')
pool4_fc = conv_2d(pool4_conv, 128, 1, strides=1, activation='relu', name='pool4_fc')
pool4_fc = dropout(pool4_fc, 0.5, name='drop_pool4_fc')
pool4_ms_silhouette = conv_2d(pool4_fc, 1, 1, name='pool4_ms_silhouette')
conv5_1 = atrous_conv_2d(pool4, 512, 3, rate=2, activation='relu', name='conv5_1')
conv5_2 = atrous_conv_2d(conv5_1, 512, 3, rate=2, activation='relu', name='conv5_2')
conv5_3 = atrous_conv_2d(conv5_2, 512, 3, rate=2, activation='relu', name='conv5_3')
pool5 = max_pool_2d(conv5_3, 3, strides=1, name='pool5')
fc6 = atrous_conv_2d(pool5, 4096, 4, rate=4, activation='relu', name='fc6')
fc6 = dropout(fc6, 0.5, name='drop6')
fc7 = conv_2d(fc6, 4096, 1, activation='relu', name='fc7')
fc7 = dropout(fc7, 0.5, name='drop7')
fc8_silhouette = conv_2d(fc7, 1, 1, name='fc7')
pool1_interp = upscore_layer(pool1_ms_silhouette, 1, strides=2, shape=tf.shape(data_ms_silhouette))
pool2_interp = upscore_layer(pool2_ms_silhouette, 1, strides=4, shape=tf.shape(data_ms_silhouette))
pool3_interp = upscore_layer(pool3_ms_silhouette, 1, strides=8, shape=tf.shape(data_ms_silhouette))
pool4_interp = upscore_layer(pool4_ms_silhouette, 1, strides=8, shape=tf.shape(data_ms_silhouette))
fc8_interp = upscore_layer(fc8_silhouette, 1, strides=8, shape=tf.shape(data_ms_silhouette))
fc_fusion = merge([data_ms_silhouette,pool1_interp,pool2_interp,pool3_interp,pool4_interp,fc8_interp], 'elemwise_sum', 0, name='fc_fusion')
fc_silhouette_reg = activation(fc_fusion, activation='sigmoid', name='fc_silhouette_reg')
network = regression(fc_silhouette_reg, optimizer='adam',loss='categorical_crossentropy',
learning_rate=l_rate, name='fc_silhouette')
return network
if __name__ == '__main__':
# Real-time data preprocessing
img_prep = ImagePreprocessing()
img_prep.add_featurewise_zero_center()
img_prep.add_featurewise_stdnorm()
network = InitMSFCN(104,136,3,img_prep=img_prep)
model = tflearn.DNN(network,best_checkpoint_path='./bestCheck/',max_checkpoints=100,tensorboard_verbose=3)
from glob import glob
from os.path import join
import cv2
import numpy as np
imgs = []
labels = []
img_paths = glob(join('./ori','*.jpg'))
for img_path in img_paths:
name = img_path.split('/')[-1]
print name
label_path = join('./sil',name)
img=cv2.imread(img_path)
label = cv2.imread(label_path,0)
imgs.append(img)
labels.append(label)
imgs = np.array(imgs).astype('float64')
labels = np.array(labels)
labels.shape=(labels.shape[0],labels.shape[1],labels.shape[2],1)
print imgs.shape
print labels.shape
model.fit(X_inputs = imgs, Y_targets = labels, n_epoch=2000,snapshot_step=10000, batch_size=16, validation_set=0.2,
snapshot_epoch=False, shuffle=True, show_metric=True, run_id='MSFCN')
提前谢谢你,
哲