如何在开始培训之前查找并运行数据集中的最大批次

时间:2018-02-02 01:09:09

标签: tensorflow tensorflow-datasets hyperparameters neural-mt

问题

在Tensorflow中,我经常在第一个训练时期遇到OOM错误。然而,网络的大规模性质导致第一个时期需要大约一个小时,很快就能快速测试新的超参数。

理想情况下,我希望能够对迭代器进行排序,以便我可以在最大批次上运行get_next()一次。

我该怎么做?或者也许有更好的方法可以尽早实现失败?

迭代器的格式为:(source, tgt_in, tgt_out, key_weights, source_len, target_len)我希望按目标长度排序。在返回之前进行填充和批量处理。

数据集是一个句子列表,有相似的长度。我想找到迭代器中最大的批处理并只运行它。

部分代码

如果初始化程序每次都没有对迭代器进行洗牌,则下面的代码将起作用,从而破坏了有关最大批次位置的信息。我不太确定如何修改它 - 只要使用get_next()读取批处理的长度,它就已经“弹出”并且不能再用作模型的输入。 / p>

def verify_hparams():
    train_sess.run(train_model.iterator.initializer)
    max_index = -1
    max_len = 0
    for batch in itertools.count():
        try:
            batch_len = np.amax(train_sess.run(train_model.iterator.get_next()[-1]))
            if batch_len > max_len:
                max_len = batch_len
                max_index = batch

        except tf.errors.OutOfRangeError:
            num_batches = batch + 1
            break

    for batch in range(-1, num_batches-1):
        try:
            if batch is max_index:
                _, _ = loaded_train_model.train(train_sess)
            else:
                train_sess.run(train_model.iterator.get_next())

        except tf.errors.OutOfRangeError:
            break

    return num_batches

1 个答案:

答案 0 :(得分:1)

您需要的是“偷看”操作。大多数语言都有迭代器,允许您查看是否有更多数据(类似 sudo certbot --authenticator standalone --installer apache -d <yourdomain(s)> --pre-hook "apache2ctl stop" --post-hook "apache2ctl start" )。但您要求的功能基本上类似于iterator.hasNext()。据我所知,张量流迭代器don't have such functionality

此外,这样的功能不太可能被添加,因为我可以想象有些生成器无法提供这样的功能,因此添加此功能会破坏向后兼容性。