在最近的TensorFlow(1.13
或2.0
)中,是否有一种方法可以一次通过张量提取非连续切片?怎么做?
例如具有以下张量:
1 2 3 4
5 6 7 8
我想在一个操作中提取第1列和第3列以获得:
2 4
6 8
但是,似乎我无法在单个操作中进行切片。 什么是正确/最快/最优雅的方法?
答案 0 :(得分:2)
第一种方法与索引(TF1.x
,TF2
)一起使用:
import tensorflow as tf
tensor = tf.constant([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=tf.float32)
columns = [1, 3] # <--columns you want to extract
transposed = tf.transpose(tensor)
sliced = [transposed[c] for c in columns]
stacked = tf.transpose(tf.stack(sliced, axis=0))
# print(stacked.numpy()) # <-- TF2, TF1.x-eager
with tf.Session() as sess: # <-- TF1.x
print(sess.run(stacked))
# [[2. 4.]
# [6. 8.]]
将其包装为函数并在%timeit
中运行tf.__version__=='2.0.0-alpha0'
:
154 µs ± 2.61 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
用@tf.function
进行装饰的速度快2倍以上:
import tensorflow as tf
tensor = tf.constant([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=tf.float32)
columns = [1, 3] # <--columns you want to extract
@tf.function
def extract_columns(tensor=tensor, columns=columns):
transposed = tf.transpose(tensor)
sliced = [transposed[c] for c in columns]
stacked = tf.transpose(tf.stack(sliced, axis=0))
return stacked
%timeit -n 10000 extract_columns()
66.8 µs ± 2.03 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
第二种方式是急于执行(TF2
,TF1.x-eager
)的一种方式:
import tensorflow as tf
tensor = tf.constant([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=tf.float32)
columns = [1, 3] # <--columns you want to extract
res = tf.transpose(tf.stack([t for i, t in enumerate(tf.transpose(tensor))
if i in columns], 0))
print(res.numpy())
# [[2. 4.]
# [6. 8.]]
%timeit
in tf.__version__=='2.0.0-alpha0'
:
242 µs ± 2.97 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
第三种方式是使用tf.one_hot()
指定行/列,然后使用tf.boolean_mask()
提取这些行/列(TF1.x
,TF2
):
import tensorflow as tf
tensor = tf.constant([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=tf.float32)
columns = [1, 3] # <--columns you want to extract
mask = tf.one_hot(columns, tensor.get_shape().as_list()[-1])
mask = tf.reduce_sum(mask, axis=0)
res = tf.transpose(tf.boolean_mask(tf.transpose(tensor), mask))
# print(res.numpy()) # <-- TF2, TF1.x-eager
with tf.Session() as sess: # TF1.x
print(sess.run(res))
# [[2. 4.]
# [6. 8.]]
%timeit
in tf.__version__=='2.0.0-alpha0'
:
494 µs ± 4.01 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
答案 1 :(得分:0)
您可以使用整形和切片的组合来获得所有奇数列:
N = 4
M = 10
input = tf.constant(np.random.rand(M, N))
slice_odd = tf.reshape(tf.reshape(input, (-1, 2))[:,1], (-1, int(N/2)))