如何在尺寸0处扫描张量?

时间:2018-05-07 20:43:40

标签: python tensorflow

tensorflow文档指出tf.scan扫描在尺寸0上从elems解包的张量列表。 最简单的扫描版本重复将可调用的fn应用于从头到尾的元素序列。这些元素由尺寸为0的elems打包的张量制成。

我的问题是: 如何扫描其他维度上的张量列表而不是维度0? 例如, 我有一个张量,参考,定义如下。

>>> ref = tf.Variable(tf.ones([2,3,3],tf.int32))
....
>>> print(ref.eval())
[[[1 1 1]
  [1 1 1]
  [1 1 1]]

 [[1 1 1]
  [1 1 1]
  [1 1 1]]]

我想扫描ref [1,0],ref [1,1],ref [1,2]并将一个函数应用于每个,例如add 1。 也就是说,我希望ref在操作之后

>>> print(ref.eval())
[[[1 1 1]
  [1 1 1]
  [1 1 1]]

 [[2 2 2]
  [2 2 2]
  [2 2 2]]]

我可以使用tf.scan来做到这一点吗?如果有,怎么样? 如果没有,任何其他方式怎么做?

感谢。

1 个答案:

答案 0 :(得分:0)

您可以在tf.transpose()之前tf.scan()张量,并在之后转置。

此外,如果您希望变量ref在操作后包含新值,则需要tf.assign()返回值。

请注意,只需应用直接扫描即可实现所需的示例值,但无需转置。请参阅此代码(已测试)(请参阅答案底部的转置另一个示例):

import tensorflow as tf

ref = tf.Variable(tf.ones([2,3,3],tf.int32))
ref2 = tf.scan( lambda y, x: x + 1, ref )

with tf.Session() as sess:
    sess.run( tf.global_variables_initializer() )
    print ( sess.run( ref2 ) ) # ref2 is calculated
    print ( "====================")
    print ( sess.run( ref ) ) # the original ref is unchanged
    print ( "====================")
    sess.run( tf.assign( ref, ref2 ) ) # assign the value back to ref
    print ( sess.run( ref ) )

输出:

  

[[[1 1 1]
    [1 1 1]
    [1 1 1]]

     

[[2 2 2]
    [2 2 2]
    [2 2 2]]]
   ====================
  [[[1 1 1]
    [1 1 1]
    [1 1 1]]

     

[[1 1 1]
    [1 1 1]
    [1 1 1]]]
   ====================
  [[[1 1 1]
    [1 1 1]
    [1 1 1]]

     

[[2 2 2]
    [2 2 2]
    [2 2 2]]]

如果您使用tf.transpose(),则可以沿任何维度进行扫描:

import tensorflow as tf

ref = tf.Variable(tf.ones([2,3,3],tf.int32))
ref2 = tf.transpose( tf.scan( lambda y, x: x + 1,
                              tf.transpose( ref, [ 1, 0, 2 ] ) ),
                     [ 1, 0, 2 ] )

with tf.Session() as sess:
    sess.run( tf.global_variables_initializer() )
    sess.run( tf.assign( ref, ref2 ) )
    print ( sess.run( ref ) )

将输出:

  

[[[1 1 1]
    [2 2 2]
    [2 2 2]]

     

[[1 1 1]
    [2 2 2]
    [2 2 2]]]