像在pytorch中一样,如何执行张量切片更新?

时间:2018-10-04 02:04:09

标签: tensorflow tensor

在Pytorch中,您可以轻松地更新张量,如下所示:

 for i in range(x_len):
     tensor_abc[:, i, i] = 0

我们如何更新张量流编码中的张量? 我尝试了tf.assign,无法执行切片更新。 尝试过tf.scatter_update,效果不佳...

2 个答案:

答案 0 :(得分:0)

此答案仅与变量有关。

import tensorflow as tf

sess = tf.InteractiveSession()
v = tf.zeros((5,5,5))
var = tf.Variable(initial_value=v)


init = tf.variables_initializer([var])
sess.run(init)


var = var[ 1 : 2 ,
           1 : 2 ,
           1 : 2 ].assign(tf.ones((1,1,1)))

print(sess.run(var))

这产生

[[[0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]]

 [[0. 0. 0. 0. 0.]
  [0. 1. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]]

 [[0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]]

 [[0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]]

 [[0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]]]

还有这个

var = var[ 1 : 2 ,
           0 : 1 ,
           0 : 1 ].assign(tf.ones((1,1,1)))

产生

  [[[0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]]

 [[1. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]]

  ....
  ....]]

另一个例子是

var = var[ 1 : 2 ,
             : 2 ,
             : 2 ].assign(tf.ones((1,2,2)))

[[[0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]]

 [[1. 1. 0. 0. 0.]
  [1. 1. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0.]]

      ....
      ....]]

您应该探索tf.scatter_nd的张量。

答案 1 :(得分:-1)

tf.Variable是唯一可以更新的张量(https://www.tensorflow.org/guide/variables)。对于变量,您将使用gatherscatter_update之类的代码进行切片。

请注意,其他张量不适合分配。如果您要这样做,我想知道为什么这样做是必要的。但是,仍然可以使用有点费解的代码来创建具有所需值的新张量(而不是就地分配)。例如,以下操作无效:

index = ... tensor = tf.constant([0,1,2,3,4]) 
tensor[i] = 0  
## Doesn't work (TypeError: `Tensor` object does not support item assignment)

但这可以等效:

tensor = tf.constant([0,1,2,3,4]) 
tensor = tf.concat([tensor[:i], tf.zeros_like(tensor[i:i+1]), tensor[i+1:]], 0)  
## This works, creates a new tensor

OR:

tensor = tf.constant([0,1,2,3,4]) 
tensor = tf.concat([tensor[:i], tf.fill([1], 0), tensor[i+1:]], 0)  
## This works, creates a new tensor