我正在尝试训练和测试一个暹罗网络,并且我有正负两副图像。 当我开始运行模型时,有时会得到:
InvalidArgumentError:摘要直方图中的Nan:距离
在某些运行中,它发生在开始的第4或5步,而在其他运行中,它发生在以后的某些步骤。 该模型的参数为:
n_steps:600 share_weights:正确 保证金:100 阈值:0.3 学习率:0.001 batch_size:64
我正在使用Adam Optimizer。 我计算距离的方式是:
loss, distance_0 = contrastive_loss(self.net0.output(), self.net1.output(), self.y, self.margin)
distance = network_layers.get_scaled_tensor(self.distance_0, 'scaled_distance')
我在此代码中使用对比损失:
def contrastive_loss(out0, out1, y_true, margin):
with tf.name_scope('contrastive_loss'):
d = tf.reduce_sum(tf.square(tf.subtract(out0, out1)), 1)
d_sqrt = tf.sqrt(1e-6 + d)
loss = (y_true * d) + ((1 - y_true) * tf.square(tf.maximum(tf.subtract(margin, d_sqrt), 0)))
loss = tf.reduce_mean(loss) # Note: constant component removed (/2)
return loss, d_sqrt
然后我对距离张量进行缩放。
def get_scaled_tensor(inputs, name, vmin=0., vmax=1.):
"""
:param inputs: input tensor
:param name: name of tensor
:param vmin: min value for scaling
:param vmax: max value for scaling
:return: a tensor to scale inputs to the interval of [vmin, vmax]
"""
inputs_min = tf.reduce_min(inputs)
inputs_max = tf.reduce_max(inputs)
if inputs_min == inputs_max:
return inputs
scaled = (inputs - inputs_min) / (inputs_max - inputs_min)
scaled = tf.add(scaled * (vmax - vmin), vmin, name=name)
return scaled
我尝试在每一步都打印出距离张量,以查看是否生成了任何NaN值。在错误发生之前,这就是我所拥有的:
[0。 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0。 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0。 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0。]
如您所见,在张量中,除一个值为1之外,其他所有值均为0。 相同的事情发生在不同的运行中,只是位置“ 1”发生了变化。 例如错误之前的另一个距离张量:
[0。 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0。 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0。 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0。]
以下是成功运行的先前步骤中的一些距离张量:
第0步: [0.03674867 0.00787617 0.2477202 0.3154928 0.10931009 0.2735154 0.13083686 0.14834054 0.09842198 0.09132127 0.0856637 0.6076441 0.100755 0.09179091 0.11605983 0.3297588 0.28090978 0.01025809 0.12361664 0.11507085 0.06570645 0.0616716 0.12276769 0.12957902 0.2234131 0.5456302 0.16503233 0.03914422 0.15807444 0.9947304 0.5870102 0.17302574 0.15670143 0.07926143 0.12775105 0.06380077 0.5533038 0.7977735 0.22019525 0.07887711 0.18530448 0.12723611 0.05254128 0.46612477 0.5007302 0. 0.03045994 0。 0.25195697 0.03005651 0.900458 0. 0.36718735 0.15492864 0.42535746 0.151013 0.7009924 0.12371618 0.05326976 0.04238863 0.08139887 1. 0.14789903 0.02012084]
步骤1: [0.00000000e + 00 5.47710210e-02 9.02850509e-01 8.88542950e-01 4.31583114e-02 6.15703311e-09 6.10682368e-01 9.25455615e-02 1.00000000e + 00 5.41838706e-01 4.48709205e-02 2.73277704e-02 2.96332520e-02 1.20817451e-02 2.27759629e-02 3.11016534e-02 6.85849637e-02 5.71546005e-03 6.70421779e-01 6.15703311e-09 1.96496379e-02 4.62664776e-02 2.84987278e-02 6.15703311e-09 5.13420582e-01 1.88366547e-02 1.45142935e-02 8.92327540e-03 6.32246491e-03 8.08481336e-01 0.00000000e + 00 5.63631430e-02 3.69409472e-02 7.36759901e-01 5.46008237e-02 1.17494063e-02 1.54513204e-02 1.68218538e-02 6.69234037e-01 3.33906822e-02 3.26752402e-02 3.83973792e-02 5.73891141e-02 1.11822821e-01 5.12272008e-02 6.93304315e-02 5.14047369e-02 1.80864055e-03 1.57557316e-02 1.11042978e-02 1.07334806e-02 6.15703311e-09 0.00000000e + 00 1.03809848e-01 0.00000000e + 00 6.15703311e-09 2.19233677e-01 1.49917230e-02 1.25507280e-01 1.75409820e-02 7.54028440e-01 5.22847399e-02 8.12488422e-02 1.12156458e-01]
步骤2: [0。 0.01425315 0.6806311 0.63107663 0.00653503 0.00653504 0.00653503 0.00653503 0.0405196 0.00653503 0.00653503 0.02334335 0.08763484 0.12318593 0.8585707 0. 0.00653504 0.01887737 0.00653503 0.78051955 0.00653503 0.02088702 0.00653503 0.03008028 0.03201023 0.36658844 0.00653503 0.01497512 0.0114087 0.03001456 0.00653503 0.00653503 0.00653503 0.95640105 0.01957083 0.00653503 0.08598939 0.00653503 0.49910328 0.00653503 0.00653503 0.02415803 0.02409992 0.71698534 0.00653503 0.00659794 0.01573122 0.6454014 0.00653503 0.81859267 0.00653503 0.00653503 0.03296572 0.00653504 0.07882535 0.2214024 1. 0.00653503 0.00653503 0.00653503 0.02502085 0.05808342 0.00653503 0.03955946]
步骤3: [0.01213129 0.01213101 0.01213105 0.01213101 0.01213105 0.01213116 0.01213094 0.01213101 0.01213147 0.01213094 0.01213105 0。 0.0121312 0.01213112 0.01213133 0.01213125 0.01213105 0.01213096 0.55621964 0.01213142 0.01213109 0.01213114 0.01213133 0。 0. 0. 0.01213133 0.01213101 0. 0 0.01213129 0. 0.01213153 0. 0. 0.01213114 0. 0.01245531 0. 0.01213101 0.01213105 0.01213105 0.01213129 0.01213101 0.0121312 0.01213079 0.35860753 0。 0.01213118 0.01213101 0.01213112 0.0.01213114 0.01213101 0.012131030。0.72302234 0.01213101 0.5268527 0。 0.01213123 1. 0. 0.]
我尝试搜索此问题,建议之一是增加批量。但是,将批处理大小增加到128会出现内存错误。
我希望您能提出一些前进的建议。
是由于距离张量中有太多0引起的问题吗?
始终添加一个较小的值以防止值变为0有意义吗?
我应该重新考虑损失函数和缩放函数吗?
尝试使用防止梯度爆炸的技术对我有意义吗?
我应该更改正在使用的优化程序吗?