问题
在Tensorflow中,我经常在第一个训练时期遇到OOM错误。然而,网络的大规模性质导致第一个时期需要大约一个小时,很快就能快速测试新的超参数。
理想情况下,我希望能够对迭代器进行排序,以便我可以在最大批次上运行get_next()
一次。
我该怎么做?或者也许有更好的方法可以尽早实现失败?
迭代器的格式为:(source, tgt_in, tgt_out, key_weights, source_len, target_len)
我希望按目标长度排序。在返回之前进行填充和批量处理。
数据集是一个句子列表,有相似的长度。我想找到迭代器中最大的批处理并只运行它。
部分代码
如果初始化程序每次都没有对迭代器进行洗牌,则下面的代码将起作用,从而破坏了有关最大批次位置的信息。我不太确定如何修改它 - 只要使用get_next()
读取批处理的长度,它就已经“弹出”并且不能再用作模型的输入。 / p>
def verify_hparams():
train_sess.run(train_model.iterator.initializer)
max_index = -1
max_len = 0
for batch in itertools.count():
try:
batch_len = np.amax(train_sess.run(train_model.iterator.get_next()[-1]))
if batch_len > max_len:
max_len = batch_len
max_index = batch
except tf.errors.OutOfRangeError:
num_batches = batch + 1
break
for batch in range(-1, num_batches-1):
try:
if batch is max_index:
_, _ = loaded_train_model.train(train_sess)
else:
train_sess.run(train_model.iterator.get_next())
except tf.errors.OutOfRangeError:
break
return num_batches
答案 0 :(得分:1)
您需要的是“偷看”操作。大多数语言都有迭代器,允许您查看是否有更多数据(类似 sudo certbot --authenticator standalone --installer apache -d <yourdomain(s)> --pre-hook "apache2ctl stop" --post-hook "apache2ctl start"
)。但您要求的功能基本上类似于iterator.hasNext()
。据我所知,张量流迭代器don't have such functionality。
此外,这样的功能不太可能被添加,因为我可以想象有些生成器无法提供这样的功能,因此添加此功能会破坏向后兼容性。