Tensorflow,Keras:如何创建仅在特定位置更新的可训练变量?

时间:2018-08-08 11:28:34

标签: python tensorflow keras gradient-descent tensor

例如,y=Ax

其中A是对角矩阵,其可训练权重(w1, w2, w3)在对角线上。

A = [w1 ... ...
    ...  w2 ...
    ... ... w3]

如何在Tensorflow或Keras中创建这种可训练的A

如果我尝试A = tf.Variable(np.eye(3)),则可训练砝码的总数将是3 * 3 = 9,而不是3。因为我想要更新(w1,w2,w3) 3个重量。

一个技巧可能是使用A = tf.Variable([1, 1, 1]) * np.eye(3),以便将3个可训练的权重映射到A的对角线中。

我的问题是:

  1. 该技巧可以达到我的目的吗?可以正确计算梯度吗?

  2. 如果A的情况更加复杂怎么办?例如。如果我要创建:

More complex Example

w1, w2, ..., w6是要更新的权重。

3 个答案:

答案 0 :(得分:2)

您有两种不同的工具来解决此问题。

  1. 您可以创建所需的变量并将其重新排列为所需的形式。
  2. 您可以创建超出所需数量的变量,然后丢弃某些变量以达到所需的形式。

两种方法都不是排他的,您可以混合使用#1和#2类型的连续步骤。

例如,对于第一个示例(对角矩阵),我们可以使用方法1。

w = tf.Variable(tf.zeros(n))
A = tf.diag(w) # creates a diagonal matrix with elements of w

对于第二个更复杂的示例,我们可以使用方法2。

A = tf.Variable(tf.zeros((n, n)))
A = tf.matrix_band_part(A, 1, 1) # keep only the central band of width 3
A = tf.matrix_set_diag(A, tf.ones(n)) # set diagonal to 1

答案 1 :(得分:0)

创建向量或矩阵变量都可以正常工作

问题1。

不用担心,渐变会正确计算

问题2。

如果它变得更加复杂(如您提到的那样),您仍然可以创建向量变量,然后从该变量构建矩阵。

或者,您可以创建一个矩阵变量,然后仅使用int而不是&&来更新其中的一部分

答案 2 :(得分:0)

对于更复杂的情况,其中A需要分成几个部分,其中只有某些部分是可训练的,而其他部分可以具有任意值,那么最简单的方法是构建各个部分,然后将它们连接在一起

例如,我需要任意大小的权重矩阵A(对于大小4x4)如下所示(4个不同的部分2x2):

#  [[0.,   0.,   -0.2,    0.],
#   [0.,   0.,   0.,      -0.2],
#   [0.35, 0.,   train,   train],
#   [0.,   0.35, train,   train]]

执行此操作的代码:

n_neurons = 3
zero_quarter = tf.zeros((n_neurons, n_neurons))  # upper left quarter are zeros
neg_diag = tf.diag(tf.ones(n_neurons) * -0.2)  # upper right is negative diag
pos_diag = tf.diag(tf.ones(n_neurons) * 0.35)  # lower left is positive diag
# lower right quarter is trainable randomly initialized vars
train_quarter = tf.get_variable(name='TrainableWeights', shape=[n_neurons, n_neurons])

weights_row0 = tf.concat([zero_quarter, neg_diag], axis=1)
weights_row1 = tf.concat([pos_diag, train_quarter], axis=1)

weights = tf.concat([weights_row0, weights_row1], axis=0)

sess = tf.Session()
sess.run(tf.global_variables_initializer())
print(sess.run(weights))

结果是:

[[ 0.          0.          0.         -0.2         0.          0.        ]
 [ 0.          0.          0.          0.         -0.2         0.        ]
 [ 0.          0.          0.          0.          0.         -0.2       ]
 [ 0.35        0.          0.         -0.61401606  0.39812732  0.72078323]
 [ 0.          0.35        0.         -0.34560132  0.40494204  0.36660933]
 [ 0.          0.          0.35        0.34820676  0.5112138  -0.97605824]]

只有右下3x3部分可以训练。