Tensorflow 1.10+:将纪元传递给估计器input_fn?

时间:2019-06-14 08:38:16

标签: python tensorflow tensorflow-estimator

tf.estimator input_fn的签名可能看起来像这样:

def input_fn(files:list, params:dict):
    dataset = tf.data.TFRecordDataset(files)
                .map(lambda record: parse_record_fn(record))

    if params['mode'] == 'train':
        # train specific things
    # ...

这样的定义允许一个人按如下方式构造其所有input_fn

train_fn = lambda: input_fn(files['training_set'], {**params, **{"mode": "train"}})
valid_fn = lambda: input_fn(files['validation_set'], {**params, **{"mode": "eval"}})
test_fn  = lambda: input_fn(files['test_set'],  {**params, **{"mode": "test"}})


train_spec = tf.estimator.TrainSpec(input_fn=train_fn, ...)
eval_spec  = tf.estimator.EvalSpec(input_fn=valid_fn,  ...)  

我的问题是,如何更改input_fn签名以允许基于时代的变化。我了解这可能会带来瓶颈,但是如果我可以做以下事情会很好:


def input_fn(...):
    # see above

    epoch = params["epoch"]
    if epoch % 100 == 0:
        # modify or make a new dataset

    # ...
    return dataset.make_one_shot_iterator().get_next()

关键是确保input_fn仍与以下设备兼容:

tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)

1 个答案:

答案 0 :(得分:1)

我不知道有任何提供epoch数字作为参数的选项。

也就是说,根据定义,纪元是输入函数的功能,因此我们应该只能够处理输入函数中的所有内容,而不是完全可以访问训练参数。因此,我认为您只需稍微摆弄一下就可以实现所需的功能。

例如,如果我有两个数据集:ds1ds2说并且想在{epoch}数字不能被100整除时使用ds1,那么我可以创建一个新数据集通过执行类似的操作:

dataset = ds1.repeat(99).concatenate(ds2)

由于默认情况下延迟加载数据集,所以我不必担心内存的问题(我不会将100倍的数据加载到内存中)。

显然,这确实对数据集的大小有影响,因此您需要考虑在评估操作/回调等之间进行操作的策略,但是应该很容易进行调整。