批量的one_hot编码将是不完整的tensorflow

时间:2018-04-27 23:45:04

标签: python tensorflow one-hot-encoding

如你所知tf.one_hot可以做一个热门编码。但是,当我的数据集非常大时,我需要进行批量训练。这样,当我使用for循环遍历所有批次时,在每次迭代中,当我执行tf.one_hot时,一个热矩阵的维度将小于我预期的值。

例如,对于列' a'我们有47个类别,但是在一个批次中它们可能只显示了20个,当我在这个批次上执行one_hot时,它将创建一个矩形,其行数为* 20而不是行* 47的维度。

如何获得行的维度* 47每批中的一个热矩阵?

谢谢!

1 个答案:

答案 0 :(得分:1)

tf.one_hot()接受一个参数depth作为第二个参数,它决定了单热矢量应该有多长。如果您按照以下方式运行操作:

b = tf.one_hot( a, 47 )

它应该给你最后一个维度47。

很难说没有代码,但有些人不会硬编码one_hot大小,但尝试从标签张量中获取它,例如

max_class = tf.reduce_max( a )
b = tf.one_hot( a, max_class )

如果你的代码就是这种情况,那么批量只能达到20级。

否则需要看你的代码说些什么。

如果TensorFlow内存不足,它会因错误而停止,不会只是默默地咬掉一半的数据。 :)