如何使用tf.scatter_nd"没有"积累?

时间:2017-05-22 16:04:41

标签: tensorflow

让我说我有一个(无,2) - 形状张量SELECT type, REFERENCES FROM ( SELECT type, REFERENCES, STATUS, rank() OVER (PARTITION BY Type ORDER BY REFERENCES DESC) AS Rank FROM Tables ) a WHERE rank <= 2 AND Type IN (A,B,C,D,E) AND STATUS = Y; 和(无) - 形状张量indices。这些实际的行#和值将在运行时确定。

我想设置一个4x5张量values,每个索引元素都有值的值。我发现我可以像这样使用tf.scatter_nd:

t

我的问题是:当索引有重复时,将累积值。

t = tf.scatter_np(indices, values, [4, 5])
# E.g., indices = [[1,2],[2,3]], values = [100, 200] 
# t[1,2] <-- 100; t[2,3] <-- 200     

我想只分配一个,即,无论是无知(所以,第一个值)还是覆盖(所以,最后一个值)。

我觉得我需要检查索引中的重复项,或者我需要使用tensorflow循环。有人可以建议吗? (希望是一个最小的示例代码?)

3 个答案:

答案 0 :(得分:3)

您可以使用tf.unique:唯一的问题是此操作需要1D张量。 因此,为了克服这个问题,我决定使用Cantor pairing function。 简而言之,它存在一个双射函数,它将元组(在这种情况下是一对值,但它适用于任何N维元组)映射到单个值。

一旦坐标减少到标量的1-D张量,则tf.unique可用于查找唯一数字的索引。

Cantor配对函数是可逆的,因此现在我们不仅知道1-D张量内非重复值的索引,而且我们也可以回到坐标的2-D空间并使用{{ 1}}执行更新而不会出现累加器问题。

TL; DR:

scatter_nd

答案 1 :(得分:0)

这可能不是最佳解决方案,我使用tf.unsorted_segment_max来避免累积

with tf.Session() as sess:

    # #########
    # Examples:
    # ##########
    width, height, depth = [3, 3, 2]
    indices = tf.cast([[0, 1, 0], [0, 1, 0], [1, 1, 1]], tf.int32)
    values  = tf.cast([1, 2, 3], tf.int32)

    # ########################
    # Filter duplicated indices
    # #########################
    flatten = tf.matmul(indices, [[height * depth], [depth], [1]])
    filtered, idx = tf.unique(tf.squeeze(flatten))

    # #####################
    # Obtain updated result
    # #####################
    def reverse(index):
        """Map from 1-D to 3-D """
        x = index / (height * depth)
        y = (index - x * height * depth) / depth
        z = index - x * height * depth - y * depth
        return tf.stack([x, y, z], -1)        

    # This will pick the maximum value instead of accumulating the result
    updated_values  = tf.unsorted_segment_max(values, idx, tf.shape(filtered_idx)[0])
    updated_indices = tf.map_fn(fn=lambda i: reverse(i), elems=filtered)

    # Now you can scatter_nd without accumulation
    result = tf.scatter_nd(updated_indices, 
                           updated_values, 
                           tf.TensorShape([3, 3, 2]))

答案 2 :(得分:0)

也将tf.scatter_nd应用于1的矩阵。这将为您提供累加的元素数量,您可以将结果除以得出平均值。 (但是要注意零,因为那些应该除以一)。

counter = tf.ones(tf.shape(values))
t = tf.scatter_nd(indices,values,shape)
t_counter = tf.scatter_nd(indices,counter,shape)

然后将t除以t_counter(但仅在t_counter不为零的情况下)。