现在我使用tf.contrib.learn.Experiment, Estimator, learn_runner
来帮助训练模型。运行learn_runner
时,它会隐式创建tf.MoniteredSession
并调用其run()
函数,因此我无法将参数options
和run_metadata
添加到run()
功能。
那么如何将options
和run_metadata
args添加到run
函数并调用summary_writer.add_run_metadata()
?
我在网上搜索了很长时间。但没用。请帮助或尝试提供一些如何实现这一目标的想法。
这是代码:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import numpy as np
from tensorflow.contrib import slim, training, learn
tf.logging.set_verbosity(tf.logging.DEBUG)
def variable_summaries(var):
with tf.name_scope(var.name.split(':')[0]):
mean = tf.reduce_mean(var)
with tf.name_scope('stddev'):
stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean)))
tf.add_to_collection('variable_summaries', tf.summary.scalar('mean', mean))
tf.add_to_collection('variable_summaries', tf.summary.scalar('stddev', stddev))
tf.add_to_collection('variable_summaries', tf.summary.scalar('max', tf.reduce_max(var)))
tf.add_to_collection('variable_summaries', tf.summary.scalar('min', tf.reduce_min(var)))
tf.add_to_collection('variable_summaries', tf.summary.histogram('histogram', var))
def model_fn(features, labels, mode, params):
id_ts = tf.get_collection('id_ts')[0]
fc1 = slim.fully_connected(features, 10, tf.nn.relu, scope='fc1')
variable_summaries(fc1)
fc2 = slim.fully_connected(fc1, 2, None, scope='fc2')
variable_summaries(fc2)
for i in tf.trainable_variables():
variable_summaries(i)
logits = fc2
prob = tf.nn.softmax(logits)
predictions = tf.argmax(logits, axis=1)
summay_op = tf.summary.merge_all('variable_summaries')
scaffold = tf.train.Scaffold(summary_op=summay_op)
if mode == learn.ModeKeys.TRAIN or mode == learn.ModeKeys.EVAL:
onehot_labels = slim.one_hot_encoding(labels, 2)
loss = tf.losses.softmax_cross_entropy(logits=logits, onehot_labels=onehot_labels)
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
train_op = optimizer.minimize(loss, slim.get_global_step())
eval_metric_ops = {
'accuracy': tf.metrics.accuracy(labels, predictions),
'auc': tf.metrics.auc(labels, predictions),
'precision': tf.metrics.precision(labels, predictions),
'recall': tf.metrics.recall(labels, predictions),
}
return learn.ModelFnOps(mode=mode,
predictions=predictions,
loss=loss,
train_op=train_op,
eval_metric_ops=eval_metric_ops,
scaffold=scaffold)
elif mode == learn.ModeKeys.INFER:
return learn.ModelFnOps(mode=mode, predictions={'prob': prob,
'fc1': fc1,
'fc2': fc2,
'id': id_ts})
def train_input_fn():
fn = tf.train.string_input_producer(['data.csv'])
reader = tf.TextLineReader()
key, value = reader.read(fn)
data_ts = tf.decode_csv(value, [[0.], [0.], [0.], [0.]], field_delim=',')
batch_ts = tf.train.shuffle_batch(data_ts, 10, 1000, 10)
id_ts = batch_ts[0]
tf.add_to_collection('id_ts', id_ts)
features_ts = tf.concat([tf.reshape(batch_ts[1], [-1, 1]), tf.reshape(batch_ts[2], [-1, 1])], axis=1)
labels_ts = tf.cast(batch_ts[3], tf.int32)
return features_ts, labels_ts
def eval_input_fn():
fn = tf.train.string_input_producer(['data.csv'])
reader = tf.TextLineReader()
key, value = reader.read(fn)
data_ts = tf.decode_csv(value, [[0.], [0.], [0.], [0.]], field_delim=',')
batch_ts = tf.train.batch(data_ts, 10, 1000)
id_ts = batch_ts[0]
tf.add_to_collection('id_ts', id_ts)
features_ts = tf.concat([tf.reshape(batch_ts[1], [-1, 1]), tf.reshape(batch_ts[2], [-1, 1])], axis=1)
labels_ts = tf.cast(batch_ts[3], tf.int32)
return features_ts, labels_ts
def run_experiment(_):
session_config = tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True),
log_device_placement=False)
run_config = learn.RunConfig(save_checkpoints_steps=100,
model_dir='model_dir',
session_config=session_config,
keep_checkpoint_max=2)
hparams = training.HParams(train_steps=1000)
learn.learn_runner.run(experiment_fn=create_experiment_fn,
schedule='train_and_evaluate',
run_config=run_config,
hparams=hparams)
def create_experiment_fn(run_config, hparams):
estimator = get_estimator_fn(config=run_config, params=hparams)
return learn.Experiment(estimator=estimator,
train_input_fn=train_input_fn,
eval_input_fn=eval_input_fn,
train_steps=hparams.train_steps)
def get_estimator_fn(config, params):
return learn.Estimator(model_fn=model_fn,
model_dir=config.model_dir,
config=config,
params=params)
if __name__ == '__main__':
tf.app.run(main=run_experiment)