我正在训练类似于-Keras model params are all "NaN"s after reloading的三元组模型;除了此模型建立在inception_v3模型之上。
(我正在将Keras与Tensorflow后端一起使用)
但是仅2个纪元后,模型权重变为NaN。当我尝试通过传递输入图像来提取学习到的特征时,这些特征全为0。
模型架构-
def Triplet_loss(x,ALPHA = 0.2):
anchor, positive, negative = x
pos_dist = tf.reduce_sum(tf.square(tf.subtract(anchor, positive)), 1)
neg_dist = tf.reduce_sum(tf.square(tf.subtract(anchor, negative)), 1)
basic_loss = tf.add(tf.subtract(pos_dist, neg_dist), ALPHA)
loss = tf.reduce_mean(tf.maximum(basic_loss, 0.0), 0)
return loss
StyleNet()类:
def __init__(self, input_shape_x, input_shape_y, input_shape_z, n_classes, reg_lambda):
self.input_shape_x = input_shape_x
self.input_shape_y = input_shape_y
self.input_shape_z = input_shape_z
self.n_classes = n_classes
self.reg_lambda = reg_lambda
def create_model(self):
anchor_example = Input(shape=(self.input_shape_x, self.input_shape_y, self.input_shape_z), name='input_1')
positive_example = Input(shape=(self.input_shape_x, self.input_shape_y, self.input_shape_z), name='input_2')
negative_example = Input(shape=(self.input_shape_x, self.input_shape_y, self.input_shape_z), name='input_3')
input_image = Input(shape=(self.input_shape_x, self.input_shape_y, self.input_shape_z))
base_inception = InceptionV3(input_tensor = input_image, input_shape=(self.input_shape_x, self.input_shape_y, self.input_shape_z), weights=None, include_top=False, pooling='avg')
base_pool5 = base_inception.output
##############Adding the Bottleneck layer Here#######################################################
bottleneck_layer = Dense(256, kernel_regularizer=l2(self.reg_lambda), name='bottleneck_layer')(base_pool5)
bottleneck_norm = BatchNormalization(name='bottleneck_norm')(bottleneck_layer)
bottleneck_relu = Activation('relu', name='bottleneck_relu')(bottleneck_norm)
bottleneck_drop = Dropout(0.5)(bottleneck_relu)
fin = Dense(self.n_classes)(bottleneck_drop)
fin_norm = BatchNormalization(name='fin_norm')(fin)
fin_softmax = Activation('softmax')(fin_norm)
######################################################################################################
###########Triplet Model Which learns the embedding layer relu6####################
self.triplet_model = Model(input_image, bottleneck_drop)
positive_embedding = self.triplet_model(positive_example)
negative_embedding = self.triplet_model(negative_example)
anchor_embedding = self.triplet_model(anchor_example)
###########Triplet Model Which learns the embedding layer relu6####################
adam_opt = optimizers.Adam(lr=0.00001, clipnorm = 1.0, amsgrad=False)
#The Triplet Model which optimizes over the triplet loss.
loss = Lambda(triplet_loss, output_shape=(1,))([anchor_embedding, positive_embedding, negative_embedding])
self.triplet_model_worker = Model(inputs=[anchor_example, positive_example, negative_example], outputs = loss)
self.triplet_model_worker.compile(loss='mean_absolute_error', optimizer=adam_opt)
def fit_model(self, pathname='./models/'):
if not os.path.exists(pathname):
os.makedirs(pathname)
if not os.path.exists(pathname+'/weights'):
os.makedirs(pathname+'/weights')
if not os.path.exists(pathname+'/tb'):
os.makedirs(pathname+'/tb')
filepath=pathname+"weights/{epoch:02d}.hdf5"
checkpoint = ModelCheckpoint(filepath, monitor='loss', verbose=1, save_best_only=False, mode='auto')
tensorboard = TensorBoard(log_dir=pathname+'/tb', write_graph=True, write_images=True)
callbacks_list = [checkpoint, tensorboard]
#Parameter
params = {'dim': (224, 224), 'batch_size':32, 'n_classes':11, 'n_channels':3, 'shuffle':True}
#Datasets
partition = pickle.load(open('../../../data/bam_2_partition_triplet.pkl', 'rb'))
labels = pickle.load(open('../../../data/bam_2_labels_triplet.pkl', 'rb'))
#Generators
training_generator = DataGenerator(partition['train'], labels, **params)
self.triplet_model_worker.fit_generator(generator = training_generator, epochs = 60, use_multiprocessing=True, workers = 10, callbacks = callbacks_list, verbose = 1)
麻烦的是上面的链接中回答。即使使用*** clipnorm = 1.0 ****后,梯度也会爆炸,权重也会给出“ nan”值。
保存并加载模型,然后打印权重。 NaN清晰可见。 加载代码:
m = load_model('/scratch/models_inception_stage2/yo/weights/02.hdf5', custom_objects={"tf":tf})
for layer in m.layers:
weights = layer.get_weights()
print (weights)
已打印砝码的片段
Here
[array([ 3.4517611e-04, 1.3431008e-03, -1.1081886e-03, 2.6104850e-04,
-2.1620051e-04, 1.6816283e-03, 8.8927911e-05, -3.8964470e-04,
1.7968584e-03, 1.0259283e-03, 5.0400384e-04, -3.6578919e-04,
-1.1292399e-03, 1.1509922e-03, 3.2478449e-04, -3.6580343e-05,
-4.4458261e-04, 4.8210021e-04, -9.5213606e-04, -6.4406055e-04,
5.0959276e-04, -3.4098624e-04, -7.0486858e-05, 2.8134760e-04,
-8.0100907e-04, 8.2962180e-04, -6.4140803e-04, 9.4872032e-04,
-3.3409546e-05, -3.0277384e-04, 5.2237371e-04, -8.3427120e-04,
-2.5856070e-04, -1.0346439e-03, 4.3354488e-05, -8.8099617e-04,
-6.8233605e-04, -1.2386916e-04, 8.2019303e-04, -1.9070004e-03,
1.5571159e-03, -3.4599879e-04, 6.2088901e-04, -8.4720332e-06,
1.6024955e-04, -1.2059419e-03, -1.4946899e-04, -6.7080715e-04,
-2.8154058e-05, 5.1517348e-04, 5.9993083e-05, 2.8555689e-04,
3.9626448e-04, -5.1538437e-04, 1.9132573e-04, 1.1226863e-03,
1.1591403e-03, -6.3404470e-04, 2.8910063e-04, -7.9366821e-04,
-1.7228167e-04, 6.2899920e-04, 1.7438219e-04, 1.1385380e-04],
dtype=float32), array([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
dtype=float32), array([0.50248814, 0.48732147, 0.64627343, 0.49432775, 0.45906776,
0.5168214 , 0.8561428 , 0.7308014 , 0.5067555 , 0.516555 ,
1.3287057 , 0.5746422 , 0.55597156, 1.0038179 , 0.9859771 ,
0.6110601 , 0.7357226 , 0.6123694 , 0.90676117, 0.5439505 ,
0.48629472, 0.5434108 , 0.4934845 , 0.5407317 , 0.6443982 ,
1.0403991 , 0.48624724, 0.83786434, 0.72478205, 0.7294607 ,
0.536994 , 0.38235992, 1.0484552 , 0.45833316, 0.48205158,
0.48236838, 0.71035874, 0.9472658 , 0.78085536, 1.0207686 ,
0.5089741 , 0.97984046, 0.86524594, 0.9828817 , 0.49027866,
0.7367909 , 0.57438385, 0.5011991 , 0.47189236, 0.52376693,
0.45648402, 0.40523565, 0.8375675 , 0.57908285, 0.6055632 ,
1.0325785 , 0.5377976 , 0.47033092, 0.83586556, 1.2780553 ,
0.503384 , 0.54509026, 0.5375585 , 0.6091993 ], dtype=float32)]
感谢您的帮助。