我有一个包含4x4 identitiy矩阵的变量。 我希望为这个矩阵分配一些值(这些值是由模型学习的)。
当我使用tf.assign()
时,我得到一个错误,说跨步切片没有渐变。
我的问题是如何在不使用tf.assign()
以下是所需行为的示例代码(没有错误,因为这里没有学习这些值):
params = [[1.0, 2.0, 3.0]]
M = tf.Variable(tf.eye(4, batch_shape=[1]), dtype=tf.float32)
M = tf.assign(M[:, 0:3, 3], params)
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
output_val = sess.run(M)
注意 - 创建变量仅用于容纳这些参数。
更新:我正在添加一个创建错误的最小工作示例。 (显然这样的训练不会带来任何好处。它只是为了说明错误,因为我的代码太长了,无法在这里复制)
params = [[1.0, 2.0, 3.0]]
M_gt = np.eye(4)
M_gt[0:3, 3] = [4.0, 5.0, 6.0]
M = tf.Variable(tf.eye(4, batch_shape=[1]), dtype=tf.float32)
M = tf.assign(M[:, 0:3, 3], params)
loss = tf.nn.l2_loss(M - M_gt)
optimizer = tf.train.AdamOptimizer(0.001)
train_op = optimizer.minimize(loss)
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
sess.run(train_op)
答案 0 :(得分:0)
这是一个如何做你想做的事情的例子:
import tensorflow as tf
import numpy as np
with tf.Graph().as_default(), tf.Session() as sess:
params = [[1.0, 2.0, 3.0]]
M_gt = np.eye(4)
M_gt[0:3, 3] = [4.0, 5.0, 6.0]
M = tf.Variable(tf.eye(4, batch_shape=[1]), dtype=tf.float32)
params_t = tf.constant(params, dtype=tf.float32)
shape_m = tf.shape(M)
batch_size = shape_m[0]
num_m = shape_m[1]
num_params = tf.shape(params_t)[1]
last_column = tf.concat([tf.tile(tf.transpose(params_t)[tf.newaxis], (batch_size, 1, 1)),
tf.zeros((batch_size, num_m - num_params, 1), dtype=params_t.dtype)], axis=1)
replace = tf.concat([tf.zeros((batch_size, num_m, num_m - 1), dtype=params_t.dtype), last_column], axis=2)
r = tf.range(num_m)
ii = r[tf.newaxis, :, tf.newaxis]
jj = r[tf.newaxis, tf.newaxis, :]
mask = tf.tile((ii < num_params) & (tf.equal(jj, num_m - 1)), (batch_size, 1, 1))
M_replaced = tf.where(mask, replace, M)
loss = tf.nn.l2_loss(M_replaced - M_gt[np.newaxis])
optimizer = tf.train.AdamOptimizer(0.001)
train_op = optimizer.minimize(loss)
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
M_val, M_replaced_val = sess.run([M, M_replaced])
print('M:')
print(M_val)
print('M_replaced:')
print(M_replaced_val)
输出:
M:
[[[ 1. 0. 0. 0.]
[ 0. 1. 0. 0.]
[ 0. 0. 1. 0.]
[ 0. 0. 0. 1.]]]
M_replaced:
[[[ 1. 0. 0. 1.]
[ 0. 1. 0. 2.]
[ 0. 0. 1. 3.]
[ 0. 0. 0. 1.]]]