我正在尝试计算通过Python Ripser库生成的两个持久性图之间的Wasserstein距离。我在Persim中发现了两个有趣的函数:sliced_wasserstein和wasserstein_matching。
我的图表生成如下:
data = json.loads(data)
data = pd.DataFrame.from_dict(data)
rips = Rips()
dgms = rips.fit_transform(data)
for i in dgms:
print(type(i))
i.tofile(directory+"diagram.txt")
plot_diagrams(dgms, show=False)
plt.savefig("persistence_diagram.png")
plt.close()
'dgms'是一个包含numpy数组的列表,所以我在'for'行中将它们取出。
我的Wasserstein函数用法如下:
with open(loc) as f:
img1 = np.fromfile(f)
f.close()
with open(loc2) as f:
img2 = np.fromfile(f)
f.close()
persim.sliced_wasserstein(img1, img2)
我试图传递给wasserstein_matching三种数据(.png,dgms列表和np.array中的图表),但我不断得到的只是一个错误“ IndexError:数组的索引过多”。 所以我切换到sliced_wasserstein,出现这样的错误:
Traceback (most recent call last):
File "C:/Users/Patka/PycharmProjects/MGR/Mapper.py", line 26, in <module>
persim.sliced_wasserstein(img1, img2)
File "C:\Users\Patka\environmentpython\lib\site-packages\persim\sliced_wasserstein.py", line 53, in sliced_wasserstein
sw += step * cityblock(sorted(V1), sorted(V2))
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
对我来说,一件奇怪的事情是,当我在保存文件之前打印i.shape时,我得到了二维,例如(12,2),但是当我使用numpy.fromfile()从同一个文件读取时,我得到一个元组(12,)。
有人能治愈吗?我的最终目标是计算许多图的距离并将它们聚类,但是我坚持比较两个图...