Tensorflow 1.4.1估计器未从input_fn获取下一个值

时间:2018-01-31 20:36:32

标签: tensorflow machine-learning tensorflow-estimator tfrecord

我正在使用带有TFRecords的estimator API来调整我在TF 1.2中使用的代码。但是input_fn总是只返回第一批,因此永远不会完成(或进展)。

def gen_input(filename):
  def decode(line):
    features = {
      'x': tf.FixedLenFeature((3,), tf.float32),
      'y': tf.FixedLenFeature((), tf.int64)
    }
    parsed = tf.parse_single_example(line, features)
    return parsed['x'], parsed['y']

  def input_fn():
    dataset = (tf.data.TFRecordDataset([filename])).map(decode)
    dataset = dataset.repeat(1)
    dataset = dataset.batch(2)

    iterator = dataset.make_one_shot_iterator()
    features, labels = iterator.get_next()
    return features, labels

  return input_fn


estimator = tf.estimator.Estimator(
    model_dir=model_dir,
    model_fn=model_fn,
    params={})

train_input_fn = gen_input('train.tfrecord')
eval_input_fn = gen_input('eval.tfrecord')

train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn)
eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn)

tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)

我的TFRecord中只有5个项目,因此dataset.batch(2)dataset.repeat(1)一起应该使模型在3个步骤后完成。此外,我正在将每个步骤中提供的features记录到我的model_fn进入Tensorboard。记录的值始终相同。

我在这里做错了什么?

enter image description here

0 个答案:

没有答案