import tensorflow as tf
slim = tf.contrib.slim
def create_learning_rate(curr_step, lr_config):
base_lr = lr_config.get('base_lr', 0.1)
decay_steps = lr_config.get('decay_steps', [])
decay_rate = lr_config.get('decay_rate', 0.1)
scale_rates = [
lambda: tf.constant(decay_rate**i, dtype=tf.float32)
for i in range(len(decay_steps) + 1)
]
conds = []
prev = -1
for decay_step in decay_steps:
conds.append(tf.logical_and(curr_step > prev, curr_step <= decay_step))
prev = decay_step
conds.append(curr_step > decay_steps[-1])
learning_rate_scale = tf.case(
list(zip(conds, scale_rates)), lambda: 0.0, exclusive=True)
return learning_rate_scale * base_lr
global_step = slim.create_global_step()
train_op = tf.assign_add(global_step, 1)
lr = create_learning_rate(
global_step, {"base_lr": 0.1,
"decay_steps": [10, 20],
"decay_rate": 0.1})
with tf.Session() as sess:
init = tf.global_variables_initializer()
sess.run(init)
for i in range(30):
curr_lr, step, _ = sess.run([lr, global_step, train_op])
print(curr_lr, step)
我想在某些时候降低学习率。但是,它总是0.001。有任何想法吗?还是有更好的方法来调整学习率?
感谢您的帮助。
答案 0 :(得分:0)
这是因为lambda函数通过引用而不是值来捕获变量。
所以正确的方式是
def create_learning_rate(global_step, lr_config):
base_lr = lr_config.get('base_lr', 0.1)
decay_steps = lr_config.get('decay_steps', [])
decay_rate = lr_config.get('decay_rate', 0.1)
prev = -1
scale_rate = 1.0
cases = []
for decay_step in decay_steps:
cases.append((tf.logical_and(global_step > prev,
global_step <= decay_step),
lambda v=scale_rate: v))
scale_rate *= decay_rate
prev = decay_step
cases.append((global_step > decay_step, lambda v=scale_rate: v))
learning_rate_scale = tf.case(cases, lambda: 0.0, exclusive=True)
return learning_rate_scale * base_lr