张量切片和张量应用函数

时间:2016-06-22 12:59:37

标签: tensorflow

我有两个问题:

  1. 让张量T具有形状[n1, n2, n3, n4]。让另一个类型为IDX的形状[n1, n2]的张量int包含所需的索引。我如何获得形状张量[n1, n2, n4],我只想提取n3 dim T dim中的IDX指数x = [[[2, 3, 1, 2, 5], [7, 1, 5, 6, 0], [7, 8, 1, 3, 8]], [[0, 7, 7, 6, 9], [5, 6, 7, 8, 8], [2, 3, 2, 9, 6]]] idx = [[1, 0, 2], [4, 3, 3]] res = [[3, 7, 1], [9, 8, 9]]` 。简单的例子:

    :nth-of-class
    1. 给定一个采用1D张量函数func(x,y)的函数如何将其应用于最后一维上的4D张量X,Y,即结果 - 结果为[i,j,k] = f的3D张量(所有i,j,k的X [i,j,k,:],Y [i,j,k,:])。我发现了tf.py_func,但在我的情况下无法使用它。
  2. 提前感谢您的帮助!

1 个答案:

答案 0 :(得分:1)

我使用tf.gather_nd处理问题1。

输入是:

  • x:你的张量T从中提取形状[n1, n2, n3, n4]的值
    • 我使用了从0到size(T)
    • 的更清晰的值
  • idx:您要从T中提取的形状[n1, n2]且包含0n3 - 1
  • 的值的索引

结果是:

  • resTidx的{​​{1}}的提取值,[n1, n2, n4]

由于tf.gather_nd()期望您创建要在x中检索的整个索引(例如[1, 0, 4, 1]),我们必须先在indices_base中创建它。

论证indices需要具有res + R形状,即[n1, n2, n4, R],其中R=4是张量x的等级。

# Inputs:
n1 = 2
n2 = 3
n3 = 5
n4 = 2
x = tf.reshape(tf.range(n1*n2*n3*n4), [n1, n2, n3, n4])  # range(60) reshaped
idx = tf.constant([[1, 0, 2], [4, 3, 3]])  # shape [n1, n2]

range_n1 = tf.reshape(tf.range(n1), [n1, 1, 1, 1])
indices_base_1 = tf.tile(range_n1, [1, n2, n4, 1]) 

range_n2 = tf.reshape(tf.range(n2), [1, n2, 1, 1])
indices_base_2 = tf.tile(range_n2, [n1, 1, n4, 1])

range_n4 = tf.reshape(tf.range(n4), [1, 1, n4, 1])
indices_base_4 = tf.tile(range_n4, [n1, n2, 1, 1])

idx = tf.reshape(idx, [n1, n2, 1, 1])
idx = tf.tile(idx, [1, 1, n4, 1])

# Create the big indices needed of shape [n1, n2, n3, n4]
indices = tf.concat(3, [indices_base_1, indices_base_2, idx, indices_base_4])

# Finally we can apply tf.gather_nd
res = tf.gather_nd(x, indices)

无论如何,这是非常复杂的,我不确定它是否能产生良好的性能。

P.S:你应该在一个单独的帖子中发布问题2。