系统信息 -TensorFlow版本:1.5、1.8、1.12的结果相同 -MacOS 10.14.3
使用tf.estimator.Estimator进行分布式训练时,是否期望在训练完成后杀死参数服务器?还是ps
永远挂起的预期行为?
我正在尝试在本地主机上使用简单的mnist示例,以尝试通过估算器工作来进行分布式训练,但无法这样做。
这是完整的代码(我从https://github.com/yu-iskw/tensorflow-serving-example/blob/master/python/train/mnist_custom_estimator.py下载并修改的代码)。
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import simplejson
import numpy as np
import tensorflow as tf
tf.logging.set_verbosity(tf.logging.INFO)
tf.app.flags.DEFINE_integer('steps', 100, 'The number of steps to train a model')
tf.app.flags.DEFINE_string('job_name', 'master', 'job_name')
tf.app.flags.DEFINE_string('task_index', '0', 'task_index')
tf.app.flags.DEFINE_string('model_dir', './models/ckpt/', 'Dir to save a model and checkpoints')
FLAGS = tf.app.flags.FLAGS
INPUT_FEATURE = 'image'
NUM_CLASSES = 10
def model_fn(features, labels, mode):
# Input Layer
input_layer = features[INPUT_FEATURE]
# Logits layer
logits = tf.layers.dense(inputs=input_layer, units=NUM_CLASSES)
predictions = {
# Generate predictions (for PREDICT and EVAL mode)
"classes": tf.argmax(input=logits, axis=1),
# Add `softmax_tensor` to the graph. It is used for PREDICT and by the
# `logging_hook`.
"probabilities": tf.nn.softmax(logits, name="softmax_tensor")
}
# PREDICT mode
if mode == tf.estimator.ModeKeys.PREDICT:
return tf.estimator.EstimatorSpec(
mode=mode,
predictions=predictions,
export_outputs={
'predict': tf.estimator.export.PredictOutput(predictions)
})
# Calculate Loss (for both TRAIN and EVAL modes)
loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
# Configure the Training Op (for TRAIN mode)
if mode == tf.estimator.ModeKeys.TRAIN:
optimizer = tf.train.AdamOptimizer()
train_op = optimizer.minimize(loss=loss, global_step=tf.train.get_global_step())
return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)
# Add evaluation metrics (for EVAL mode)
eval_metric_ops = {
"accuracy": tf.metrics.accuracy(labels=labels, predictions=predictions["classes"])
}
return tf.estimator.EstimatorSpec(mode=mode, loss=loss, eval_metric_ops=eval_metric_ops)
def main(_):
# Load training and eval data
mnist = tf.contrib.learn.datasets.load_dataset("mnist")
train_data = mnist.train.images # Returns np.array
train_labels = np.asarray(mnist.train.labels, dtype=np.int32)
eval_data = mnist.test.images # Returns np.array
eval_labels = np.asarray(mnist.test.labels, dtype=np.int32)
# reshape images
# To have input as an image, we reshape images beforehand.
train_data = train_data.reshape(train_data.shape[0], 28 * 28)
eval_data = eval_data.reshape(eval_data.shape[0], 28 * 28)
# Create the Estimator
training_config = tf.estimator.RunConfig(
model_dir=FLAGS.model_dir,
save_summary_steps=20,
save_checkpoints_steps=20)
classifier = tf.estimator.Estimator(
model_fn=model_fn,
model_dir=FLAGS.model_dir,
config=training_config)
# Set up logging for predictions
# Log the values in the "Softmax" tensor with label "probabilities"
tensors_to_log = {"probabilities": "softmax_tensor"}
logging_hook = tf.train.LoggingTensorHook(tensors=tensors_to_log, every_n_iter=50)
# Train the model
train_input_fn = tf.estimator.inputs.numpy_input_fn(
x={INPUT_FEATURE: train_data},
y=train_labels,
batch_size=FLAGS.steps,
num_epochs=1,
shuffle=True)
# Evaluate the model and print results
eval_input_fn = tf.estimator.inputs.numpy_input_fn(
x={INPUT_FEATURE: eval_data},
y=eval_labels,
num_epochs=1,
shuffle=False)
train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn)
# setup eval spec evaluating ever n seconds
eval_spec = tf.estimator.EvalSpec(input_fn = eval_input_fn)
tf.estimator.train_and_evaluate(classifier, train_spec, eval_spec)
def make_tf_training_config(args):
"""
Returns TF_CONFIG that can be used to set the environment variable necessary for distributed training
See https://github.com/clusterone/clusterone-tutorials/blob/master/tf-estimator/mnist.py
"""
worker_hosts = ['localhost:2222']
ps_hosts = ['localhost:2224']
tf_config = {
'task': {
'type': FLAGS.job_name,
'index': FLAGS.task_index
},
'cluster': {
'master': [worker_hosts[0]],
'ps': ps_hosts
},
'environment': 'cloud'
}
return tf_config
import os
if __name__ == "__main__":
print("@@@ Version: {}".format(tf.__version__))
tf_config = make_tf_training_config(None)
os.environ['TF_CONFIG'] = simplejson.dumps(tf_config)
tf.app.run()
这是我启动一项主要工作和一项ps工作的方法:
# launch master worker job
python test.py
# launch ps job
python test.py --job_name ps
这是主工作者作业的作业日志(此作业成功并退出后仅最后几行):
INFO:tensorflow:Finished evaluation at 2019-02-19-01:40:18
INFO:tensorflow:Saving dict for global step 3301: accuracy = 0.9918, global_step = 3301, loss = 0.025573652
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 3301: ./models/ckpt/model.ckpt-3301
INFO:tensorflow:Loss for final step: 0.02294134.
这是ps作业的完整作业日志:
@@@ Version: 1.12.0
WARNING:tensorflow:From test.py:118: load_dataset (from tensorflow.contrib.learn.python.learn.datasets) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data.
WARNING:tensorflow:From /Users/hjing/miniconda2/lib/python2.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/__init__.py:80: load_mnist (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
WARNING:tensorflow:From /Users/hjing/miniconda2/lib/python2.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:300: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
WARNING:tensorflow:From /Users/hjing/miniconda2/lib/python2.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Please write your own downloading logic.
WARNING:tensorflow:From /Users/hjing/miniconda2/lib/python2.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting MNIST-data/train-images-idx3-ubyte.gz
WARNING:tensorflow:From /Users/hjing/miniconda2/lib/python2.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting MNIST-data/train-labels-idx1-ubyte.gz
Extracting MNIST-data/t10k-images-idx3-ubyte.gz
Extracting MNIST-data/t10k-labels-idx1-ubyte.gz
WARNING:tensorflow:From /Users/hjing/miniconda2/lib/python2.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:290: __init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
INFO:tensorflow:TF_CONFIG environment variable: {u'environment': u'cloud', u'cluster': {u'ps': [u'localhost:2224'], u'master': [u'localhost:2222']}, u'task': {u'index': u'0', u'type': u'ps'}}
INFO:tensorflow:Using config: {'_save_checkpoints_secs': None, '_session_config': device_filters: "/job:ps"
device_filters: "/job:worker"
device_filters: "/job:master"
allow_soft_placement: true
graph_options {
rewrite_options {
meta_optimizer_iterations: ONE
}
}
, '_keep_checkpoint_max': 5, '_task_type': u'ps', '_train_distribute': None, '_is_chief': False, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x127a47550>, '_model_dir': './models/ckpt/', '_protocol': None, '_save_checkpoints_steps': 20, '_keep_checkpoint_every_n_hours': 10000, '_service': None, '_num_ps_replicas': 1, '_tf_random_seed': None, '_save_summary_steps': 20, '_device_fn': None, '_experimental_distribute': None, '_num_worker_replicas': 1, '_task_id': 0, '_log_step_count_steps': 100, '_evaluation_master': '', '_eval_distribute': None, '_global_id_in_cluster': 1, '_master': u'grpc://localhost:2224'}
INFO:tensorflow:Not using Distribute Coordinator.
INFO:tensorflow:Start Tensorflow server.
2019-02-18 17:38:15.844761: I tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
2019-02-18 17:38:15.846148: I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:222] Initialize GrpcChannelCache for job master -> {0 -> localhost:2222}
2019-02-18 17:38:15.846171: I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:222] Initialize GrpcChannelCache for job ps -> {0 -> localhost:2224}
2019-02-18 17:38:15.846704: I tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc:381] Started server with target: grpc://localhost:2224
并且ps永远挂起。
这是预期的行为吗?
谢谢!
我们欢迎用户的贡献。您能否更新提交PR(使用doc style guide)来解决文档问题?