我正在使用tf.estimator.Estimator管理代码的培训和测试。我正在调整一些超参数,因此我需要确保使用相同的随机种子初始化权重。无论如何,对于tf.estimator创建的会话,set_random_seed是否存在?
答案 0 :(得分:3)
您应在传递给估算器的配置中定义随机种子:
seed = 2018
config = tf.estimator.RunConfig(model_dir=model_dir, tf_random_seed=seed)
estimator = tf.estimator.Estimator(model_fn, config=config, params=params)
这是RunConfig
的文档。
要注意的一件事是,每次运行estimator.train(train_input_fn)
时,都会创建一个新图来训练模型(通过调用train_input_fn
创建输入管道并调用model_fn
在train_input_fn
的输出上)。
一个问题是,每次使用相同的随机种子设置此新图。
让我举例说明。假设您在输入管道中执行数据扩充,并在每个时期评估模型。这会给你这样的东西:
def train_input_fn():
features = tf.random_uniform([])
labels = tf.random_uniform([])
dataset = tf.data.Dataset.from_tensors((features, labels))
return dataset
def model_fn(features, labels, mode, params):
loss = features
global_step = tf.train.get_global_step()
train_op = global_step.assign_add(1)
return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)
seed = 2018
config = tf.estimator.RunConfig(model_dir="test", tf_random_seed=seed)
estimator = tf.estimator.Estimator(model_fn, config=config)
num_epochs = 10
for epoch in range(num_epochs):
estimator.train(train_input_fn, steps=1)
estimator.evaluate(train_input_fn, steps=1)
输入函数创建随机特征(和标签)。发生的情况是,在每个时期创建的功能都将完全相同。输出如下:
INFO:tensorflow:loss = 0.17983198, step = 1
INFO:tensorflow:Saving dict for global step 1: global_step = 1, loss = 0.006007552
INFO:tensorflow:loss = 0.17983198, step = 2
INFO:tensorflow:Saving dict for global step 2: global_step = 2, loss = 0.006007552
INFO:tensorflow:loss = 0.17983198, step = 3
INFO:tensorflow:Saving dict for global step 3: global_step = 3, loss = 0.006007552
...
您会看到每个时期的损失(等于输入要素)是相同的,这意味着每个时期都使用相同的随机种子。
如果您想在每个时期进行评估并执行数据扩充,则会遇到此问题,因为您最终在每个时期都会得到相同的数据扩充。
一种快速的解决方法是删除随机种子。但是,这会阻止您进行可重复的实验。
另一个更好的解决方案是在每个时期使用相同的model_fn
但随机种子不同的方式创建一个新的估算器:
seed = 2018
num_epochs = 10
for epoch in range(num_epochs):
config = tf.estimator.RunConfig(model_dir="test", tf_random_seed=seed + epoch)
estimator = tf.estimator.Estimator(model_fn, config=config)
estimator.train(train_input_fn, steps=1)
estimator.evaluate(train_input_fn, steps=1)
功能将在每个时期正确更改:
INFO:tensorflow:loss = 0.17983198, step = 1
INFO:tensorflow:Saving dict for global step 1: global_step = 1, loss = 0.006007552
INFO:tensorflow:loss = 0.22154999, step = 2
INFO:tensorflow:Saving dict for global step 2: global_step = 2, loss = 0.70446754
INFO:tensorflow:loss = 0.48594844, step = 3