使用布尔掩码和索引列表替换Tensorflow中变量的某些行

时间:2016-05-04 11:51:37

标签: tensorflow k-means

我在球体上实施k-means,从@ dga' s gist开始。 单位范数约束基本上意味着使用内积而不是成对距离,使用argmax代替argmin和求和+归一化而不是平均来更新质心。

现在我正在尝试用表现最少的数据点替换死质心。 对于死质心,unsorted_segment_sum将返回0的总和:

total = tf.unsorted_segment_sum(points, best_centroids, K)

从这里我得到一个死亡质心的布尔掩码:

deads = tf.equal(total, 0)

......死亡质心的数量:

 dead_count = tf.reduce_sum(tf.as_type(deads, 'int64'))

...最后是一个列表,其中包含当前模型表示最差的数据点索引:

_, dead_replacement_idx = tf.nn.top_k(-assignment_qualities, 
                                      k=dead_count, sorted=False)

现在我如何更换死质心? 在numpy中,现在可以归结为:

means[deads] = points[dead_replacement_idx]

我怎样才能在Tensorflow中做类似的事情?

1 个答案:

答案 0 :(得分:0)

如果您将工具存储在unicode中,则可以使用Variables

scatter_update

结果

tf.reset_default_graph()

means = tf.Variable(np.array([[1,1],[2,2],[3,3]]), dtype=np.float32)
indices = tf.constant([0, 2])
new_mean = tf.constant([-1, -1], dtype=np.float32)

new_mean_matrix = tf.reshape(new_mean, [1, -1])
tile_shape = tf.pack([tf.size(indices), 1])
new_mean_matrix_tiled = tf.tile(new_mean_matrix, tile_shape) 
update_op = tf.scatter_update(means, indices, new_mean_matrix_tiled)

sess = tf.InteractiveSession()
sess.run(tf.initialize_all_variables())
print "Before update"
print sess.run(means)
print "Updating rows", indices.eval(), "to", new_mean.eval()
sess.run(update_op)
print "After update"
print sess.run(means)