Python Minibatch Dictionary Learning

时间:2016-09-22 03:50:04

标签: python numpy image-processing machine-learning scikit-learn

我想使用sklearn' s MiniBatchDictionaryLearning在python中使用字典学习实现错误跟踪,这样我就可以记录错误在迭代中的减少情况。我有两种方法可以做到,这两种方法都没有真正起作用。设置:

  • 输入数据X ,numpy数组形状(n_samples,n_features)=(298143,300)。这些是形状(10,10)的斑块,由形状(642,480,3)的图像生成。
  • 字典学习参数:列数(或原子数)= 100,alpha = 2,变换算法= OMP,总数没有。迭代次数= 500(首先保持小,就像测试用例一样)
  • 计算错误:在学习字典后,我根据学习的字典再次对原始图像进行编码。由于编码和原始编码都是相同形状的凹凸不平的数组(642,480,3),我现在只做元素欧几里德距离:

    err = np.sqrt(np.sum(重建 - 原创)** 2))

我使用这些参数进行了测试运行,并且完全匹配能够以低误差产生非常好的重建,因此这很好。现在这两种方法:

方法1:每100次迭代保存学习的字典,并记录错误。对于500次迭代,这给了我们5次运行,每次100次迭代。每次运行后,我计算错误,然后使用当前学习的字典作为下一次运行的初始化。

# Fit an initial dictionary, V, as a first run
dico = MiniBatchDictionaryLearning(n_components = 100,
                                   alpha = 2,
                                   n_iter = 100,
                                   transform_algorithm='omp')
dl = dico.fit(patches)
V = dl.components_

# Now do another 4 runs.
# Note the warm restart parameter, dict_init = V.
for i in range(n_runs):
    print("Run %s..." % i, end = "")
    dico = MiniBatchDictionaryLearning(n_components = 100,
                                       alpha = 2,
                                       n_iter = n_iterations,
                                       transform_algorithm='omp',
                                       dict_init = V)
    dl = dico.fit(patches)
    V = dl.components_

    img_r = reconstruct_image(dico, V, patches)
    err = np.sqrt(np.sum((img - img_r)**2))
    print("Err = %s" % err)

问题:错误没有减少,而且非常高。这本词典也没有得到很好的学习。

方法2 :将输入数据X剪切成500批,然后使用partial_fit()方法进行部分拟合。

batch_size = 500
n_batches = X.shape[0] // batch_size
print(n_batches) # 596

for iternum in range(n_batches):
    batch = patches[iternum*batch_size : (iternum+1)*batch_size]
    V = dico.partial_fit(batch)

问题:这似乎需要大约5000倍的时间。

我想知道是否有办法在拟合过程中检索错误?

2 个答案:

答案 0 :(得分:2)

每次调用fit都会重新初始化模型并忘记之前对fit的任何调用:这是scikit-learn中所有估算工具的预期行为。

我认为在循环中使用partial_fit是正确的解决方案,但您应该在小批量上调用它(如在fit方法中所做的那样,默认的batch_size值只有3)然后只计算每个成本例如,对partial_fit进行100或1000次调用:

batch_size = 3
n_epochs = 20
n_batches = X.shape[0] // batch_size
print(n_batches) # 596


n_updates = 0
for epoch in range(n_epochs):
    for i in range(n_batches):
        batch = patches[i * batch_size:(i + 1) * batch_size]
        dico.partial_fit(batch)
        n_updates += 1
        if n_updates % 100 == 0:
            img_r = reconstruct_image(dico, dico.components_, patches)
            err = np.sqrt(np.sum((img - img_r)**2))
            print("[epoch #%02d] Err = %s" % (epoch, err))

答案 1 :(得分:0)

我遇到了同样的问题,最后我能够使代码更快。如果仍然对某人有用,请在此处添加解决方案。问题在于,在构造MiniBatchDictionaryLearning对象时,我们需要将n_iter设置为一个较低的值(例如1),这样对于每个partial_fit来说,它也不会运行单个批处理很多时代。

# Construct an initial dictionary object, note partial fit will be done later inside
# the loop, here we only specify that for partial_fit it needs just to run just 1 
# epoch (n_iter=1) with batch_size=batch_size on the current batch provided 
# (otherwise by default it can run upto 1000 iterations with batch_size=3 for a 
# single partial_fit() and on each of the batches, which makes the a single run of 
# partial_fit() very slow. Since we control the epoch on our own and it restarts 
# when all the batches are done, we need not provide more than 1 iteration here. 
# This will make the code to execute fast.

batch_size = 128 # e.g.,
dico = MiniBatchDictionaryLearning(n_components = 100,
                                   alpha = 2,
                                   n_iter = 1,  # epoch per partial_fit()
                                   batch_size = batch_size,
                                   transform_algorithm='omp')

后跟@ogrisel的代码:

n_updates = 0
for epoch in range(n_epochs):
    for i in range(n_batches):
        batch = patches[i * batch_size:(i + 1) * batch_size]
        dico.partial_fit(batch)
        n_updates += 1
        if n_updates % 100 == 0:
            img_r = reconstruct_image(dico, dico.components_, patches)
            err = np.sqrt(np.sum((img - img_r)**2))
            print("[epoch #%02d] Err = %s" % (epoch, err))

enter image description here