我想实现一个函数,将一个变量作为输入,将其某些行或列突变,然后将其替换为原始变量。我能够使用tf.gather和tf.scatter_update为行切片实现它,但无法为列切片实现它,因为显然tf.scatter_update仅更新行切片并且没有轴功能。我不是张量流方面的专家,因此我可能会缺少一些东西。有人可以帮忙吗?
def matrix_reg(t, percent_t, beta):
''' Takes a variable tensor t as input and regularizes some of its rows.
The number of rows to be regularized are specified by the percent_t. Reuturns the original tensor by updating its rows indexed by row_ind.
Arguements:
t -- input tensor
percent_t -- percentage of the total rows
beta -- the regularization factor
Output:
the regularized tensor
'''
row_ind = np.random.choice(int(t.shape[0]), int(percent_t*int(t.shape[0])), replace = False)
t_ = tf.gather(t,row_ind)
t_reg = (1+beta)*t_-beta*(tf.matmul(tf.matmul(t_,tf.transpose(t_)),t_))
return tf.scatter_update(t, row_ind, t_reg)
答案 0 :(得分:2)
这里是有关如何更新行或列的小示例。想法是,指定变量的行索引和列索引,以使更新中的每个元素结束。使用tf.meshgrid
可以轻松做到这一点。
import tensorflow as tf
var = tf.get_variable('var', [4, 3], tf.float32, initializer=tf.zeros_initializer())
updates = tf.placeholder(tf.float32, [None, None])
indices = tf.placeholder(tf.int32, [None])
# Update rows
var_update_rows = tf.scatter_update(var, indices, updates)
# Update columns
col_indices_nd = tf.stack(tf.meshgrid(tf.range(tf.shape(var)[0]), indices, indexing='ij'), axis=-1)
var_update_cols = tf.scatter_nd_update(var, col_indices_nd, updates)
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
print('Rows updated:')
print(sess.run(var_update_rows, feed_dict={updates: [[1, 2, 3], [4, 5, 6]], indices: [3, 1]}))
print('Columns updated:')
print(sess.run(var_update_cols, feed_dict={updates: [[1, 5], [2, 6], [3, 7], [4, 8]], indices: [0, 2]}))
输出:
Rows updated:
[[0. 0. 0.]
[4. 5. 6.]
[0. 0. 0.]
[1. 2. 3.]]
Columns updated:
[[1. 0. 5.]
[2. 5. 6.]
[3. 0. 7.]
[4. 2. 8.]]
答案 1 :(得分:0)
有关tf.Variable,请参阅Tensorflow2文档
__getitem__
( var,slice_spec)在给定变量的情况下创建切片助手对象。
这允许根据部分当前内容创建一个子张量 一个变量。有关切片的详细示例,请参见tf.Tensor.getitem。
此功能还允许分配给切片范围。 这类似于Python中的
__setitem__
功能。然而 语法不同,以便用户可以捕获分配 分组或传递给sess.run()的操作。例如,...
这是一个最小的工作示例:
import tensorflow as tf
import numpy as np
var = tf.Variable(np.random.rand(3,3,3))
print(var)
# update the last column of the three (3x3) matrices to random integer values
# note that the update values needs to have the same shape
# as broadcasting is not supported as of TF2
var[:,:,2].assign(np.random.randint(10,size=(3,3)))
print(var)