I am working on a convolutional neural net that requires some parts of the a kernel weights to be untrainable. tf.nn.conv2d(x, W) takes in a trainable variable W as weights. How can I make some of the elements of W to be untrainable?
答案 0 :(得分:2)
也许您可以使用可训练的权重W1
,掩码M
表示可训练变量的位置,以及常量/不可重置的权重矩阵W2
,并使用
W = tf.multiply(W1, tf.cast(M, dtype=W1.dtype)) + tf.multiply(W2, tf.cast(tf.logical_not(M), dtype=W2.dtype))