我正在尝试根据层之间部分连接的最后一个维度收集张量的切片。由于输出张量的形状为[batch_size, h, w, depth]
,我想根据最后一个维度选择切片,例如
# L is intermediate tensor
partL = L[:, :, :, [0,2,3,8]]
然而,tf.gather(L, [0, 2,3,8])
似乎只适用于第一个维度(对吗?)任何人都可以告诉我该怎么做?
答案 0 :(得分:19)
截至TensorFlow 1.3 'Pusher' => Pusher\Pusher::class,
有一个tf.gather
参数,因此不再需要此处的各种变通方法。
https://www.tensorflow.org/versions/r1.3/api_docs/python/tf/gather https://github.com/tensorflow/tensorflow/issues/11223
答案 1 :(得分:9)
这里有一个跟踪错误来支持这个用例:https://github.com/tensorflow/tensorflow/issues/206
现在你可以:
转置您的矩阵,以便首先收集维度(转置费用昂贵)
将你的张量重塑为1d(重塑很便宜)并将你的聚集列索引转换为线性索引的单个元素索引列表,然后重新整形
gather_nd
。仍然需要将列索引转换为单个元素索引列表。答案 2 :(得分:7)
使用gather_nd,您现在可以执行以下操作:
cat_idx = tf.concat([tf.range(0, tf.shape(x)[0]), indices_for_dim1], axis=0)
result = tf.gather_nd(matrix, cat_idx)
另外,正如用户Nova在@Yaroslav Bulatov所引用的主题中报道的那样:
x = tf.constant([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
idx = tf.constant([1, 0, 2])
idx_flattened = tf.range(0, x.shape[0]) * x.shape[1] + idx
y = tf.gather(tf.reshape(x, [-1]), # flatten input
idx_flattened) # use flattened indices
with tf.Session(''):
print y.eval() # [2 4 9]
要点是使张量变平,并使用tf.gather(...)进行大步寻址。
答案 3 :(得分:3)
使用tf.unstack(...),tf.gather(...)和tf.stack(..)的另一种解决方案
代码:
import tensorflow as tf
import numpy as np
shape = [2, 2, 2, 10]
L = np.arange(np.prod(shape))
L = np.reshape(L, shape)
indices = [0, 2, 3, 8]
axis = -1 # last dimension
def gather_axis(params, indices, axis=0):
return tf.stack(tf.unstack(tf.gather(tf.unstack(params, axis=axis), indices)), axis=axis)
print(L)
with tf.Session() as sess:
partL = sess.run(gather_axis(L, indices, axis))
print(partL)
结果:
L =
[[[[ 0 1 2 3 4 5 6 7 8 9]
[10 11 12 13 14 15 16 17 18 19]]
[[20 21 22 23 24 25 26 27 28 29]
[30 31 32 33 34 35 36 37 38 39]]]
[[[40 41 42 43 44 45 46 47 48 49]
[50 51 52 53 54 55 56 57 58 59]]
[[60 61 62 63 64 65 66 67 68 69]
[70 71 72 73 74 75 76 77 78 79]]]]
partL =
[[[[ 0 2 3 8]
[10 12 13 18]]
[[20 22 23 28]
[30 32 33 38]]]
[[[40 42 43 48]
[50 52 53 58]]
[[60 62 63 68]
[70 72 73 78]]]]
答案 4 :(得分:3)
@ Andrei的答案正确版本为
cat_idx = tf.stack([tf.range(0, tf.shape(x)[0]), indices_for_dim1], axis=1)
result = tf.gather_nd(matrix, cat_idx)
答案 5 :(得分:2)
您可以尝试这种方式,例如(在大多数情况下至少在NLP中),
参数的形状为[batch_size, depth]
,索引为[i,j,k,n,m],其长度为batch_size。然后gather_nd
可能会有所帮助。
parameters = tf.constant([
[11, 12, 13],
[21, 22, 23],
[31, 32, 33],
[41, 42, 43]])
targets = tf.constant([2, 1, 0, 1])
batch_nums = tf.range(0, limit=parameters.get_shape().as_list()[0])
indices = tf.stack((batch_nums, targets), axis=1) # the axis is the dimension number
items = tf.gather_nd(parameters, indices)
# which is what we want: [13, 22, 31, 42]
此代码段首先通过batch_num找到第一个维度,然后按目标号码沿该维度获取项目。
答案 6 :(得分:0)
答案 7 :(得分:0)
Tensor没有属性形状,但是get_shape()方法。以下内容可由Python 2.7
运行import tensorflow as tf
import numpy as np
x = tf.constant([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
idx = tf.constant([1, 0, 2])
idx_flattened = tf.range(0, x.get_shape()[0]) * x.get_shape()[1] + idx
y = tf.gather(tf.reshape(x, [-1]), # flatten input
idx_flattened) # use flattened indices
with tf.Session(''):
print y.eval() # [2 4 9]