original Tensorflow tutorial包含以下代码:
batch_size = tf.size(labels)
labels = tf.expand_dims(labels, 1)
indices = tf.expand_dims(tf.range(0, batch_size, 1), 1)
concated = tf.concat(1, [indices, labels])
onehot_labels = tf.sparse_to_dense(concated, tf.pack([batch_size, NUM_CLASSES]), 1.0, 0.0)
第二行为labels
张量添加维度。但是,labels
通过Feed字典输入,因此它应该已经具有形状[batch_size, NUM_CLASSES]
。如果是,那么为什么expand_dims
在这里使用?
答案 0 :(得分:2)
那个教程很老了。您引用的是版本0.6,而截至(本帖子的11-20-2016时间)它们的值为0.11。所以当时有许多不同的函数v0.6。
无论如何要回答你的问题:
mnist中的标签只是编号为0-9。但是,损失函数期望标签被编码为一个热矢量。
在该示例中,标签不仅仅是[batch_size, NUM_CLASSES]
,它只是[batch_size]
。
这可以通过类似的numpy函数完成。此外,他们还提供了将张量流中的mnist数据集中的标签作为一个已经具有您所述形状的热矢量的函数。