我在球体上实施k-means,从@ dga' s gist开始。
单位范数约束基本上意味着使用内积而不是成对距离,使用argmax
代替argmin
和求和+归一化而不是平均来更新质心。
现在我正在尝试用表现最少的数据点替换死质心。
对于死质心,unsorted_segment_sum
将返回0的总和:
total = tf.unsorted_segment_sum(points, best_centroids, K)
从这里我得到一个死亡质心的布尔掩码:
deads = tf.equal(total, 0)
......死亡质心的数量:
dead_count = tf.reduce_sum(tf.as_type(deads, 'int64'))
...最后是一个列表,其中包含当前模型表示最差的数据点索引:
_, dead_replacement_idx = tf.nn.top_k(-assignment_qualities,
k=dead_count, sorted=False)
现在我如何更换死质心? 在numpy中,现在可以归结为:
means[deads] = points[dead_replacement_idx]
我怎样才能在Tensorflow中做类似的事情?
答案 0 :(得分:0)
如果您将工具存储在unicode
中,则可以使用Variables
scatter_update
结果
tf.reset_default_graph()
means = tf.Variable(np.array([[1,1],[2,2],[3,3]]), dtype=np.float32)
indices = tf.constant([0, 2])
new_mean = tf.constant([-1, -1], dtype=np.float32)
new_mean_matrix = tf.reshape(new_mean, [1, -1])
tile_shape = tf.pack([tf.size(indices), 1])
new_mean_matrix_tiled = tf.tile(new_mean_matrix, tile_shape)
update_op = tf.scatter_update(means, indices, new_mean_matrix_tiled)
sess = tf.InteractiveSession()
sess.run(tf.initialize_all_variables())
print "Before update"
print sess.run(means)
print "Updating rows", indices.eval(), "to", new_mean.eval()
sess.run(update_op)
print "After update"
print sess.run(means)