Sinkhorn算法没有收敛

时间:2017-12-04 11:47:38

标签: python algorithm optimization machine-learning linear-programming

我试图实现正则化的sinkhorn算法(Wilson 62),该算法计算两个均匀分布(lat,lon)之间的最佳传输。基本思想是定点迭代。如果您不知道这一点,您可能会看到http://www.numerical-tours.com/matlab/optimaltransp_5_entropic/我的情况完全相同。

import pandas as pd
from sklearn.cluster.hierarchical import AgglomerativeClustering
from sklearn.neighbors import KNeighborsClassifier
import numpy as np
from shapely.geometry import shape, Point
import shapefile
from sklearn.metrics.pairwise import pairwise_distances


if __name__ == '__main__':
    inst = pd.read_csv('inst.csv', encoding='utf-8')[['lat', 'lng']]
    ht = pd.read_csv('ht.csv', encoding='utf-8')[['lat', 'lng']]

for k in range(0,4):
    print('k = {}'.format(k))
    gamma = 0.5
    M = pairwise_distances(inst[['lat','lng']],ht[['lat','lng']])
    K = np.exp(-M/0.5)
    N1 = len(inst)
    N2 = len(ht)
    v = np.ones((N1,1))
    i = 0
    M = 99999
    a = np.divide(np.ones((N2,1)),N2)
    b = np.divide(np.ones((N1,1)),N1)
    while i<M:
        print('i = {}'.format(i))
        u = np.divide(a,np.dot(K.T,v))
        v = np.divide(b,np.dot(K,u))
        i = i+1
    print('cluster {} has wassersetein distance {}'.format(k,np.sum(np.multiply(np.dot(np.dot(np.diagflat(u),K.T),np.diagflat(v)),M).reshape(-1,1))))

最终距离爆炸并始终等于迭代次数。我的算法出了什么问题?

cluster 2 has wassersetein distance 99999.0

我试过了这个库。也不行。

for k in range(0,4):

N1 = len(inst)
N2 = len(ht)
b = np.divide(np.ones((N2,1)),N2)
a = np.divide(np.ones((N1,1)),N1)
M = ot.dist(inst[inst.lb==k][['lat','lng']],ht[ht.lb==k][['lat','lng']])
print(ot.sinkhorn2(a,b,M,1e-3))

给出广播杜松子酒错误

ValueError: operands could not be broadcast together with shapes (132,1) (2011,93) 

0 个答案:

没有答案