Tensorflow 1.8.0:宽和深模型结果不稳定。随机种子不起作用

时间:2018-05-16 13:26:43

标签: python tensorflow machine-learning deep-learning

我使用Tensorflow 1.8.0培训了一个广泛而深入的模型。我的测试和训练数据集是先前拆分的单独文件。我在tf.set_random_seed(1234)之前使用tf.contrib.learn.DNNLinearCombinedClassifier,如下所示 -

tf.set_random_seed(1234)

import tempfile

model_dir = tempfile.mkdtemp()
m = tf.contrib.learn.DNNLinearCombinedClassifier(model_dir=model_dir,
                                                 linear_feature_columns=wide_columns,
                                                 dnn_feature_columns=deep_columns,
                                                 dnn_hidden_units=[100, 50])

显示以下日志 -

INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_task_type': None, '_task_id': 0, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7f394b585c18>, '_master': '', '_num_ps_replicas': 0, '_num_worker_replicas': 0, '_environment': 'local', '_is_chief': True, '_evaluation_master': '', '_train_distribute': None, '_tf_config': gpu_options {
  per_process_gpu_memory_fraction: 1.0
}
, '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_secs': 600, '_log_step_count_steps': 100, '_session_config': None, '_save_checkpoints_steps': None, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_model_dir': '/tmp/tmpxka6vy6t'}

从日志中,我可以看到未应用随机种子。

每当我运行脚本时,我都会得到不同的准确度结果。

如何使结果稳定?为什么没有应用随机种子?

1 个答案:

答案 0 :(得分:1)

经过这么多的斗争,终于,我找到了解决方案。需要在tf_random_seed内设置DNNLinearCombinedClassifier作为config的参数。包括行config=tf.contrib.learn.RunConfig(tf_random_seed=123)解决了问题。它设置随机种子并使结果可重复。

以下是代码的外观 -

# Combining Wide and Deep Models into One
model_dir = tempfile.mkdtemp()
m = tf.contrib.learn.DNNLinearCombinedClassifier(model_dir=model_dir,
                                                 linear_feature_columns=wide_columns,
                                                 dnn_feature_columns=deep_columns,
                                                 dnn_hidden_units=[100, 50],
                                                 config=tf.contrib.learn.RunConfig(tf_random_seed=123))