Tensorflow tf.expand_dims

时间:2016-11-20 17:22:57

标签: tensorflow

original Tensorflow tutorial包含以下代码:

batch_size = tf.size(labels)
labels = tf.expand_dims(labels, 1)
indices = tf.expand_dims(tf.range(0, batch_size, 1), 1)
concated = tf.concat(1, [indices, labels])
onehot_labels = tf.sparse_to_dense(concated, tf.pack([batch_size, NUM_CLASSES]), 1.0, 0.0)

第二行为labels张量添加维度。但是,labels通过Feed字典输入,因此它应该已经具有形状[batch_size, NUM_CLASSES]。如果是,那么为什么expand_dims在这里使用?

1 个答案:

答案 0 :(得分:2)

那个教程很老了。您引用的是版本0.6,而截至(本帖子的11-20-2016时间)它们的值为0.11。所以当时有许多不同的函数v0.6。

无论如何要回答你的问题:

mnist中的标签只是编号为0-9。但是,损失函数期望标签被编码为一个热矢量。

在该示例中,标签不仅仅是[batch_size, NUM_CLASSES],它只是[batch_size]

这可以通过类似的numpy函数完成。此外,他们还提供了将张量流中的mnist数据集中的标签作为一个已经具有您所述形状的热矢量的函数。