有一个程序定义了损失函数,如下所示:
def loss(hypes, decoded_logits, labels):
"""Calculate the loss from the logits and the labels.
Args:
logits: Logits tensor, float - [batch_size, NUM_CLASSES].
labels: Labels tensor, int32 - [batch_size].
Returns:
loss: Loss tensor of type float.
"""
logits = decoded_logits['logits']
with tf.name_scope('loss'):
logits = tf.reshape(logits, (-1, 2))
shape = [logits.get_shape()[0], 2]
epsilon = tf.constant(value=hypes['solver']['epsilon'])
# logits = logits + epsilon
labels = tf.to_float(tf.reshape(labels, (-1, 2)))
softmax = tf.nn.softmax(logits) + epsilon
if hypes['loss'] == 'xentropy':
cross_entropy_mean = _compute_cross_entropy_mean(hypes, labels,
softmax)
elif hypes['loss'] == 'softF1':
cross_entropy_mean = _compute_f1(hypes, labels, softmax, epsilon)
elif hypes['loss'] == 'softIU':
cross_entropy_mean = _compute_soft_ui(hypes, labels, softmax,
epsilon)
reg_loss_col = tf.GraphKeys.REGULARIZATION_LOSSES
print('******'*10)
print('loss type ',hypes['loss'])
print('type ', type(tf.get_collection(reg_loss_col)))
print( "Regression loss collection: {}".format(tf.get_collection(reg_loss_col)))
print('******'*10)
weight_loss = tf.add_n(tf.get_collection(reg_loss_col))
total_loss = cross_entropy_mean + weight_loss
losses = {}
losses['total_loss'] = total_loss
losses['xentropy'] = cross_entropy_mean
losses['weight_loss'] = weight_loss
return losses
运行程序会引发以下错误消息
File "/home/ decoder/kitti_multiloss.py", line 86, in loss
name='reg_loss')
File "/devl /tensorflow/tf_0.12/lib/python3.4/site-packages/tensorflow/python/ops/math_ops.py", line 1827, in add_n
raise ValueError("inputs must be a list of at least one Tensor with the "
ValueError: inputs must be a list of at least one Tensor with the same dtype and shape
我检查了tf.add_n
的功能,其实现如下。我的问题是如何检查tf.get_collection(reg_loss_col)
中的第一个参数tf.add_n
并打印其信息以找出错误消息的生成原因?
def add_n(inputs, name=None):
"""Adds all input tensors element-wise.
Args:
inputs: A list of `Tensor` objects, each with same shape and type.
name: A name for the operation (optional).
Returns:
A `Tensor` of same shape and type as the elements of `inputs`.
Raises:
ValueError: If `inputs` don't all have same shape and dtype or the shape
cannot be inferred.
"""
if not inputs or not isinstance(inputs, (list, tuple)):
raise ValueError("inputs must be a list of at least one Tensor with the "
"same dtype and shape")
inputs = ops.convert_n_to_tensor_or_indexed_slices(inputs)
if not all(isinstance(x, ops.Tensor) for x in inputs):
raise ValueError("inputs must be a list of at least one Tensor with the "
"same dtype and shape")
答案 0 :(得分:0)
为什么你甚至需要进入add_n
才能看到tf.get_collection(reg_loss_col)
是什么?您可以拥有tmp = tf.get_collection(reg_loss_col)
,然后查看其类型。顺便说一句,看起来您的图表中没有任何正规化损失,在这种情况下tf.get_collection(reg_loss_col)
将返回一个空列表。
在Python中查看对象的类型您可以使用内置函数type
。例如,要查看tmp
的类型:print type(tmp)
答案 1 :(得分:0)
作为解决方法,您可以将以下行替换为:
temp = tf.get_collection('losses')
if temp == []:
temp = [0]
weight_loss = tf.add_n(temp, name='total_loss')
添加零值不会影响最终结果,但可以有效地运行软件...您如何看待?