分布式估算器永远挂起

时间:2018-06-10 15:30:32

标签: tensorflow-estimator

我用估算器运行了mnist代码,在单个进程中运行时它没问题,但在分布式模式的3个进程中运行时挂起。代码(PS)如下所示。请注意,worker应该将TF_CONFIG中的task.type更改为' worker',并且Chief应该更改为' chief'。

import tensorflow as tf
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import matplotlib.cm as cm

tf.logging.set_verbosity(tf.logging.INFO)

def load2():
train = pd.read_csv('/tmp/train.csv')

test = pd.read_csv('/tmp/test.csv')
labels = train['label']
images = train.iloc[:, 1:]
image_size = 28

train_ds, valid_ds, train_labels, valid_labels = train_test_split(images, labels, test_size=0.33, random_state=42)

print len(train_ds)
print len(valid_ds)
print len(train_labels)
print len(valid_labels)

data = images

train_ds, train_labels = reformat(train_ds, train_labels)
valid_ds, valid_labels = reformat(valid_ds, valid_labels)
test_ds, test_labels = reformat(test, train_labels[:len(test)])

print('Training set', train_ds.shape, train_labels.shape)
print('Validation set', valid_ds.shape, valid_labels.shape)
print('Test set', test_ds.shape, test_labels.shape)


def reformat(dataset, labels):
    num_channels = 1
    num_labels = 10
    dataset = dataset.values.reshape(
        (-1, image_size, image_size, num_channels)).astype(np.float32)
    labels = (np.arange(num_labels) == labels[:, None]).astype(np.float32)
    return dataset, labels


def cnn_model_fn(features, labels, mode):
    input_layer = tf.reshape(features['x'], [-1, 28, 28, 1])
    conv1 = tf.layers.conv2d(
        inputs=input_layer,
        filters=32,
        kernel_size=[5, 5],
        padding="same",
        activation=tf.nn.relu
    )

    pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2)  # 14*14*32

    conv2 = tf.layers.conv2d(
        inputs=pool1,
        filters=64,
        kernel_size=[5, 5],
        padding='same',
        activation=tf.nn.relu
    )

    pool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2)
    pool2_flat = tf.reshape(pool2, [-1, 7 * 7 * 64])

    dense = tf.layers.dense(inputs=pool2_flat, units=1024, activation=tf.nn.relu)
    dropout = tf.layers.dropout(inputs=dense, rate=0.4, training=mode == tf.estimator.ModeKeys.TRAIN)

    logits = tf.layers.dense(inputs=dropout, units=10)
    print '####', logits

    predictions = {
        'classes': tf.argmax(logits, axis=1),
        'probabilities': tf.nn.softmax(logits, name='softmax_tensor')
    }

    if mode == tf.estimator.ModeKeys.PREDICT:
        return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)

    if mode == tf.estimator.ModeKeys.TRAIN:
        loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
        optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.001)
        train_op = optimizer.minimize(
            loss=loss,
            global_step=tf.train.get_global_step()
        )
        return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)

def main(argvs):
    mnist = tf.contrib.learn.datasets.load_dataset("mnist")
    train_data = mnist.train.images
    print train_data.shape
    train_labels = np.asarray(mnist.train.labels, dtype=np.int32)
    eval_data = mnist.test.images
    eval_labels = np.asarray(mnist.test.labels, dtype=np.int32)
    mnist_classifier = tf.estimator.Estimator(
        model_fn=cnn_model_fn, model_dir="/tmp/mnist_convnet_model")
    # Set up logging for predictions
    # Log the values in the "Softmax" tensor with label "probabilities"
    tensors_to_log = {"probabilities": "softmax_tensor"}
    logging_hook = tf.train.LoggingTensorHook(
        tensors=tensors_to_log, every_n_iter=50)
    print 'labels: ', train_labels
    # Train the model
    train_input_fn = tf.estimator.inputs.numpy_input_fn(
        x={"x": train_data},
        y=train_labels,
        batch_size=100,
        num_epochs=None,
        shuffle=True)
    mnist_classifier.train(
        input_fn=train_input_fn,
        steps=20000,
        hooks=[logging_hook])
    eval_input_fn = tf.estimator.inputs.numpy_input_fn(
        x={"x": eval_data},
        y=eval_labels,
        num_epochs=1,
        shuffle=False)


if __name__ == '__main__':
  import os,json
  os.environ['TF_CONFIG']=json.dumps({
      "cluster": {
        "ps": [
          "127.0.0.1:34567"
        ],
        "chief": [
          "127.0.0.1:34568"
        ],
        "worker": [
          "127.0.0.1:34569"
        ]
      },
      "task": {
        "index": 0,
        "type": "ps" # optional: chief, ps, worker
      }
  })
  tf.app.run()

这是日志:

WARNING:tensorflow:From ps_cnn_mnist.py:94: load_dataset (from tensorflow.contrib.learn.python.learn.datasets) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data.
WARNING:tensorflow:From /home/yiguang.wyg/tools/anaconda2/lib/python2.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/__init__.py:80: load_mnist (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
WARNING:tensorflow:From /home/yiguang.wyg/tools/anaconda2/lib/python2.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:300: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
WARNING:tensorflow:From /home/yiguang.wyg/tools/anaconda2/lib/python2.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Please write your own downloading logic.
WARNING:tensorflow:From /home/yiguang.wyg/tools/anaconda2/lib/python2.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting MNIST-data/train-images-idx3-ubyte.gz
WARNING:tensorflow:From /home/yiguang.wyg/tools/anaconda2/lib/python2.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting MNIST-data/train-labels-idx1-ubyte.gz
Extracting MNIST-data/t10k-images-idx3-ubyte.gz
Extracting MNIST-data/t10k-labels-idx1-ubyte.gz
WARNING:tensorflow:From /home/yiguang.wyg/tools/anaconda2/lib/python2.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:290: __init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
(55000, 784)
INFO:tensorflow:TF_CONFIG environment variable: {u'cluster': {u'ps': [u'127.0.0.1:34567'], u'chief': [u'127.0.0.1:34568'], u'worker': [u'127.0.0.1:34569']}, u'task': {u'index': 0, u'type': u'ps'}}
INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_save_checkpoints_secs': 600, '_session_config': None, '_keep_checkpoint_max': 5, '_task_type': u'ps', '_train_distribute': None, '_is_chief': False, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7f8b25008990>, '_evaluation_master': '', '_save_checkpoints_steps': None, '_keep_checkpoint_every_n_hours': 10000, '_service': None, '_num_ps_replicas': 1, '_tf_random_seed': None, '_master': u'grpc://127.0.0.1:34567', '_num_worker_replicas': 2, '_task_id': 0, '_log_step_count_steps': 100, '_model_dir': '/tmp/mnist_convnet_model', '_global_id_in_cluster': 2, '_save_summary_steps': 100}
labels:  [7 3 4 ... 5 6 8]
INFO:tensorflow:Calling model_fn.
#### Tensor("dense_1/BiasAdd:0", shape=(100, 10), dtype=float32, device=/job:ps/task:0)
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.

0 个答案:

没有答案