如你所知tf.one_hot可以做一个热门编码。但是,当我的数据集非常大时,我需要进行批量训练。这样,当我使用for循环遍历所有批次时,在每次迭代中,当我执行tf.one_hot时,一个热矩阵的维度将小于我预期的值。
例如,对于列' a'我们有47个类别,但是在一个批次中它们可能只显示了20个,当我在这个批次上执行one_hot时,它将创建一个矩形,其行数为* 20而不是行* 47的维度。
如何获得行的维度* 47每批中的一个热矩阵?
谢谢!
答案 0 :(得分:1)
tf.one_hot()
接受一个参数depth
作为第二个参数,它决定了单热矢量应该有多长。如果您按照以下方式运行操作:
b = tf.one_hot( a, 47 )
它应该给你最后一个维度47。
很难说没有代码,但有些人不会硬编码one_hot大小,但尝试从标签张量中获取它,例如
max_class = tf.reduce_max( a )
b = tf.one_hot( a, max_class )
如果你的代码就是这种情况,那么批量只能达到20级。
否则需要看你的代码说些什么。
如果TensorFlow内存不足,它会因错误而停止,不会只是默默地咬掉一半的数据。 :)