我看过几篇关于在TensorFlow权重变量上添加简单约束(即非负性)的文章,但没有关于如何防止权重改变符号的文章。例如,如果我有
W = tf.get_variable('W1', [512, 256], initializer=initializer, dtype=tf.float32)
如何添加约束,使得初始化W[i,j]
后不能更改符号?我在tf.get_variable()中看不到使用“约束”选项的明确方法。
答案 0 :(得分:1)
我解决此问题的方法如下。
对于每个重量,您都存储初始符号。可以使用以下代码完成
w1 = tf.get_variable('W1', [512, 256], initializer=initializer, dtype=tf.float32)
w1_sign = tf.zeros_like(w1)
store_sign = tf.assign(w1_sign, tf.sign(w1))
只要权重违反符号约束,就可以使用以下代码将权重设置为0。
constraint_op = tf.assign(w1, tf.where(w1_sign * w1 >= 0, w1, 0))
现在您可以按以下方式运行上面的代码
sess = tf.Session()
sess.run(tf.global_variables_initializer())
sess.run(store_sign)
for _ in range(train_itr):
sess.run(some_train_op)
sess.run(constraint_op)
请注意,在以上代码中,您仅运行一次op store_sign
,并且在每次运行constraint_op
之后运行op train_op
。
相同的想法可以与constraints
的{{1}}参数一起应用。