tf.estimator
input_fn
的签名可能看起来像这样:
def input_fn(files:list, params:dict):
dataset = tf.data.TFRecordDataset(files)
.map(lambda record: parse_record_fn(record))
if params['mode'] == 'train':
# train specific things
# ...
这样的定义允许一个人按如下方式构造其所有input_fn
:
train_fn = lambda: input_fn(files['training_set'], {**params, **{"mode": "train"}})
valid_fn = lambda: input_fn(files['validation_set'], {**params, **{"mode": "eval"}})
test_fn = lambda: input_fn(files['test_set'], {**params, **{"mode": "test"}})
train_spec = tf.estimator.TrainSpec(input_fn=train_fn, ...)
eval_spec = tf.estimator.EvalSpec(input_fn=valid_fn, ...)
我的问题是,如何更改input_fn
签名以允许基于时代的变化。我了解这可能会带来瓶颈,但是如果我可以做以下事情会很好:
def input_fn(...):
# see above
epoch = params["epoch"]
if epoch % 100 == 0:
# modify or make a new dataset
# ...
return dataset.make_one_shot_iterator().get_next()
关键是确保input_fn
仍与以下设备兼容:
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
答案 0 :(得分:1)
我不知道有任何提供epoch
数字作为参数的选项。
也就是说,根据定义,纪元是输入函数的功能,因此我们应该只能够处理输入函数中的所有内容,而不是完全可以访问训练参数。因此,我认为您只需稍微摆弄一下就可以实现所需的功能。
例如,如果我有两个数据集:ds1
和ds2
说并且想在{epoch}数字不能被100整除时使用ds1
,那么我可以创建一个新数据集通过执行类似的操作:
dataset = ds1.repeat(99).concatenate(ds2)
由于默认情况下延迟加载数据集,所以我不必担心内存的问题(我不会将100倍的数据加载到内存中)。
显然,这确实对数据集的大小有影响,因此您需要考虑在评估操作/回调等之间进行操作的策略,但是应该很容易进行调整。