我有一个CNN架构,可以输出对象周围的盒子的坐标:
但是,如果我在tf中实现它,即使经过一个纪元,损失也会变得微不足道。我尝试了梯度裁剪和批处理归一化,但均无效。我怀疑我的损失有问题,这是相应的代码:
...
output = tf.layers.dense(dense, 4, name="output")
# Loss
error = output-y
error_sq = tf.square(error)
loss = tf.reduce_mean(error_sq, axis=-1)
# Training operation
optimizer = tf.train.RMSPropOptimizer(learning_rate=0.001, momentum=0.9, decay=0.0, epsilon=1e-08)
training_op = optimizer.minimize(loss)
我用keras实现了相同的体系结构,并且运行良好:
...
model.add(Dense(4))
# Optimizer
optimizer = RMSprop(lr=0.001, rho=0.9, epsilon=1e-08, decay=0.0)
# Compile
model.compile(optimizer = optimizer , loss = "mean_squared_error", metrics=["mean_squared_error"])
我看不出两者的区别。
注1: 如果我删除axis = -1参数,我也会得到nan的值,但是我加入了它,因为keras均值是用相同的参数计算的。
注2: 即使我仅在输入上训练密集层,keras模型也会缓慢收敛,而张量流模型则不会收敛。
答案 0 :(得分:1)
捕获可能是由于error = output-y
这行。因为output
是每个类的对数或预测概率。
如果我们打印它,它将如下所示:
output/logits = [[-4.55290842e+00 9.54713643e-01 2.04970908e+00 ... 1.06385863e+00
-1.76558220e+00 5.84793314e-02]
[ 1.42444344e+01 -3.09316659e+00 4.31246233e+00 ... -1.64039159e+00
-4.75767326e+00 2.69032687e-01]
[-3.66746974e+00 -1.05631983e+00 1.63249350e+00 ... 2.34054995e+00
-2.86306214e+00 -1.29766455e-02]
...
[ 1.92035064e-01 2.18118310e+00 1.05751991e+01 ... -3.32132912e+00
2.23277748e-01 -4.14045334e+00]
[-3.95318937e+00 7.54375601e+00 5.60657620e-01 ... 3.35071832e-02
2.31437039e+00 -3.36187315e+00]
[-4.37104368e+00 4.23799706e+00 1.20920219e+01 ... -1.18962801e+00
2.23617482e+00 -3.06528354e+00]]
依次执行步骤error = output-y
,error_sq = tf.square(error)
和
loss = tf.reduce_mean(error_sq, axis=-1)
可能会导致NaN。
下面的代码应该可以解决您的问题:
l1 = tf.layers.dense(normed_train_data, 64, activation='relu')
l2 = tf.layers.dense(l1, 64, activation='relu')
l3 = tf.layers.dense(l2, 4, name="output")
mse = tf.losses.mean_squared_error(labels=y, logits=l3)
loss = tf.reduce_mean(mse, name="loss")
optimizer = tf.train.RMSPropOptimizer(learning_rate=0.001, momentum=0.9, decay=0.0, epsilon=1e-08)
training_op = optimizer.minimize(loss)
correct = tf.nn.in_top_k(logits, y, 1)
accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))