在Tensorflow中进行急切的执行训练期间修复变量的一部分

时间:2019-08-27 19:12:17

标签: python tensorflow

有没有办法在急切的执行更新步骤中仅更新变量的 some ?考虑这个最小的工作示例:

import tensorflow as tf
tf.enable_eager_execution()
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)

x = tf.Variable([1.0, 2.0])

def train(x):
    with tf.GradientTape() as tape:
        loss = x[0]**2 + x[1]**2 + 1/(x[0]+x[1])
        variables = [x]
        grads = tape.gradient(loss, variables)
        optimizer.apply_gradients(zip(grads, variables))

for _ in range(2000):
    train(x)
    print(x.numpy())

收敛到[0.5, 0.5]。我想将x[0]的值固定为初始值,同时将其他所有内容保持原样。到目前为止,我已经尝试过:

  • 在训练步骤中添加x[0].assign(1.0)操作,从而不必要地增加图形
  • 更改variables = [x[:-1]]会得到ValueError: No gradients provided for any variable: ['tf.Tensor([1.], shape=(1,), dtype=float32)']
  • 添加grads = [grads[0][1:]]会得到tensorflow.python.framework.errors_impl.InvalidArgumentError: var and delta do not have the same shape[2] [1] [Op:ResourceApplyGradientDescent]
  • 同时做这两项,得到TypeError: 'NoneType' object is not subscriptable

对于此MWE,我可以轻松地使用两个单独的变量,但是我对只希望更新数组的已知切片的一般情况感兴趣。

1 个答案:

答案 0 :(得分:1)

您可以将不想更新的索引的梯度设置为0。在下面的代码段中,mask张量指示我们要更新的元素(值1),以及我们不想更新的元素(值0)。

import tensorflow as tf
tf.enable_eager_execution()

optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)

x = tf.Variable([1.0, 2.0])
mask = tf.constant([0.0, 1.0])

def train(x):
    with tf.GradientTape() as tape:
        loss = x[0]**2 + x[1]**2 + 1/(x[0]+x[1])
        variables = [x]

        grads = tape.gradient(loss, variables) * mask
        optimizer.apply_gradients(zip(grads, variables))

for _ in range(100):
    train(x)
    print(x.numpy())

针对您的问题的另一种可能的解决方案是停止x[0]所依赖的操作上的渐变。例如:

optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)

x = tf.Variable([1.0, 2.0])

def train(x):
    with tf.GradientTape() as tape:
        loss = tf.stop_gradient(x[0])**2 + x[1]**2 + 1/(tf.stop_gradient(x[0])+x[1])
        variables = [x]

        grads = tape.gradient(loss, variables)
        optimizer.apply_gradients(zip(grads, variables))

for _ in range(100):
    train(x)
    print(x.numpy())