将tf.set_random_seed与tf.estimator.Estimator一起使用

时间:2018-06-27 23:26:53

标签: tensorflow tensorflow-datasets tensorflow-estimator

我正在使用tf.estimator.Estimator管理代码的培训和测试。我正在调整一些超参数,因此我需要确保使用相同的随机种子初始化权重。无论如何,对于tf.estimator创建的会话,set_random_seed是否存在?

1 个答案:

答案 0 :(得分:3)

您应在传递给估算器的配置中定义随机种子:

seed = 2018
config = tf.estimator.RunConfig(model_dir=model_dir, tf_random_seed=seed)

estimator = tf.estimator.Estimator(model_fn, config=config, params=params)

这是RunConfig的文档。


要注意的一件事是,每次运行estimator.train(train_input_fn)时,都会创建一个新图来训练模型(通过调用train_input_fn创建输入管道并调用model_fntrain_input_fn的输出上)。

一个问题是,每次使用相同的随机种子设置此新图。


示例

让我举例说明。假设您在输入管道中执行数据扩充,并在每个时期评估模型。这会给你这样的东西:

def train_input_fn():
    features = tf.random_uniform([])
    labels = tf.random_uniform([])
    dataset = tf.data.Dataset.from_tensors((features, labels))
    return dataset


def model_fn(features, labels, mode, params):
    loss = features
    global_step = tf.train.get_global_step()
    train_op = global_step.assign_add(1)
    return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)


seed = 2018
config = tf.estimator.RunConfig(model_dir="test", tf_random_seed=seed)
estimator = tf.estimator.Estimator(model_fn, config=config)

num_epochs = 10
for epoch in range(num_epochs):
    estimator.train(train_input_fn, steps=1)
    estimator.evaluate(train_input_fn, steps=1)

输入函数创建随机特征(和标签)。发生的情况是,在每个时期创建的功能都将完全相同。输出如下:

INFO:tensorflow:loss = 0.17983198, step = 1
INFO:tensorflow:Saving dict for global step 1: global_step = 1, loss = 0.006007552
INFO:tensorflow:loss = 0.17983198, step = 2
INFO:tensorflow:Saving dict for global step 2: global_step = 2, loss = 0.006007552
INFO:tensorflow:loss = 0.17983198, step = 3
INFO:tensorflow:Saving dict for global step 3: global_step = 3, loss = 0.006007552
...

您会看到每个时期的损失(等于输入要素)是相同的,这意味着每个时期都使用相同的随机种子。

如果您想在每个时期进行评估并执行数据扩充,则会遇到此问题,因为您最终在每个时期都会得到相同的数据扩充


解决方案

一种快速的解决方法是删除随机种子。但是,这会阻止您进行可重复的实验。

另一个更好的解决方案是在每个时期使用相同的model_fn但随机种子不同的方式创建一个新的估算器:

seed = 2018

num_epochs = 10
for epoch in range(num_epochs):
    config = tf.estimator.RunConfig(model_dir="test", tf_random_seed=seed + epoch)
    estimator = tf.estimator.Estimator(model_fn, config=config)

    estimator.train(train_input_fn, steps=1)
    estimator.evaluate(train_input_fn, steps=1)

功能将在每个时期正确更改:

INFO:tensorflow:loss = 0.17983198, step = 1
INFO:tensorflow:Saving dict for global step 1: global_step = 1, loss = 0.006007552
INFO:tensorflow:loss = 0.22154999, step = 2
INFO:tensorflow:Saving dict for global step 2: global_step = 2, loss = 0.70446754
INFO:tensorflow:loss = 0.48594844, step = 3