我已经看到转置和重塑在一起可以提供帮助,但我不知道如何使用。
EG。 dimshuffle(0,'x')
使用转置和重塑它的等价物是什么?或者,还有更好的方法? 谢谢。
答案 0 :(得分:10)
在TensorFlow中有三个相关的操作来实现Theano的dimshuffle
:
tf.transpose()
用于置换张量的维度。如果dimshuffle
的参数中指定的模式是输入张量维度的排列(即没有'x'
或缺少维度),则可以使用tf.transpose()
来实现dimshuffle()
tf.expand_dims()
用于将一个或多个size-1维度添加到张量。这处理了'x'
被指定为dimshuffle()
模式的一部分但不重新排序现有维度的情况。
tf.squeeze()
用于从张量中删除一个或多个size-1维度。这样可以处理从dimshuffle()
模式中省略维度的情况,但不会对现有维度重新排序。
假设输入是矢量,您的示例(dimshuffle(0, 'x')
)只能使用tf.expand_dims()
表示:
input = tf.placeholder(tf.float32, [None]) # Defines an arbitrary-sized vector.
result = tf.expand_dims(input, 1)
print result.get_shape() # ==> TensorShape([Dimension(None), Dimension(1)])
采用更复杂的例子,dimshuffle(1, 'x', 0)
应用于矩阵将是:
input = tf.placeholder(tf.float32, [128, 32]) # Defines a matrix.
output = tf.expand_dims(tf.transpose(input, [1, 0]), 1)
print output.get_shape()
# ==> TensorShape([Dimension(32), Dimension(1), Dimension(128)])
答案 1 :(得分:0)
我在our framework Returnn(here)中为TensorFlow实施了dimshuffle
。代码是这样的:
def expand_multiple_dims(x, axes, name="expand_multiple_dims"):
"""
:param tf.Tensor x:
:param list[int]|tuple[int] axes: after completion, tf.shape(y)[axis] == 1 for axis in axes
:param str name: scope name
:return: y where we have a new broadcast axis for each axis in axes
:rtype: tf.Tensor
"""
with tf.name_scope(name):
for i in sorted(axes):
x = tf.expand_dims(x, axis=i, name="expand_axis_%i" % i)
return x
def dimshuffle(x, axes, name="dimshuffle"):
"""
Like Theanos dimshuffle.
Combines tf.transpose, tf.expand_dims and tf.squeeze.
:param tf.Tensor x:
:param list[int|str]|tuple[int|str] axes:
:param str name: scope name
:rtype: tf.Tensor
"""
with tf.name_scope(name):
assert all([i == "x" or isinstance(i, int) for i in axes])
real_axes = [i for i in axes if isinstance(i, int)]
bc_axes = [i for (i, j) in enumerate(axes) if j == "x"]
if x.get_shape().ndims is None:
x_shape = tf.shape(x)
x = tf.reshape(x, [x_shape[i] for i in range(max(real_axes) + 1)]) # will have static ndims
assert x.get_shape().ndims is not None
# First squeeze missing axes.
i = 0
while i < x.get_shape().ndims:
if i not in real_axes:
x = tf.squeeze(x, axis=i)
real_axes = [(j if (j < i) else (j - 1)) for j in real_axes]
else:
i += 1
# Now permute.
assert list(sorted(real_axes)) == list(range(x.get_shape().ndims))
if real_axes != list(range(x.get_shape().ndims)):
x = tf.transpose(x, real_axes)
# Now add broadcast dimensions.
if bc_axes:
x = expand_multiple_dims(x, bc_axes)
assert len(axes) == x.get_shape().ndims
return x
答案 2 :(得分:0)
如果tensorflow是你的后端
from keras import baskend as K
K.permute_dimension should do
答案 3 :(得分:0)
tf.transpose可能就是你要找的东西。它需要一个任意的排列。