如何获得scikit-learn partial_fit的迭代

时间:2017-03-08 16:48:39

标签: python scikit-learn

我正在尝试使用SGDClassifier训练HashingVectorizer文本数据。我想知道如何组装从多个文件中传递给partial_fit()的批次。

以下代码是否是通过迭代批量获取数据的合适方法?这样做是否有最佳实践或推荐方法?

class MyIterable:
def __init__(self, files, batch_size):
    self.files = files
    self.batch_size = batch_size
def __iter__(self):
    batchstartmark = 0
    for line in fileinput.input(self.files):
        while batchstartmark < self.batch_size
            yield line.split('\t')
            batchstartmark += 1

提前致谢!

1 个答案:

答案 0 :(得分:1)

在这里判断这种方法的理论:  这是一个非常糟糕的方法!

由于SGDClassifier正在使用随机梯度下降(如果需要,可以使用小批量),您应该尝试完成SGDs数学分析的假设。

SGD的基本思想是:选择一些随机元素和下降。你的代码明显偏离两点:

  • A)您在每个时期以相同的顺序挑选元素
  • B)你正在采样(不是真的)没有替换
    • 因此,在此纪元中挑选出所有其他x之前,不会选择x17

你对 A 的无知会导致非常糟糕的表现,而且很有可能。

B 很难分析。有不同的理论观点,主要取决于某些特定的问题(当然,凸和非凸问题之间存在差异),而采样与替换是经典的(最常见的)收敛证明),有时采样 - 无替换(也就是:在纪元/循环期间随机播放和迭代),并且通常会更快收敛。