如何在tensorflow的SKCompat中更改global_step

时间:2017-09-02 11:03:25

标签: python tensorflow scikit-learn

我使用SKCompat中的tensorflow.contrib.learn类来分类MNIST数据:

import tensorflow as tf
from tensorflow.contrib.learn import SKCompat
from tensorflow.contrib.learn import RunConfig

config = tf.contrib.learn.RunConfig(save_summary_steps=1000
                                    ,log_step_count_steps=1000
                                    ,model_dir = "tmp/TF/")

feature_columns = tf.contrib.learn.infer_real_valued_columns_from_input(X_train)

dnn_clf = SKCompat(tf.contrib.learn.DNNClassifier(hidden_units=[1200,600, 
     300, 150, 75], n_classes=10 ,feature_columns=feature_columns, dropout = .5, 
     config=config))

dnn_clf.fit(x=X_train, y=y_train, batch_size=128, steps=5000)

按预期工作, 一个小问题 :它每100步输出一次Info消息,我怀疑这些消息是在global_step变量中编码的:

INFO:tensorflow:Using config: {'_task_type': None, '_task_id': 0, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7f3a663d09b0>, '_master': '', '_num_ps_replicas': 0, '_num_worker_replicas': 0, '_environment': 'local', '_is_chief': True, '_evaluation_master': '', '_tf_config': gpu_options {
  per_process_gpu_memory_fraction: 1.0
}
, '_tf_random_seed': None, '_save_summary_steps': 1000, '_save_checkpoints_secs': 600, '_log_step_count_steps': 1000, '_session_config': None, '_save_checkpoints_steps': None, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_model_dir': 'tmp/TF/'}
WARNING:tensorflow:From /home/sergey/anaconda3/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/estimators/head.py:641: scalar_summary (from tensorflow.python.ops.logging_ops) is deprecated and will be removed after 2016-11-30.
Instructions for updating:
Please switch to tf.summary.scalar. Note that tf.summary.scalar uses the node name instead of the tag. This means that TensorFlow will automatically de-duplicate summary names based on the scope they are created in. Also, passing a tensor or list of tags to a scalar summary op is no longer supported.
WARNING:tensorflow:From /home/sergey/anaconda3/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/estimators/dnn.py:192: get_global_step (from tensorflow.contrib.framework.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Please switch to tf.train.get_global_step
WARNING:tensorflow:From /home/sergey/anaconda3/lib/python3.6/site-packages/tensorflow/contrib/layers/python/layers/optimizers.py:161: assert_global_step (from tensorflow.contrib.framework.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Please switch to tf.train.assert_global_step
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Saving checkpoints for 1 into tmp/TF/model.ckpt.
INFO:tensorflow:loss = 2.50853, step = 1
INFO:tensorflow:global_step/sec: 60.7314
INFO:tensorflow:loss = 1.46667, step = 101 (1.648 sec)
INFO:tensorflow:global_step/sec: 53.2696
INFO:tensorflow:loss = 1.00616, step = 201 (1.877 sec)
INFO:tensorflow:global_step/sec: 57.6543
INFO:tensorflow:loss = 0.625174, step = 301 (1.734 sec)
INFO:tensorflow:global_step/sec: 53.0346
INFO:tensorflow:loss = 0.692355, step = 401 (1.885 sec)
...

问题:

有没有办法减少Info outpus让我们说每1步&#39000步?

下载数据*:

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/tmp/data/")
X_train = mnist.train.images
X_test = mnist.test.images
y_train = mnist.train.labels.astype("int")
y_test = mnist.test.labels.astype("int")

*来源:使用Scikit-learn和Tensorflow实践机器学习

1 个答案:

答案 0 :(得分:1)

您可以使用RunConfig来设置运行配置。对于您的案例设置log_step_count_steps=1000

config = tf.contrib.learn.RunConfig(master=None, num_cores=0,
    log_device_placement=False,
    gpu_memory_fraction=1,
    tf_random_seed=None,
    save_summary_steps=100,
    save_checkpoints_secs=_USE_DEFAULT,
    save_checkpoints_steps=None,
    keep_checkpoint_max=5,
    keep_checkpoint_every_n_hours=10000,
    log_step_count_steps=1000,
    evaluation_master='',
    model_dir=None,
    session_config=None)

dnn_clf = SKCompat(tf.contrib.learn.DNNClassifier(hidden_units=[1200,600, 
     300, 150, 75], n_classes=10 ,feature_columns=feature_columns, dropout = .5, 
     config=config))