使用tf.expand_dims()
或tf.squeeze()
...而不是tf.reshape()
可以提高性能吗?
感觉tf.reshape()
通常是最佳选择,因为您可以在一行中执行任意数量/组合的整形步骤,并且您完全确定最终形状将是什么。
但是,我读到tf.reshape()
是在内部复制数据的。 tf.expand_dims()
或tf.squeeze()
不这样做吗?是否有性能提升或其他原因要使用竞争对手来tf.reshape()
?
答案 0 :(得分:0)
在TF1.x
中,尤其是在TF1.12.0
中,所有方法在 CPU 上具有相同的性能:
import tensorflow as tf
with tf.device('cpu:0'):
tensor = tf.random.normal(shape=(1, 3, 2))
newaxis = tensor[tf.newaxis, ...]
expanded_dims = tf.expand_dims(tensor, 0)
reshaped = tf.reshape(tensor, (1, ) + tuple(tensor.get_shape().as_list()))
squeezed = tf.squeeze(tensor)
reshaped2 = tf.reshape(tensor, (3, 2))
sess = tf.Session()
%timeit -n 10000 sess.run(newaxis) # 84.3 µs ± 767 ns per loop
%timeit -n 10000 sess.run(expanded_dims) # 83.3 µs ± 837 ns per loop
%timeit -n 10000 sess.run(reshaped) # 83.5 µs ± 946 ns per loop
%timeit -n 10000 sess.run(squeezed) # 81.9 µs ± 852 ns per loop
%timeit -n 10000 sess.run(reshaped2) # 83.9 µs ± 852 ns per loop
在 GPU 上,tf.newaxis
和tf.squeeze()
是最快的:
import tensorflow as tf
with tf.device('gpu:0'):
tensor = tf.random.normal(shape=(1, 3, 2))
newaxis = tensor[tf.newaxis, ...] # <-- Fastest to add new axis
expanded_dims = tf.expand_dims(tensor, 0)
reshaped = tf.reshape(tensor, (1, ) + tuple(tensor.get_shape().as_list()))
squeezed = tf.squeeze(tensor) # <-- Fastest to remove unit-sized dims
reshaped2 = tf.reshape(tensor, (3, 2))
sess = tf.Session()
%timeit -n 10000 sess.run(newaxis) # 133 µs ± 1.58 µs per loop
%timeit -n 10000 sess.run(expanded_dims) # 140 µs ± 1.4 µs per loop
%timeit -n 10000 sess.run(reshaped) #153 µs ± 1.22 µs per loop
%timeit -n 10000 sess.run(squeezed) # 134 µs ± 1.86 µs per loop
%timeit -n 10000 sess.run(reshaped2) # 153 µs ± 1.19 µs per loop
在TF2.0
tf.expand_dims()
中添加维度,而tf.squeeze()
是最快的( CPU ):
import tensorflow as tf
tensor = tf.random.normal(shape=(1, 3, 2))
%timeit -n 10000 tf.expand_dims(tensor, 0) # 7.07 µs ± 162 ns per loop
%timeit -n 10000 tf.reshape(tensor, (1, ) + tuple(tensor.shape.as_list())) # 21.3 µs ± 326 ns per loop
%timeit -n 10000 tensor[tf.newaxis, ...] # 42.9 µs ± 565 ns per loop
%timeit -n 10000 tf.squeeze(tensor) # 9.85 µs ± 166 ns per loop
%timeit -n 10000 tf.reshape(tensor, shape=(3, 2)) # 18.2 µs ± 386 ns per loop