我想用3个维度更新张量的一个切片。在How to do slice assignment in Tensorflow之后,我会做类似的事情
import tensorflow as tf
with tf.Session() as sess:
init_val = tf.Variable(tf.zeros((2, 3, 3)))
indices = tf.constant([[0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 1]])
update = tf.scatter_nd_add(init_val, indices, tf.ones(4))
init = tf.global_variables_initializer()
sess.run(init)
print(sess.run(update))
这可行,但是由于我的实际问题更加复杂,因此我想通过定义切片的开始和大小自动生成索引集,例如是否使用tf.slice(...)
。你有什么想法?预先感谢!
我正在使用TensorFlow 1.12,它是当前最新版本。
答案 0 :(得分:1)
tf.strided_slice
支持传递一个var
参数来指示切片所引用的变量,因此当您传递它时,它将返回一个可分配的对象(我不确定为什么他们不这样做取决于输入的类型,但无论如何)。您可以执行以下操作:
import tensorflow as tf
import numpy as np
var = tf.Variable(np.ones((3, 4), dtype=np.float32))
s = tf.strided_slice(var, [0, 2], [2, 3], var=var, name='var_slice')
s2 = s.assign([[2], [3]])
init_op = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init_op)
print(sess.run(s2))
输出:
[[1. 1. 2. 1.]
[1. 1. 3. 1.]
[1. 1. 1. 1.]]
请注意,在tf.strided_slice
中,您给出了开始和结束索引(不包括end),而在tf.slice
中,您给出了开始和大小。另外,按照当前的代码,您必须在slice或assign操作中提供一个名称值(我认为这应该是一个错误,并且会发生,因为该API的该部分几乎全部在内部使用)。