如何在切片中使用Tensorflows scatter_nd?

时间:2018-10-29 13:53:44

标签: python tensorflow range slice

我正在尝试仅优化变量的一部分。我发现this似乎有用的答案。

但是我的变量是图像,我只想更改其中的一部分,所以我试图将代码扩展到更大的维度。这似乎工作正常:

import tensorflow as tf
import tensorflow.contrib.opt as opt

X = tf.Variable([[1.0, 2.0], [3.0, 4.0]])

# the next two lines need to change because
# manually specifying the values is not feasible
indexes = tf.constant([[0, 0], [1, 0]])
updates = [X[0, 0], X[1, 0]]

part_X = tf.scatter_nd(indexes, updates, [2, 2])
X_2 = part_X + tf.stop_gradient(-part_X + X)
Y = tf.constant([[2.5, -3.5], [5.5, -7.5]])
loss = tf.reduce_sum(tf.squared_difference(X_2, Y))
opt = opt.ScipyOptimizerInterface(loss, [X])

init = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)
    opt.minimize(sess)
    print("X: {}".format(X.eval()))

但是,由于我要选择的图像尺寸和区域要大得多,因此手动指定所有索引是不可行的。我想知道如何使用切片或范围分配。

1 个答案:

答案 0 :(得分:3)

您可以这样做:

import tensorflow as tf

# Input with size (50, 100)
X = tf.Variable([[0] * 100] * 50)
# Selected slice
row_start = 10
row_end = 30
col_start = 20
col_end = 50
# Make indices from meshgrid
indexes = tf.meshgrid(tf.range(row_start, row_end),
                      tf.range(col_start, col_end), indexing='ij')
indexes = tf.stack(indexes, axis=-1)
# Take slice
updates = X[row_start:row_end, col_start:col_end]
# Build tensor with "filtered" gradient
part_X = tf.scatter_nd(indexes, updates, tf.shape(X))
X_2 = part_X + tf.stop_gradient(-part_X + X)
# Continue as before...