计算持续性图之间的wasserstein距离

时间:2019-06-09 16:29:04

标签: python arrays numpy distance diagram

我正在尝试计算通过Python Ripser库生成的两个持久性图之间的Wasserstein距离。我在Persim中发现了两个有趣的函数:sliced_wasserstein和wasserstein_matching。

我的图表生成如下:

    data = json.loads(data)
    data = pd.DataFrame.from_dict(data)
    rips = Rips()
    dgms = rips.fit_transform(data)
    for i in dgms:
        print(type(i))
        i.tofile(directory+"diagram.txt")
    plot_diagrams(dgms, show=False)
    plt.savefig("persistence_diagram.png")
    plt.close()

'dgms'是一个包含numpy数组的列表,所以我在'for'行中将它们取出。

我的Wasserstein函数用法如下:

with open(loc) as f:
    img1 = np.fromfile(f)
    f.close()
with open(loc2) as f:
    img2 = np.fromfile(f)
    f.close()
persim.sliced_wasserstein(img1, img2)

我试图传递给wasserstein_matching三种数据(.png,dgms列表和np.array中的图表),但我不断得到的只是一个错误“ IndexError:数组的索引过多”。 所以我切换到sliced_wasserstein,出现这样的错误:

Traceback (most recent call last):
  File "C:/Users/Patka/PycharmProjects/MGR/Mapper.py", line 26, in <module>
    persim.sliced_wasserstein(img1, img2)
  File "C:\Users\Patka\environmentpython\lib\site-packages\persim\sliced_wasserstein.py", line 53, in sliced_wasserstein
    sw += step * cityblock(sorted(V1), sorted(V2))
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

对我来说,一件奇怪的事情是,当我在保存文件之前打印i.shape时,我得到了二维,例如(12,2),但是当我使用numpy.fromfile()从同一个文件读取时,我得到一个元组(12,)。

有人能治愈吗?我的最终目标是计算许多图的距离并将它们聚类,但是我坚持比较两个图...

0 个答案:

没有答案