我用估算器运行了mnist代码,在单个进程中运行时它没问题,但在分布式模式的3个进程中运行时挂起。代码(PS)如下所示。请注意,worker应该将TF_CONFIG中的task.type更改为' worker',并且Chief应该更改为' chief'。
import tensorflow as tf
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import matplotlib.cm as cm
tf.logging.set_verbosity(tf.logging.INFO)
def load2():
train = pd.read_csv('/tmp/train.csv')
test = pd.read_csv('/tmp/test.csv')
labels = train['label']
images = train.iloc[:, 1:]
image_size = 28
train_ds, valid_ds, train_labels, valid_labels = train_test_split(images, labels, test_size=0.33, random_state=42)
print len(train_ds)
print len(valid_ds)
print len(train_labels)
print len(valid_labels)
data = images
train_ds, train_labels = reformat(train_ds, train_labels)
valid_ds, valid_labels = reformat(valid_ds, valid_labels)
test_ds, test_labels = reformat(test, train_labels[:len(test)])
print('Training set', train_ds.shape, train_labels.shape)
print('Validation set', valid_ds.shape, valid_labels.shape)
print('Test set', test_ds.shape, test_labels.shape)
def reformat(dataset, labels):
num_channels = 1
num_labels = 10
dataset = dataset.values.reshape(
(-1, image_size, image_size, num_channels)).astype(np.float32)
labels = (np.arange(num_labels) == labels[:, None]).astype(np.float32)
return dataset, labels
def cnn_model_fn(features, labels, mode):
input_layer = tf.reshape(features['x'], [-1, 28, 28, 1])
conv1 = tf.layers.conv2d(
inputs=input_layer,
filters=32,
kernel_size=[5, 5],
padding="same",
activation=tf.nn.relu
)
pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2) # 14*14*32
conv2 = tf.layers.conv2d(
inputs=pool1,
filters=64,
kernel_size=[5, 5],
padding='same',
activation=tf.nn.relu
)
pool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2)
pool2_flat = tf.reshape(pool2, [-1, 7 * 7 * 64])
dense = tf.layers.dense(inputs=pool2_flat, units=1024, activation=tf.nn.relu)
dropout = tf.layers.dropout(inputs=dense, rate=0.4, training=mode == tf.estimator.ModeKeys.TRAIN)
logits = tf.layers.dense(inputs=dropout, units=10)
print '####', logits
predictions = {
'classes': tf.argmax(logits, axis=1),
'probabilities': tf.nn.softmax(logits, name='softmax_tensor')
}
if mode == tf.estimator.ModeKeys.PREDICT:
return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)
if mode == tf.estimator.ModeKeys.TRAIN:
loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.001)
train_op = optimizer.minimize(
loss=loss,
global_step=tf.train.get_global_step()
)
return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)
def main(argvs):
mnist = tf.contrib.learn.datasets.load_dataset("mnist")
train_data = mnist.train.images
print train_data.shape
train_labels = np.asarray(mnist.train.labels, dtype=np.int32)
eval_data = mnist.test.images
eval_labels = np.asarray(mnist.test.labels, dtype=np.int32)
mnist_classifier = tf.estimator.Estimator(
model_fn=cnn_model_fn, model_dir="/tmp/mnist_convnet_model")
# Set up logging for predictions
# Log the values in the "Softmax" tensor with label "probabilities"
tensors_to_log = {"probabilities": "softmax_tensor"}
logging_hook = tf.train.LoggingTensorHook(
tensors=tensors_to_log, every_n_iter=50)
print 'labels: ', train_labels
# Train the model
train_input_fn = tf.estimator.inputs.numpy_input_fn(
x={"x": train_data},
y=train_labels,
batch_size=100,
num_epochs=None,
shuffle=True)
mnist_classifier.train(
input_fn=train_input_fn,
steps=20000,
hooks=[logging_hook])
eval_input_fn = tf.estimator.inputs.numpy_input_fn(
x={"x": eval_data},
y=eval_labels,
num_epochs=1,
shuffle=False)
if __name__ == '__main__':
import os,json
os.environ['TF_CONFIG']=json.dumps({
"cluster": {
"ps": [
"127.0.0.1:34567"
],
"chief": [
"127.0.0.1:34568"
],
"worker": [
"127.0.0.1:34569"
]
},
"task": {
"index": 0,
"type": "ps" # optional: chief, ps, worker
}
})
tf.app.run()
这是日志:
WARNING:tensorflow:From ps_cnn_mnist.py:94: load_dataset (from tensorflow.contrib.learn.python.learn.datasets) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data.
WARNING:tensorflow:From /home/yiguang.wyg/tools/anaconda2/lib/python2.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/__init__.py:80: load_mnist (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
WARNING:tensorflow:From /home/yiguang.wyg/tools/anaconda2/lib/python2.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:300: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
WARNING:tensorflow:From /home/yiguang.wyg/tools/anaconda2/lib/python2.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Please write your own downloading logic.
WARNING:tensorflow:From /home/yiguang.wyg/tools/anaconda2/lib/python2.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting MNIST-data/train-images-idx3-ubyte.gz
WARNING:tensorflow:From /home/yiguang.wyg/tools/anaconda2/lib/python2.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting MNIST-data/train-labels-idx1-ubyte.gz
Extracting MNIST-data/t10k-images-idx3-ubyte.gz
Extracting MNIST-data/t10k-labels-idx1-ubyte.gz
WARNING:tensorflow:From /home/yiguang.wyg/tools/anaconda2/lib/python2.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:290: __init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
(55000, 784)
INFO:tensorflow:TF_CONFIG environment variable: {u'cluster': {u'ps': [u'127.0.0.1:34567'], u'chief': [u'127.0.0.1:34568'], u'worker': [u'127.0.0.1:34569']}, u'task': {u'index': 0, u'type': u'ps'}}
INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_save_checkpoints_secs': 600, '_session_config': None, '_keep_checkpoint_max': 5, '_task_type': u'ps', '_train_distribute': None, '_is_chief': False, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7f8b25008990>, '_evaluation_master': '', '_save_checkpoints_steps': None, '_keep_checkpoint_every_n_hours': 10000, '_service': None, '_num_ps_replicas': 1, '_tf_random_seed': None, '_master': u'grpc://127.0.0.1:34567', '_num_worker_replicas': 2, '_task_id': 0, '_log_step_count_steps': 100, '_model_dir': '/tmp/mnist_convnet_model', '_global_id_in_cluster': 2, '_save_summary_steps': 100}
labels: [7 3 4 ... 5 6 8]
INFO:tensorflow:Calling model_fn.
#### Tensor("dense_1/BiasAdd:0", shape=(100, 10), dtype=float32, device=/job:ps/task:0)
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.