从tf.tensor的一个操作中提取几列

时间:2019-05-02 21:35:25

标签: python tensorflow tensorflow2.0

在最近的TensorFlow(1.132.0)中,是否有一种方法可以一次通过张量提取非连续切片?怎么做? 例如具有以下张量:

1 2 3 4
5 6 7 8 

我想在一个操作中提取第1列和第3列以获得:

2 4
6 8

但是,似乎我无法在单个操作中进行切片。 什么是正确/最快/最优雅的方法?

2 个答案:

答案 0 :(得分:2)

第一种方法与索引(TF1.xTF2)一起使用:

import tensorflow as tf

tensor = tf.constant([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=tf.float32)
columns = [1, 3] # <--columns you want to extract

transposed = tf.transpose(tensor)
sliced = [transposed[c] for c in columns]
stacked = tf.transpose(tf.stack(sliced, axis=0))
# print(stacked.numpy()) # <-- TF2, TF1.x-eager

with tf.Session() as sess:  # <-- TF1.x
    print(sess.run(stacked))
# [[2. 4.]
#  [6. 8.]]

将其包装为函数并在%timeit中运行tf.__version__=='2.0.0-alpha0'

154 µs ± 2.61 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

@tf.function进行装饰的速度快2倍以上:

import tensorflow as tf
tensor = tf.constant([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=tf.float32)
columns = [1, 3] # <--columns you want to extract
@tf.function
def extract_columns(tensor=tensor, columns=columns):
    transposed = tf.transpose(tensor)
    sliced = [transposed[c] for c in columns]
    stacked = tf.transpose(tf.stack(sliced, axis=0))
    return stacked

%timeit -n 10000 extract_columns()
66.8 µs ± 2.03 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

第二种方式急于执行TF2TF1.x-eager)的一种方式:

import tensorflow as tf

tensor = tf.constant([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=tf.float32)
columns = [1, 3] # <--columns you want to extract

res = tf.transpose(tf.stack([t for i, t in enumerate(tf.transpose(tensor))
                             if i in columns], 0))
print(res.numpy())
# [[2. 4.]
#  [6. 8.]]

%timeit in tf.__version__=='2.0.0-alpha0'

242 µs ± 2.97 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

第三种方式是使用tf.one_hot()指定行/列,然后使用tf.boolean_mask()提取这些行/列(TF1.xTF2 ):

import tensorflow as tf

tensor = tf.constant([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=tf.float32)
columns = [1, 3] # <--columns you want to extract

mask = tf.one_hot(columns, tensor.get_shape().as_list()[-1])
mask = tf.reduce_sum(mask, axis=0)
res = tf.transpose(tf.boolean_mask(tf.transpose(tensor), mask))
# print(res.numpy()) # <-- TF2, TF1.x-eager

with tf.Session() as sess: # TF1.x
    print(sess.run(res))
# [[2. 4.]
#  [6. 8.]]

%timeit in tf.__version__=='2.0.0-alpha0'

494 µs ± 4.01 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

答案 1 :(得分:0)

您可以使用整形和切片的组合来获得所有奇数列:

N = 4
M = 10
input = tf.constant(np.random.rand(M, N))
slice_odd = tf.reshape(tf.reshape(input, (-1, 2))[:,1], (-1, int(N/2)))