张量流中的output_data [output_ids == i] = input_data [input_ids == i]

时间:2019-05-10 08:41:20

标签: python tensorflow

在numpy代码中,如果希望相同的id获得相同的值,则可以例如:

input_data = np.array([0.1, 0.2, 0.3])
input_ids = np.array([0, 1, 2])
output_ids = np.array([2, 0, 1, 0])
output_data = np.array([0.1, 0.1, 0.1, 0.1])
for i in input_ids:
    output_data[output_ids == i] = input_data[input_ids == i]
print(output_data)

输出:[0.3 0.1 0.2 0.1]

注意:input_ids = unique(input_ids),开头是唯一的。

在tensorflow中,我如何执行这种代码,我应该使用哪种功能。有类似的例子吗?

  • input_data:一个张量,可以是float64,float 32

  • output_data:一个张量,与input_data类型相同

  • input_ids:一个张量,必须为int32或int64。

  • output_ids:张量,必须为int32或int64。

1 个答案:

答案 0 :(得分:2)

我将按照复杂性的升序为您提供一些选择。在最简单的情况下,input_ids始终是从0开始的整数序列,对应于input_data[0, 1, 2, ...])的索引。在这种情况下,您只需执行以下操作即可:

import tensorflow as tf

with tf.Graph().as_default(), tf.Session() as sess:
    input_data = tf.constant([0.1, 0.2, 0.3])
    output_ids = tf.constant([2, 0, 1, 0])
    output_data = tf.gather(input_data, output_ids)
    print(sess.run(output_data))
    # [0.3 0.1 0.2 0.1]

如果input_idsinput_data的索引不对应,但仍按升序排序,则可以执行以下操作:

import tensorflow as tf

with tf.Graph().as_default(), tf.Session() as sess:
    input_data = tf.constant([0.1, 0.2, 0.3])
    input_ids = tf.constant([-2, 0, 4])
    output_ids = tf.constant([4, -2, 0, -2])
    output_idx = tf.searchsorted(input_ids, output_ids)
    output_data = tf.gather(input_data, output_idx)
    print(sess.run(output_data))
    # [0.3 0.1 0.2 0.1]

最一般的情况是input_ids是未排序的整数数组。在这种情况下,您可以执行以下操作:

import tensorflow as tf

with tf.Graph().as_default(), tf.Session() as sess:
    input_data = tf.constant([0.1, 0.2, 0.3])
    input_ids = tf.constant([3, 1, 6])
    output_ids = tf.constant([6, 3, 1, 3])
    # From TF v1.13
    s = tf.argsort(input_ids)
    # Before TF v1.13
    s = tf.contrib.framework.argsort(input_ids)
    output_idx_s = tf.searchsorted(tf.gather(input_ids, s), output_ids)
    output_data = tf.gather(input_data, tf.gather(s, output_idx_s))
    print(sess.run(output_data))
    # [0.3 0.1 0.2 0.1]

当然,在所有情况下,您都可以使用二次解,将input_ids中的每个值与output_ids中的每个值进行比较。我将在下面编写它以供参考,但是与以前相比,它的时间和内存效率较低,因此实际上没有理由更喜欢它。

import tensorflow as tf

with tf.Graph().as_default(), tf.Session() as sess:
    input_data = tf.constant([0.1, 0.2, 0.3])
    input_ids = tf.constant([3, 1, 6])
    output_ids = tf.constant([6, 3, 1, 3])
    eq = tf.equal(tf.expand_dims(output_ids, 1), input_ids)
    output_idx = tf.argmax(tf.cast(eq, tf.int8), axis=1)
    output_data = tf.gather(input_data, output_idx)
    print(sess.run(output_data))
    # [0.3 0.1 0.2 0.1]

编辑:正如giser_yugang所指出的,也可能存在并非output_ids中的所有值都在input_ids中的情况。在这种情况下,将使用output_data的初始值。您可以通过以下方式实现该目标:

import tensorflow as tf

with tf.Graph().as_default(), tf.Session() as sess:
    input_data = tf.constant([0.1, 0.2, 0.3])
    input_ids = tf.constant([3, 1, 6])
    output_data = tf.constant([0., 0., 0., 0., 0.])
    output_ids = tf.constant([6, 3, 1, 3, 0])
    # From TF v1.13
    s = tf.argsort(input_ids)
    # Before TF v1.13
    s = tf.contrib.framework.argsort(input_ids)
    input_ids_s = tf.gather(input_ids, s)
    n = tf.size(input_ids)
    output_idx_s = tf.minimum(tf.searchsorted(input_ids_s, output_ids), n - 1)
    output_data = tf.where(tf.equal(output_ids, tf.gather(input_ids_s, output_idx_s)),
                           tf.gather(input_data, tf.gather(s, output_idx_s)),
                           output_data)
    print(sess.run(output_data))
    # [0.3 0.1 0.2 0.1 0. ]