tf.segment_max错误:段ID不会增加1

时间:2016-10-18 13:38:02

标签: tensorflow

tf.segment_max()和其他细分操作通常require the segment IDs to be consecutive。将此op应用于动态生成的批处理并使用tf.unique()定义段时,段ID可能不连续,从而产生错误。这两种情况(错误/无错误)如下所示:

with tf.Session() as sess:
    data = tf.constant([[1, 3, 5], [6, 2, 7], [9, 9, 2], [9, 5, 1]])
    #labels = ['a','r','r','d'] # this works
    labels = ['a','r','d','r']  # error: segment ids are not increasing by 1
    y, idx = tf.unique(labels) 
    maxs = tf.segment_max(data, idx)
    rval = sess.run([idx, maxs])
    print('indices: ', rval[0])
    print('maxs: ', rval[1])

如何处理非连续段ID的一般情况?

1 个答案:

答案 0 :(得分:1)

考虑到这个问题是很久以前发布的,我不确定 OP 是否已经找到了答案。但是,我正在为可能最终遇到相同问题的任何人发布解决方案。

这里的诀窍是使用 Attribute。它非常简单。例如,考虑以下代码-

tf.math.unsorted_segment_max()

在上面的代码中,我们首先使用 # Data data = tf.constant([0.4, 0.5, 0.2, 0.7, 0.1], dtype=tf.float32) # Unique segements with Segment IDs unique_elems, segment_ids = tf.unique(tf.constant([1, 2, 1, 1, 2], dtype=tf.int32)) # Count the number of unique segments. nb_segments = tf.size(tf.unique(segment_ids)[0]) # Result result = tf.math.unsorted_segment_max( data=data, segment_ids=segment_ids, num_segments=nb_segments ) tf.print(result) # result turns out to be [0.7, 0.5] 方法查找具有相同值的索引。然后,我们使用 tf.unique() 方法找到 tf.unsorted_segment_max() 张量中对应于这些索引段的最大值。

您可以阅读更多相关信息here