在我的模型(使用Tensorflow Estimator)中,我希望数据馈送更加动态。例如。在训练期间提供不同的数据(在不同的训练步骤中,向模型提供了不同的数据)。
一个类似于以下代码的示例。 get_input_fn提供input_fn并由_parse函数处理这些功能。 _parse中的_py_process_line_pair进行了精确的处理。但是我不确定如何传递global_step(或相关参数到_py_process_line_pair中)
def _parse(self, features):
def _py_process_line_pair(src_wds, trg_wds, cur_training_steps):
.... (some processing depends on cur_training_steps)
return np.array(src_ids, np.int32), np.array(trg_ids, np.int32)
src_wds, trg_wds = features['src_wds'], features['trg_wds']
src_ids, trg_ids = tf.py_func(
_py_process_line_pair,
[src_wds, trg_wds],
[tf.int32, tf.int32])
src_ids.set_shape(
[self.flags.max_src_len])
trg_ids.set_shape(
[self.flags.max_trg_len])
output = {
'src_ids': src_ids,
'trg_ids': trg_ids,
}
return output
def get_input_fn(self, is_training, input_files, num_cpu_threads):
def input_fn(params):
batch_size = params['batch_size']
if is_training:
d = tf.data.Dataset.from_tensor_slices(tf.constant(tf.gfile.Glob(input_files)))
d = d.repeat()
d = d.shuffle(buffer_size=len(input_files))
cycle_length = min(num_cpu_threads, len(input_files))
d = d.apply(
tf.data.experimental.parallel_interleave(
tf.data.TFRecordDataset,
sloppy=is_training,
cycle_length=cycle_length))
d = d.shuffle(buffer_size=100)
else:
d = tf.data.TFRecordDataset(input_files)
d = d.apply(
tf.data.experimental.map_and_batch(
lambda record: self._parse(tf.parse_single_example(record, self.feature_set)),
batch_size=batch_size,
num_parallel_batches=num_cpu_threads,
drop_remainder=is_training))
return d
return input_fn
答案 0 :(得分:0)
这非常简单:您只需要在_parse
函数内部,使用global_step
从图中获取tf.train.get_or_create_global_step()
张量。
这是一个可行的示例
import tensorflow as tf
import numpy as np
# Synth dataset with 10 values
x = np.arange(10)
# This function replaces 'x' by the current step
def step_dependant_preprocessing(x):
global_step = tf.train.get_or_create_global_step()
return global_step
# Maps step_dependant_preprocessing
def input_fn():
dataset = tf.data.Dataset.from_tensor_slices((x))
dataset = dataset.map(step_dependant_preprocessing)
return dataset
def model_fn(features, labels, mode, params=None):
# Get the global step
global_step = tf.train.get_or_create_global_step()
# Since this example doesn't use an optimizer, we need to increment
# the global step manually.
increment_global_step = tf.assign_add(global_step, 1)
# Logging hook to verify that the global step inside the input fn has
# the same value as the one here.
logging_hook = tf.train.LoggingTensorHook({"true_global_step": global_step,
"input_fn_global_step": features},
every_n_iter=1)
return tf.estimator.EstimatorSpec(
mode=tf.estimator.ModeKeys.TRAIN,
loss=tf.constant(0.0), # Needed to use estimator.train()
training_hooks=[logging_hook],
train_op=increment_global_step)
estimator = tf.estimator.Estimator(model_fn=model_fn)
estimator.train(input_fn)
...
# Output
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmppuwe9hxh/model.ckpt.
INFO:tensorflow:loss = 0.0, step = 1
INFO:tensorflow:input_fn_global_step = 1, true_global_step = 1
INFO:tensorflow:input_fn_global_step = 2, true_global_step = 2 (0.007 sec)
INFO:tensorflow:input_fn_global_step = 3, true_global_step = 3 (0.002 sec)
INFO:tensorflow:input_fn_global_step = 4, true_global_step = 4 (0.001 sec)
INFO:tensorflow:input_fn_global_step = 5, true_global_step = 5 (0.001 sec)
INFO:tensorflow:input_fn_global_step = 6, true_global_step = 6 (0.001 sec)
INFO:tensorflow:input_fn_global_step = 7, true_global_step = 7 (0.001 sec)
INFO:tensorflow:input_fn_global_step = 8, true_global_step = 8 (0.001 sec)
INFO:tensorflow:input_fn_global_step = 9, true_global_step = 9 (0.001 sec)
INFO:tensorflow:input_fn_global_step = 10, true_global_step = 10 (0.001 sec)
INFO:tensorflow:Saving checkpoints for 11 into /tmp/tmppuwe9hxh/model.ckpt.
INFO:tensorflow:Loss for final step: 0.0.