Tensorflow Estimator logits和标签必须具有相同的形状

时间:2018-01-19 17:26:05

标签: python tensorflow

只是fyi ......这不是已报告和修复的#4715号错误。

print('Tensorflow version {} is loaded.'.format(tf.__version__))
#Tensorflow version 1.4.0 is loaded.

我已经整理了一个只有2个功能和二进制分类的自定义Estimator。以下代码可以正常工作,但确实会集中在有用的东西上。

input_layer = tf.feature_column.input_layer(features, feature_columns)
h1 = tf.layers.Dense(h1_size, activation=tf.nn.relu)(input_layer)
h2 = tf.layers.Dense(h2_size, activation=tf.nn.relu)(h1)
logits = tf.layers.Dense(NUM_CLASSES)(h2)
labels = tf.squeeze(labels, 1)
loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)

为了改进模型,我想改变计算损失的方式。具体来说,我宁愿不使用softmax而是使用sigmoid_cross_entropy。因此,我将损失改为;

loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits)

并得到这些错误;

ValueError: Shapes (?,) and (?, 2) must have the same rank
....
During handling of the above exception, another exception occurred:
....
ValueError: Shapes (?,) and (?, 2) are not compatible
....
During handling of the above exception, another exception occurred:
....
ValueError: logits and labels must have the same shape ((?, 2) vs (?,))

因为我真的不明白为什么挤压()是必要的,所以我将它删除了,我得到了它;

ValueError: Dimensions 1 and 2 are not compatible
....
During handling of the above exception, another exception occurred:
....
ValueError: logits and labels must have the same shape ((?, 2) vs (?, 1))

这让我觉得我可以将one_hot传递给损失计算以解决形状问题。我尝试了这个,以及在one_hot()之前的squeeze()。虽然,错误变得更加复杂,表明我可能会走错路。 (对于记录,one_hot为矩阵添加了一个维度,因此没有挤压,张量变为(?,1,2)。

loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.one_hot(labels, depth=2), logits=logits)
InvalidArgumentError (see above for traceback): Input to reshape is a tensor with 288 values, but the requested shape has 1

有人可以帮我理解为什么我的标签和logits适用于sparse_softmax_cross_entropy。但不是sigmoid_cross_entropy_with_logits? 我有什么办法可以重塑张量以允许sigmoid损失函数吗?

1 个答案:

答案 0 :(得分:0)

答案是首先创建一个one_hot然后挤压()得到的张量。

这适用于损失计算。

loss = tf.losses.sigmoid_cross_entropy(
        multi_class_labels=tf.squeeze(tf.one_hot(labels, depth=2), axis=1),
        logits=logits)

sigmoid_cross_entropy继续生成sigmoid_cross_entropy_with_logits。在搜索期间,我最终只是通过懒惰切换到发布的内容。