Tensorflow: Trainable Variable Masking

时间:2017-06-27 07:37:54

标签: tensorflow mask convolution

I am working on a convolutional neural net that requires some parts of the a kernel weights to be untrainable. tf.nn.conv2d(x, W) takes in a trainable variable W as weights. How can I make some of the elements of W to be untrainable?

1 个答案:

答案 0 :(得分:2)

也许您可以使用可训练的权重W1,掩码M表示可训练变量的位置,以及常量/不可重置的权重矩阵W2,并使用

W = tf.multiply(W1, tf.cast(M, dtype=W1.dtype)) + tf.multiply(W2, tf.cast(tf.logical_not(M), dtype=W2.dtype))