从tensorflow中的两个现有张量构造一个新的张量

时间:2018-07-05 09:15:53

标签: python tensorflow

我想从形状为y的现有张量(b,n,c)x和索引张量{{来构造形状为(b,m,c)的新张量m<n 1}}的形状为idx,它告诉我(b,m)中每一行(长度x)在c中的放置位置。

numpy示例:

y

这将导致一个数组import numpy as np b=2 n=100 m=4 c=3 idx=np.array([[0,31,5,66],[1,73,34,80]]) # shape b x m x=np.random.random((b,m,c)) y=np.zeros((b,n,c)) for i,cur_idx in enumerate(idx): y[i,cur_idx]=x[i] 到处都是零,除了y给定的位置,其中插入了idx的值。

我需要帮助来将此代码片段“翻译”为张量流。

编辑: 我不想创建一个变量,而是创建一个常数张量,因此无法使用tf.scatter_update。

1 个答案:

答案 0 :(得分:2)

您需要tf.scatter_nd

import tensorflow as tf
import numpy as np

b = 2
n = 100
m = 4
c = 3

# Synthetic data
x = tf.reshape(tf.range(b * m * c), (b, m, c))
# Arbitrary indices: [0, 25, 50, 75], [1, 26, 51, 76]
idx = tf.convert_to_tensor(
    np.stack([np.arange(0, n, n // m) + i for i in range(b)], axis=0))

# Add index for the first dimension
idx = tf.concat([
    tf.tile(tf.range(b, dtype=idx.dtype)[:, tf.newaxis, tf.newaxis], (1, m, 1)),
    idx[:, :, tf.newaxis]], axis=2)

# Scatter operation
y = tf.scatter_nd(idx, x, (b, n, c))
with tf.Session() as sess:
    y_val = sess.run(y)
    print(y_val[:, 20:30, :])

输出:

[[[ 0  0  0]
  [ 0  0  0]
  [ 0  0  0]
  [ 0  0  0]
  [ 0  0  0]
  [ 3  4  5]
  [ 0  0  0]
  [ 0  0  0]
  [ 0  0  0]
  [ 0  0  0]]

 [[ 0  0  0]
  [ 0  0  0]
  [ 0  0  0]
  [ 0  0  0]
  [ 0  0  0]
  [ 0  0  0]
  [15 16 17]
  [ 0  0  0]
  [ 0  0  0]
  [ 0  0  0]]]