张量流中的学习率调整

时间:2017-09-07 16:02:04

标签: python tensorflow

import tensorflow as tf

slim = tf.contrib.slim


def create_learning_rate(curr_step, lr_config):
    base_lr = lr_config.get('base_lr', 0.1)
    decay_steps = lr_config.get('decay_steps', [])
    decay_rate = lr_config.get('decay_rate', 0.1)

    scale_rates = [
        lambda: tf.constant(decay_rate**i, dtype=tf.float32)
        for i in range(len(decay_steps) + 1)
    ]

    conds = []
    prev = -1
    for decay_step in decay_steps:
        conds.append(tf.logical_and(curr_step > prev, curr_step <= decay_step))
        prev = decay_step
    conds.append(curr_step > decay_steps[-1])

    learning_rate_scale = tf.case(
        list(zip(conds, scale_rates)), lambda: 0.0, exclusive=True)
    return learning_rate_scale * base_lr


global_step = slim.create_global_step()
train_op = tf.assign_add(global_step, 1)
lr = create_learning_rate(
    global_step, {"base_lr": 0.1,
                "decay_steps": [10, 20],
                "decay_rate": 0.1})

with tf.Session() as sess:
    init = tf.global_variables_initializer()
    sess.run(init)
    for i in range(30):
        curr_lr, step, _ = sess.run([lr, global_step, train_op])
        print(curr_lr, step)

我想在某些时候降低学习率。但是,它总是0.001。有任何想法吗?还是有更好的方法来调整学习率?

感谢您的帮助。

1 个答案:

答案 0 :(得分:0)

这是因为lambda函数通过引用而不是值来捕获变量。

所以正确的方式是

def create_learning_rate(global_step, lr_config):
    base_lr = lr_config.get('base_lr', 0.1)
    decay_steps = lr_config.get('decay_steps', [])
    decay_rate = lr_config.get('decay_rate', 0.1)

    prev = -1
    scale_rate = 1.0

    cases = []
    for decay_step in decay_steps:
        cases.append((tf.logical_and(global_step > prev,
                                    global_step <= decay_step),
                    lambda v=scale_rate: v))
        scale_rate *= decay_rate
        prev = decay_step
    cases.append((global_step > decay_step, lambda v=scale_rate: v))
    learning_rate_scale = tf.case(cases, lambda: 0.0, exclusive=True)
    return learning_rate_scale * base_lr