Theano Dimshuffle相当于谷歌的TensorFlow?

时间:2016-02-02 21:04:52

标签: python numpy theano tensorflow

我已经看到转置和重塑在一起可以提供帮助,但我不知道如何使用。

EG。 dimshuffle(0,'x')

使用转置和重塑它的等价物是什么?或者,还有更好的方法? 谢谢。

4 个答案:

答案 0 :(得分:10)

在TensorFlow中有三个相关的操作来实现Theano的dimshuffle

  • tf.transpose()用于置换张量的维度。如果dimshuffle的参数中指定的模式是输入张量维度的排列(即没有'x'或缺少维度),则可以使用tf.transpose()来实现dimshuffle()

  • tf.expand_dims()用于将一个或多个size-1维度添加到张量。这处理了'x'被指定为dimshuffle()模式的一部分但不重新排序现有维度的情况。

  • tf.squeeze()用于从张量中删除一个或多个size-1维度。这样可以处理从dimshuffle()模式中省略维度的情况,但不会对现有维度重新排序。

假设输入是矢量,您的示例(dimshuffle(0, 'x'))只能使用tf.expand_dims()表示:

input = tf.placeholder(tf.float32, [None])  # Defines an arbitrary-sized vector.
result = tf.expand_dims(input, 1)

print result.get_shape()  # ==> TensorShape([Dimension(None), Dimension(1)])

采用更复杂的例子,dimshuffle(1, 'x', 0)应用于矩阵将是:

input = tf.placeholder(tf.float32, [128, 32])  # Defines a matrix.
output = tf.expand_dims(tf.transpose(input, [1, 0]), 1)

print output.get_shape()
# ==> TensorShape([Dimension(32), Dimension(1), Dimension(128)])

答案 1 :(得分:0)

我在our framework Returnnhere)中为TensorFlow实施了dimshuffle。代码是这样的:

def expand_multiple_dims(x, axes, name="expand_multiple_dims"):
  """
  :param tf.Tensor x:
  :param list[int]|tuple[int] axes: after completion, tf.shape(y)[axis] == 1 for axis in axes
  :param str name: scope name
  :return: y where we have a new broadcast axis for each axis in axes
  :rtype: tf.Tensor
  """
  with tf.name_scope(name):
    for i in sorted(axes):
      x = tf.expand_dims(x, axis=i, name="expand_axis_%i" % i)
    return x


def dimshuffle(x, axes, name="dimshuffle"):
  """
  Like Theanos dimshuffle.
  Combines tf.transpose, tf.expand_dims and tf.squeeze.

  :param tf.Tensor x:
  :param list[int|str]|tuple[int|str] axes:
  :param str name: scope name
  :rtype: tf.Tensor
  """
  with tf.name_scope(name):
    assert all([i == "x" or isinstance(i, int) for i in axes])
    real_axes = [i for i in axes if isinstance(i, int)]
    bc_axes = [i for (i, j) in enumerate(axes) if j == "x"]
    if x.get_shape().ndims is None:
      x_shape = tf.shape(x)
      x = tf.reshape(x, [x_shape[i] for i in range(max(real_axes) + 1)])  # will have static ndims
    assert x.get_shape().ndims is not None

    # First squeeze missing axes.
    i = 0
    while i < x.get_shape().ndims:
      if i not in real_axes:
        x = tf.squeeze(x, axis=i)
        real_axes = [(j if (j < i) else (j - 1)) for j in real_axes]
      else:
        i += 1

    # Now permute.
    assert list(sorted(real_axes)) == list(range(x.get_shape().ndims))
    if real_axes != list(range(x.get_shape().ndims)):
      x = tf.transpose(x, real_axes)

    # Now add broadcast dimensions.
    if bc_axes:
      x = expand_multiple_dims(x, bc_axes)
    assert len(axes) == x.get_shape().ndims
    return x

答案 2 :(得分:0)

如果tensorflow是你的后端

from keras import baskend as K
K.permute_dimension should do

答案 3 :(得分:0)

tf.transpose可能就是你要找的东西。它需要一个任意的排列。