如何通过更改符号来限制TensorFlow中的权重?

时间:2018-07-03 23:34:21

标签: python tensorflow

我看过几篇关于在TensorFlow权重变量上添加简单约束(即非负性)的文章,但没有关于如何防止权重改变符号的文章。例如,如果我有 W = tf.get_variable('W1', [512, 256], initializer=initializer, dtype=tf.float32) 如何添加约束,使得初始化W[i,j]后不能更改符号?我在tf.get_variable()中看不到使用“约束”选项的明确方法。

1 个答案:

答案 0 :(得分:1)

我解决此问题的方法如下。

对于每个重量,您都存储初始符号。可以使用以下代码完成

w1 = tf.get_variable('W1', [512, 256], initializer=initializer, dtype=tf.float32)
w1_sign = tf.zeros_like(w1)
store_sign = tf.assign(w1_sign, tf.sign(w1))

只要权重违反符号约束,就可以使用以下代码将权重设置为0。

constraint_op = tf.assign(w1, tf.where(w1_sign * w1 >= 0, w1, 0))

现在您可以按以下方式运行上面的代码

sess = tf.Session()
sess.run(tf.global_variables_initializer())
sess.run(store_sign)
for _ in range(train_itr):
    sess.run(some_train_op)
    sess.run(constraint_op)

请注意,在以上代码中,您仅运行一次op store_sign,并且在每次运行constraint_op之后运行op train_op

相同的想法可以与constraints的{​​{1}}参数一起应用。