我遵循了Tensorflow seq2seq模块的教程,特别是神经机器翻译模块:https://google.github.io/seq2seq/nmt/
它正在工作,除了它只使用两个可用的一个GPU(集群分配给我gpus 5和6,它只使用6)。这是nvidia-smi命令输出的快照。
在上述教程中" Training":
中说明以下内容使用tf.learn开箱即用支持Distributed Training。可以使用TF_CONFIG环境变量指定群集配置,该变量由RunConfig解析。有关详细信息,请参阅“分布式Tensorflow指南”。
我搜索了很多资源,但没有一个人能清楚地说明我需要在代码中调整什么才能使用它。请指教。
基本上我问:
(1)TF_CONFIG的值应该是多少?我怎么知道他们? (2)我应该如何编辑train.py?
train.py:
#! /usr/bin/env python
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Main script to run training and evaluation of models.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import os
import tempfile
import yaml
import tensorflow as tf
from tensorflow.contrib.learn.python.learn import learn_runner
from tensorflow.contrib.learn.python.learn.estimators import run_config
from tensorflow import gfile
from seq2seq import models
from seq2seq.contrib.experiment import Experiment as PatchedExperiment
from seq2seq.configurable import _maybe_load_yaml, _create_from_dict
from seq2seq.configurable import _deep_merge_dict
from seq2seq.data import input_pipeline
from seq2seq.metrics import metric_specs
from seq2seq.training import hooks
from seq2seq.training import utils as training_utils
tf.flags.DEFINE_string("config_paths", "",
"""Path to a YAML configuration files defining FLAG
values. Multiple files can be separated by commas.
Files are merged recursively. Setting a key in these
files is equivalent to setting the FLAG value with
the same name.""")
tf.flags.DEFINE_string("hooks", "[]",
"""YAML configuration string for the
training hooks to use.""")
tf.flags.DEFINE_string("metrics", "[]",
"""YAML configuration string for the
training metrics to use.""")
tf.flags.DEFINE_string("model", "",
"""Name of the model class.
Can be either a fully-qualified name, or the name
of a class defined in `seq2seq.models`.""")
tf.flags.DEFINE_string("model_params", "{}",
"""YAML configuration string for the model
parameters.""")
tf.flags.DEFINE_string("input_pipeline_train", "{}",
"""YAML configuration string for the training
data input pipeline.""")
tf.flags.DEFINE_string("input_pipeline_dev", "{}",
"""YAML configuration string for the development
data input pipeline.""")
tf.flags.DEFINE_string("buckets", None,
"""Buckets input sequences according to these length.
A comma-separated list of sequence length buckets, e.g.
"10,20,30" would result in 4 buckets:
<10, 10-20, 20-30, >30. None disabled bucketing. """)
tf.flags.DEFINE_integer("batch_size", 16,
"""Batch size used for training and evaluation.""")
tf.flags.DEFINE_string("output_dir", None,
"""The directory to write model checkpoints and summaries
to. If None, a local temporary directory is created.""")
# Training parameters
tf.flags.DEFINE_string("schedule", "continuous_train_and_eval",
"""Estimator function to call, defaults to
continuous_train_and_eval for local run""")
tf.flags.DEFINE_integer("train_steps", None,
"""Maximum number of training steps to run.
If None, train forever.""")
tf.flags.DEFINE_integer("eval_every_n_steps", 1000,
"Run evaluation on validation data every N steps.")
# RunConfig Flags
tf.flags.DEFINE_integer("tf_random_seed", None,
"""Random seed for TensorFlow initializers. Setting
this value allows consistency between reruns.""")
tf.flags.DEFINE_integer("save_checkpoints_secs", None,
"""Save checkpoints every this many seconds.
Can not be specified with save_checkpoints_steps.""")
tf.flags.DEFINE_integer("save_checkpoints_steps", None,
"""Save checkpoints every this many steps.
Can not be specified with save_checkpoints_secs.""")
tf.flags.DEFINE_integer("keep_checkpoint_max", 5,
"""Maximum number of recent checkpoint files to keep.
As new files are created, older files are deleted.
If None or 0, all checkpoint files are kept.""")
tf.flags.DEFINE_integer("keep_checkpoint_every_n_hours", 4,
"""In addition to keeping the most recent checkpoint
files, keep one checkpoint file for every N hours of
training.""")
tf.flags.DEFINE_float("gpu_memory_fraction", 1.0,
"""Fraction of GPU memory used by the process on
each GPU uniformly on the same machine.""")
tf.flags.DEFINE_boolean("gpu_allow_growth", False,
"""Allow GPU memory allocation to grow
dynamically.""")
tf.flags.DEFINE_boolean("log_device_placement", False,
"""Log the op placement to devices""")
FLAGS = tf.flags.FLAGS
def create_experiment(output_dir):
"""
Creates a new Experiment instance.
Args:
output_dir: Output directory for model checkpoints and summaries.
"""
config = run_config.RunConfig(
tf_random_seed=FLAGS.tf_random_seed,
save_checkpoints_secs=FLAGS.save_checkpoints_secs,
save_checkpoints_steps=FLAGS.save_checkpoints_steps,
keep_checkpoint_max=FLAGS.keep_checkpoint_max,
keep_checkpoint_every_n_hours=FLAGS.keep_checkpoint_every_n_hours,
gpu_memory_fraction=FLAGS.gpu_memory_fraction)
config.tf_config.gpu_options.allow_growth = FLAGS.gpu_allow_growth
config.tf_config.log_device_placement = FLAGS.log_device_placement
train_options = training_utils.TrainOptions(
model_class=FLAGS.model,
model_params=FLAGS.model_params)
# On the main worker, save training options
if config.is_chief:
gfile.MakeDirs(output_dir)
train_options.dump(output_dir)
bucket_boundaries = None
if FLAGS.buckets:
bucket_boundaries = list(map(int, FLAGS.buckets.split(",")))
# Training data input pipeline
train_input_pipeline = input_pipeline.make_input_pipeline_from_def(
def_dict=FLAGS.input_pipeline_train,
mode=tf.contrib.learn.ModeKeys.TRAIN)
# Create training input function
train_input_fn = training_utils.create_input_fn(
pipeline=train_input_pipeline,
batch_size=FLAGS.batch_size,
bucket_boundaries=bucket_boundaries,
scope="train_input_fn")
# Development data input pipeline
dev_input_pipeline = input_pipeline.make_input_pipeline_from_def(
def_dict=FLAGS.input_pipeline_dev,
mode=tf.contrib.learn.ModeKeys.EVAL,
shuffle=False, num_epochs=1)
# Create eval input function
eval_input_fn = training_utils.create_input_fn(
pipeline=dev_input_pipeline,
batch_size=FLAGS.batch_size,
allow_smaller_final_batch=True,
scope="dev_input_fn")
def model_fn(features, labels, params, mode):
"""Builds the model graph"""
model = _create_from_dict({
"class": train_options.model_class,
"params": train_options.model_params
}, models, mode=mode)
return model(features, labels, params)
estimator = tf.contrib.learn.Estimator(
model_fn=model_fn,
model_dir=output_dir,
config=config,
params=FLAGS.model_params)
# Create hooks
train_hooks = []
for dict_ in FLAGS.hooks:
hook = _create_from_dict(
dict_, hooks,
model_dir=estimator.model_dir,
run_config=config)
train_hooks.append(hook)
# Create metrics
eval_metrics = {}
for dict_ in FLAGS.metrics:
metric = _create_from_dict(dict_, metric_specs)
eval_metrics[metric.name] = metric
experiment = PatchedExperiment(
estimator=estimator,
train_input_fn=train_input_fn,
eval_input_fn=eval_input_fn,
min_eval_frequency=FLAGS.eval_every_n_steps,
train_steps=FLAGS.train_steps,
eval_steps=None,
eval_metrics=eval_metrics,
train_monitors=train_hooks)
return experiment
def main(_argv):
"""The entrypoint for the script"""
# Parse YAML FLAGS
FLAGS.hooks = _maybe_load_yaml(FLAGS.hooks)
FLAGS.metrics = _maybe_load_yaml(FLAGS.metrics)
FLAGS.model_params = _maybe_load_yaml(FLAGS.model_params)
FLAGS.input_pipeline_train = _maybe_load_yaml(FLAGS.input_pipeline_train)
FLAGS.input_pipeline_dev = _maybe_load_yaml(FLAGS.input_pipeline_dev)
# Load flags from config file
final_config = {}
if FLAGS.config_paths:
for config_path in FLAGS.config_paths.split(","):
config_path = config_path.strip()
if not config_path:
continue
config_path = os.path.abspath(config_path)
tf.logging.info("Loading config from %s", config_path)
with gfile.GFile(config_path.strip()) as config_file:
config_flags = yaml.load(config_file)
final_config = _deep_merge_dict(final_config, config_flags)
tf.logging.info("Final Config:\n%s", yaml.dump(final_config))
# Merge flags with config values
for flag_key, flag_value in final_config.items():
if hasattr(FLAGS, flag_key) and isinstance(getattr(FLAGS, flag_key), dict):
merged_value = _deep_merge_dict(flag_value, getattr(FLAGS, flag_key))
setattr(FLAGS, flag_key, merged_value)
elif hasattr(FLAGS, flag_key):
setattr(FLAGS, flag_key, flag_value)
else:
tf.logging.warning("Ignoring config flag: %s", flag_key)
if FLAGS.save_checkpoints_secs is None \
and FLAGS.save_checkpoints_steps is None:
FLAGS.save_checkpoints_secs = 600
tf.logging.info("Setting save_checkpoints_secs to %d",
FLAGS.save_checkpoints_secs)
if not FLAGS.output_dir:
FLAGS.output_dir = tempfile.mkdtemp()
if not FLAGS.input_pipeline_train:
raise ValueError("You must specify input_pipeline_train")
if not FLAGS.input_pipeline_dev:
raise ValueError("You must specify input_pipeline_dev")
learn_runner.run(
experiment_fn=create_experiment,
output_dir=FLAGS.output_dir,
schedule=FLAGS.schedule)
if __name__ == "__main__":
tf.logging.set_verbosity(tf.logging.INFO)
tf.app.run()