我正在使用带有TFRecords的estimator API来调整我在TF 1.2中使用的代码。但是input_fn
总是只返回第一批,因此永远不会完成(或进展)。
def gen_input(filename):
def decode(line):
features = {
'x': tf.FixedLenFeature((3,), tf.float32),
'y': tf.FixedLenFeature((), tf.int64)
}
parsed = tf.parse_single_example(line, features)
return parsed['x'], parsed['y']
def input_fn():
dataset = (tf.data.TFRecordDataset([filename])).map(decode)
dataset = dataset.repeat(1)
dataset = dataset.batch(2)
iterator = dataset.make_one_shot_iterator()
features, labels = iterator.get_next()
return features, labels
return input_fn
estimator = tf.estimator.Estimator(
model_dir=model_dir,
model_fn=model_fn,
params={})
train_input_fn = gen_input('train.tfrecord')
eval_input_fn = gen_input('eval.tfrecord')
train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn)
eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn)
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
我的TFRecord中只有5个项目,因此dataset.batch(2)
和dataset.repeat(1)
一起应该使模型在3个步骤后完成。此外,我正在将每个步骤中提供的features
记录到我的model_fn
进入Tensorboard。记录的值始终相同。