用于最佳传输的Sinkhorn算法

时间:2019-12-24 08:39:52

标签: python algorithm optimization

我正在尝试编写Sinkhorn算法,尤其是当熵正则化强度收敛为0时,我是否想计算两个度量之间的最优传输。

举例来说,让我们将统一度量$ U $超过$ [0; 1] $转换为统一度量$ V $超过$ [1; 2] $。 二次海岸的最佳度量是$(x,x-1)_ {#} U $。

让我们离散化$ [0; 1] $,度量$ U $,$ [1; 2] $和度量$ V $。我应该使用Sinkhorn进行测量,以使支撑位于线$ y = x-1 $的图形中。但事实并非如此,因此我正在研究以查找问题所在。我将向您展示我的代码,我的结果也许有人可以帮助我。

import numpy as np
import math
from mpl_toolkits import mplot3d
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm
import matplotlib.colors as colors

#Parameters
N = 10                        #Step of the discritization of [0,1]
stop = 10**-3
Niter = 10**3


def Sinkhorn(C, mu, nu, lamb):
    # lam : strength of the entropic regularization

    #Initialization
    a1 = np.zeros(N)
    b1 = np.zeros(N)
    a2 = np.ones(N)
    b2 = np.ones(N)
    Iter = 0

    GammaB = np.exp(-lamb*C)

    #Sinkhorn
    while (np.linalg.norm(a2) > stop and  np.linalg.norm(b2) > stop and np.linalg.norm(a2) < 1/stop and np.linalg.norm(b2) < 1/stop and Iter < Niter and np.linalg.norm(a1-a2) + np.linalg.norm(b1-b2) > stop ):
        a1 = a2
        b1 = b2
        a2 = mu/(np.dot(GammaB,b1))
        b2 = nu/(np.dot(GammaB.T,a2))
        Iter +=1

    # Compute gamma_star
    Gamma = np.zeros((N,N))
    for i in range(N):
        for j in range(N):
            Gamma[i][j] = a2[i]*b2[j]*GammaB[i][j]
    Gamma /= Gamma.sum()

    return Gamma

## Test between uniform([0;1]) over uniform([1;2])

S = np.linspace(0,1,N, False)  #discritization of [0,1]
T = np.linspace(1,2,N,False)  #discritization of [1,2]

# Discretization of uniform([0;1])
U01 = np.ones(N)
Mass = np.sum(U01)
U01 = U01/Mass

# Discretization uniform([1;2])
U12 = np.ones(N)
Mass = np.sum(U12)
U12 = U12/Mass

# Cost function
X,Y = np.meshgrid(S,T)
C = (X-Y)**2               #Matrix of c[i,j]=(xi-yj)²

def plot_Sinkhorn_U01_U12():

    #plot optimal measure and convergence
    fig = plt.figure()
    for i in range(4):
        ax = fig.add_subplot(2, 2, i+1, projection='3d')
        Gamma_star = Sinkhorn(C, U01, U12, 1/10**i)
        ax.scatter(X, Y, Gamma_star, cmap='viridis', linewidth=0.5)
        plt.title("Gamma_bar({}) between uniform([0,1]) and uniform([1,2])".format(1/10**i))

    plt.show()

    plt.figure()
    for i in range(4):
        plt.subplot(2,2,i+1)
        Gamma_star = Sinkhorn(C, U01, U12, 1/10**i)
        plt.imshow(Gamma_star,interpolation='none')
        plt.title("Gamma_bar({}) between uniform([0,1]) and uniform([1,2])".format(1/10**i))

    plt.show()

    return

# The transport between U01 ans U12 is x -> x-1 so the support of gamma^* is contained in the graph of the function x -> (x,x+1) which is the line y = x+1

plot_Sinkhorn_U01_U12()

我得到了什么。

enter image description here enter image description here

根据建议,这是我考虑1个羊羔时代码的输出。

Output

Output

方法更好,但仍然不正确。这是Gamma_star(125)

Gamma_star(125) : 
[[0.1  0.1  0.1  0.1  0.1  0.1  0.1  0.1  0.09 0.01]
 [0.   0.   0.   0.   0.   0.   0.   0.   0.01 0.01]
 [0.   0.   0.   0.   0.   0.   0.   0.   0.   0.01]
 [0.   0.   0.   0.   0.   0.   0.   0.   0.   0.01]
 [0.   0.   0.   0.   0.   0.   0.   0.   0.   0.01]
 [0.   0.   0.   0.   0.   0.   0.   0.   0.   0.01]
 [0.   0.   0.   0.   0.   0.   0.   0.   0.   0.01]
 [0.   0.   0.   0.   0.   0.   0.   0.   0.   0.01]
 [0.   0.   0.   0.   0.   0.   0.   0.   0.   0.01]
 [0.   0.   0.   0.   0.   0.   0.   0.   0.   0.01]]

我们可以看到,度量gamma_star的支持未包含在$ y = x-1 $

行中

感谢和问候。

1 个答案:

答案 0 :(得分:0)

这不是最终的答案,但我们正在接近。

根据建议,我减轻了我的休息时间。以唯一条件为例

while (Iter < Niter):

这就是我得到的:

enter image description here

enter image description here

这是我得到的gamma_star(125)矩阵:

[[0.08 0.02 0.   0.   0.   0.   0.   0.   0.   0.  ]
 [0.02 0.06 0.02 0.   0.   0.   0.   0.   0.   0.  ]
 [0.   0.02 0.06 0.02 0.   0.   0.   0.   0.   0.  ]
 [0.   0.   0.02 0.06 0.02 0.   0.   0.   0.   0.  ]
 [0.   0.   0.   0.02 0.06 0.02 0.   0.   0.   0.  ]
 [0.   0.   0.   0.   0.02 0.06 0.02 0.   0.   0.  ]
 [0.   0.   0.   0.   0.   0.02 0.06 0.02 0.   0.  ]
 [0.   0.   0.   0.   0.   0.   0.02 0.06 0.02 0.  ]
 [0.   0.   0.   0.   0.   0.   0.   0.02 0.06 0.02]
 [0.   0.   0.   0.   0.   0.   0.   0.   0.02 0.08]]

它离我的期望值更近:$ \ text {Gamma_star}(i,j)= 0 $ for $ j \ ne i-1 $

新代码是:

import numpy as np
import math
from mpl_toolkits import mplot3d
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm
import matplotlib.colors as colors

#Parameters
N = 10                        #Step of the discritization of [0,1]
Niter = 10**5

def Sinkhorn(C, mu, nu, lamb):
    # lam : strength of the entropic regularization

    #Initialization
    a1 = np.zeros(N)
    b1 = np.zeros(N)
    a2 = np.ones(N)
    b2 = np.ones(N)
    Iter = 0

    GammaB = np.exp(-lamb*C)

    #Sinkhorn
    while (Iter < Niter):
        a1 = a2
        b1 = b2
        a2 = mu/(np.dot(GammaB,b1))
        b2 = nu/(np.dot(GammaB.T,a2))
        Iter +=1

    # Compute gamma_star
    Gamma = np.zeros((N,N))
    for i in range(N):
        for j in range(N):
            Gamma[i][j] = a2[i]*b2[j]*GammaB[i][j]
    Gamma /= Gamma.sum()

    return Gamma

    ## Test between uniform([0;1]) over uniform([1;2])

    S = np.linspace(0,1,N, False)  #discritization of [0,1]
    T = np.linspace(1,2,N,False)  #discritization of [1,2]

    # Discretization of uniform([0;1])
    U01 = np.ones(N)
    Mass = np.sum(U01)
    U01 = U01/Mass

    # Discretization uniform([1;2])
    U12 = np.ones(N)
    Mass = np.sum(U12)
    U12 = U12/Mass

    # Cost function
    X,Y = np.meshgrid(S,T)
    C = (X-Y)**2               #Matrix of c[i,j]=(xi-yj)²

    def plot_Sinkhorn_U01_U12():

        #plot optimal measure and convergence
        fig = plt.figure()
        for i in range(4):
            ax = fig.add_subplot(2, 2, i+1, projection='3d')
            Gamma_star = Sinkhorn(C, U01, U12, 5**i)
            ax.scatter(X, Y, Gamma_star, cmap='viridis', linewidth=0.5)
            plt.title("Gamma_bar({}) between uniform([0,1]) and uniform([1,2])".format(5**i))
        plt.show()

        plt.figure()
        for i in range(4):
            plt.subplot(2,2,i+1)
            Gamma_star = Sinkhorn(C, U01, U12, 5**i)
            plt.imshow(Gamma_star,interpolation='none')
            plt.title("Gamma_bar({}) between uniform([0,1]) and uniform([1,2])".format(5**i))

        plt.show()

        return

    # The transport between U01 ans U12 is x -> x-1 so the support of gamma^* is contained in the graph of the function x -> (x,x-1) which is the line y = x-1

    plot_Sinkhorn_U01_U12()