如何在使用tf.contrib.learn模块时创建tf.RunMetadata并将其添加到writer

时间:2017-10-17 09:03:06

标签: tensorflow

现在我使用tf.contrib.learn.Experiment, Estimator, learn_runner来帮助训练模型。运行learn_runner时,它会隐式创建tf.MoniteredSession并调用其run()函数,因此我无法将参数optionsrun_metadata添加到run()功能。

那么如何将optionsrun_metadata args添加到run函数并调用summary_writer.add_run_metadata()

我在网上搜索了很长时间。但没用。请帮助或尝试提供一些如何实现这一目标的想法。

这是代码:

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
import numpy as np

from tensorflow.contrib import slim, training, learn

tf.logging.set_verbosity(tf.logging.DEBUG)


def variable_summaries(var):
    with tf.name_scope(var.name.split(':')[0]):
        mean = tf.reduce_mean(var)
        with tf.name_scope('stddev'):
            stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean)))
        tf.add_to_collection('variable_summaries', tf.summary.scalar('mean', mean))
        tf.add_to_collection('variable_summaries', tf.summary.scalar('stddev', stddev))
        tf.add_to_collection('variable_summaries', tf.summary.scalar('max', tf.reduce_max(var)))
        tf.add_to_collection('variable_summaries', tf.summary.scalar('min', tf.reduce_min(var)))
        tf.add_to_collection('variable_summaries', tf.summary.histogram('histogram', var))


def model_fn(features, labels, mode, params):
    id_ts = tf.get_collection('id_ts')[0]
    fc1 = slim.fully_connected(features, 10, tf.nn.relu, scope='fc1')
    variable_summaries(fc1)
    fc2 = slim.fully_connected(fc1, 2, None, scope='fc2')
    variable_summaries(fc2)

    for i in tf.trainable_variables():
        variable_summaries(i)

    logits = fc2
    prob = tf.nn.softmax(logits)
    predictions = tf.argmax(logits, axis=1)

    summay_op = tf.summary.merge_all('variable_summaries')
    scaffold = tf.train.Scaffold(summary_op=summay_op)

    if mode == learn.ModeKeys.TRAIN or mode == learn.ModeKeys.EVAL:
        onehot_labels = slim.one_hot_encoding(labels, 2)
        loss = tf.losses.softmax_cross_entropy(logits=logits, onehot_labels=onehot_labels)

        optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)

        train_op = optimizer.minimize(loss, slim.get_global_step())

        eval_metric_ops = {
            'accuracy': tf.metrics.accuracy(labels, predictions),
            'auc': tf.metrics.auc(labels, predictions),
            'precision': tf.metrics.precision(labels, predictions),
            'recall': tf.metrics.recall(labels, predictions),
        }

        return learn.ModelFnOps(mode=mode,
                                predictions=predictions,
                                loss=loss,
                                train_op=train_op,
                                eval_metric_ops=eval_metric_ops,
                                scaffold=scaffold)
    elif mode == learn.ModeKeys.INFER:
        return learn.ModelFnOps(mode=mode, predictions={'prob': prob,
                                                        'fc1': fc1,
                                                        'fc2': fc2,
                                                        'id': id_ts})


def train_input_fn():
    fn = tf.train.string_input_producer(['data.csv'])
    reader = tf.TextLineReader()
    key, value = reader.read(fn)
    data_ts = tf.decode_csv(value, [[0.], [0.], [0.], [0.]], field_delim=',')
    batch_ts = tf.train.shuffle_batch(data_ts, 10, 1000, 10)
    id_ts = batch_ts[0]
    tf.add_to_collection('id_ts', id_ts)
    features_ts = tf.concat([tf.reshape(batch_ts[1], [-1, 1]), tf.reshape(batch_ts[2], [-1, 1])], axis=1)
    labels_ts = tf.cast(batch_ts[3], tf.int32)
    return features_ts, labels_ts


def eval_input_fn():
    fn = tf.train.string_input_producer(['data.csv'])
    reader = tf.TextLineReader()
    key, value = reader.read(fn)
    data_ts = tf.decode_csv(value, [[0.], [0.], [0.], [0.]], field_delim=',')
    batch_ts = tf.train.batch(data_ts, 10, 1000)
    id_ts = batch_ts[0]
    tf.add_to_collection('id_ts', id_ts)
    features_ts = tf.concat([tf.reshape(batch_ts[1], [-1, 1]), tf.reshape(batch_ts[2], [-1, 1])], axis=1)
    labels_ts = tf.cast(batch_ts[3], tf.int32)
    return features_ts, labels_ts


def run_experiment(_):
    session_config = tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True),
                                    log_device_placement=False)

    run_config = learn.RunConfig(save_checkpoints_steps=100,
                                 model_dir='model_dir',
                                 session_config=session_config,
                                 keep_checkpoint_max=2)

    hparams = training.HParams(train_steps=1000)

    learn.learn_runner.run(experiment_fn=create_experiment_fn,
                           schedule='train_and_evaluate',
                           run_config=run_config,
                           hparams=hparams)


def create_experiment_fn(run_config, hparams):
    estimator = get_estimator_fn(config=run_config, params=hparams)
    return learn.Experiment(estimator=estimator,
                            train_input_fn=train_input_fn,
                            eval_input_fn=eval_input_fn,
                            train_steps=hparams.train_steps)


def get_estimator_fn(config, params):
    return learn.Estimator(model_fn=model_fn,
                           model_dir=config.model_dir,
                           config=config,
                           params=params)


if __name__ == '__main__':
    tf.app.run(main=run_experiment)

0 个答案:

没有答案