张量流找到到实点的最小距离

时间:2019-09-11 14:16:49

标签: python tensorflow

我有一个Bx3张量,foo的B =批处理大小的3D点。出于某种幻想,我得到了另一个张量bar,其形状为Bx6x3,每个B 6x3矩阵都对应于foo中的一个点。该6x3矩阵由6个复数值3D点组成。我想做的是,对于我的每个B点,在bar的6个点中找到与foo中的对应点最接近的实值点,最终得到一个新的Bx3 min_bar,该Bx3由bar中最接近的点组成foo中的点。

numpy中,我可以使用带掩码的数组来完成以下任务:

foo = np.array([
    [1,2,3],
    [4,5,6],
    [7,8,9]])
# here bar is only Bx2x3 for simplicity, but the solution generalizes
bar = np.array([
    [[2,3,4],[1+0.1j,2+0.1j,3+0.1j]],
    [[6,5,4],[4,5,7]],
    [[1j,1j,1j],[0,0,0]],
])

#mask complex elements of bar
bar_with_masked_imag = np.ma.array(bar)
candidates = bar_with_masked_imag.imag == 0
bar_with_masked_imag.mask = ~candidates

dists = np.sum(bar_with_masked_imag**2, axis=1)
mindists = np.argmin(dists, axis=1)
foo_indices = np.arange(foo.shape[0])
min_bar = np.array(
    bar_with_masked_imag[foo_indices,mindists,:], 
    dtype=float
)

print(min_bar)
#[[2. 3. 4.]
# [4. 5. 7.]
# [0. 0. 0.]]

但是,tensorflow没有屏蔽数组等。如何将其转换为张量流?

1 个答案:

答案 0 :(得分:2)

这是一种实现方法:

import tensorflow as tf
import math

def solution_tf(foo, bar):
    foo = tf.convert_to_tensor(foo)
    bar = tf.convert_to_tensor(bar)
    # Get real and imaginary parts
    bar_r = tf.cast(tf.real(bar), foo.dtype)
    bar_i = tf.imag(bar)
    # Mask of all real-valued points
    m = tf.reduce_all(tf.equal(bar_i, 0), axis=-1)
    # Distance to every corresponding point
    d = tf.reduce_sum(tf.squared_difference(tf.expand_dims(foo, 1), bar_r), axis=-1)
    # Replace distances of complex points with infinity
    d2 = tf.where(m, d, tf.fill(tf.shape(d), tf.constant(math.inf, d.dtype)))
    # Find smallest distances
    idx = tf.argmin(d2, axis=1)
    # Get points with smallest distances
    b = tf.range(tf.shape(foo, out_type=idx.dtype)[0])
    return tf.gather_nd(bar_r, tf.stack([b, idx], axis=1))

# Test
with tf.Graph().as_default(), tf.Session() as sess:
    foo = tf.constant([
        [1,2,3],
        [4,5,6],
        [7,8,9]], dtype=tf.float32)
    bar = tf.constant([
        [[2,3,4],[1+0.1j,2+0.1j,3+0.1j]],
        [[6,5,4],[4,5,7]],
        [[1j,1j,1j],[0,0,0]]], dtype=tf.complex64)
    sol_tf = solution_tf(foo, bar)
    print(sess.run(sol_tf))
    # [[2. 3. 4.]
    #  [4. 5. 7.]
    #  [0. 0. 0.]]