我有一个张量(batch, num_points, 2)
,它由2D点组成。我想将k-means或GMM群集批量应用于这些点,并生成k
是预先定义的k
群集。也就是说,我将得到形状为(batch, k, 2)
的聚类质心,以及所有点的聚类标签。
我注意到TensorFlow中同时存在k-means和GMM集群的实现,但是它们都是针对单个批处理的,并且GMM似乎有问题。我想将其中一种算法应用于数据集,并希望确保实现是可区分的。 (例如,我认为运行EM算法最多5次将使计算图可区分。)
任何人都可以向我推荐实施方案吗?