如何在不使用tf.assign的情况下为TensorFlow中的tf.Variable赋值

时间:2018-03-26 09:03:22

标签: python tensorflow

我有一个包含4x4 identitiy矩阵的变量。 我希望为这个矩阵分配一些值(这些值是由模型学习的)。

当我使用tf.assign()时,我得到一个错误,说跨步切片没有渐变。 我的问题是如何在不使用tf.assign()

的情况下执行此操作

以下是所需行为的示例代码(没有错误,因为这里没有学习这些值):

params = [[1.0, 2.0, 3.0]]
M = tf.Variable(tf.eye(4, batch_shape=[1]), dtype=tf.float32)
M = tf.assign(M[:, 0:3, 3], params)

sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
output_val = sess.run(M)

注意 - 创建变量仅用于容纳这些参数。

更新:我正在添加一个创建错误的最小工作示例。 (显然这样的训练不会带来任何好处。它只是为了说明错误,因为我的代码太长了,无法在这里复制)

params = [[1.0, 2.0, 3.0]]
M_gt = np.eye(4)
M_gt[0:3, 3] = [4.0, 5.0, 6.0]

M = tf.Variable(tf.eye(4, batch_shape=[1]), dtype=tf.float32)
M = tf.assign(M[:, 0:3, 3], params)

loss = tf.nn.l2_loss(M - M_gt)
optimizer = tf.train.AdamOptimizer(0.001)
train_op = optimizer.minimize(loss)
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
sess.run(train_op)

1 个答案:

答案 0 :(得分:0)

这是一个如何做你想做的事情的例子:

import tensorflow as tf
import numpy as np

with tf.Graph().as_default(), tf.Session() as sess:
    params = [[1.0, 2.0, 3.0]]
    M_gt = np.eye(4)
    M_gt[0:3, 3] = [4.0, 5.0, 6.0]

    M = tf.Variable(tf.eye(4, batch_shape=[1]), dtype=tf.float32)
    params_t = tf.constant(params, dtype=tf.float32)

    shape_m = tf.shape(M)
    batch_size = shape_m[0]
    num_m = shape_m[1]
    num_params = tf.shape(params_t)[1]

    last_column = tf.concat([tf.tile(tf.transpose(params_t)[tf.newaxis], (batch_size, 1, 1)),
                             tf.zeros((batch_size, num_m - num_params, 1), dtype=params_t.dtype)], axis=1)
    replace = tf.concat([tf.zeros((batch_size, num_m, num_m - 1), dtype=params_t.dtype), last_column], axis=2)

    r = tf.range(num_m)
    ii = r[tf.newaxis, :, tf.newaxis]
    jj = r[tf.newaxis, tf.newaxis, :]
    mask = tf.tile((ii < num_params) & (tf.equal(jj, num_m - 1)), (batch_size, 1, 1))
    M_replaced = tf.where(mask, replace, M)

    loss = tf.nn.l2_loss(M_replaced - M_gt[np.newaxis])
    optimizer = tf.train.AdamOptimizer(0.001)
    train_op = optimizer.minimize(loss)
    sess = tf.Session()
    init = tf.global_variables_initializer()
    sess.run(init)
    M_val, M_replaced_val = sess.run([M, M_replaced])
    print('M:')
    print(M_val)
    print('M_replaced:')
    print(M_replaced_val)

输出:

M:
[[[ 1.  0.  0.  0.]
  [ 0.  1.  0.  0.]
  [ 0.  0.  1.  0.]
  [ 0.  0.  0.  1.]]]
M_replaced:
[[[ 1.  0.  0.  1.]
  [ 0.  1.  0.  2.]
  [ 0.  0.  1.  3.]
  [ 0.  0.  0.  1.]]]