如何帮助tqdm找出自定义迭代器中的总数

时间:2018-03-12 21:45:18

标签: python tqdm

我正在实现自己的迭代器。 tqdm不显示进度条,因为它不知道列表中元素的总量。我不想使用“total =”,因为它看起来很难看。相反,我更愿意在我的迭代器中添加一些东西,tqdm可以用来计算总数。

class Batches:
    def __init__(self, batches, target_input):
        self.batches = batches
        self.pos = 0
        self.target_input = target_input

    def __iter__(self):
        return self

    def __next__(self):
        if self.pos < len(self.batches):
            minibatch = self.batches[self.pos]
            target = minibatch[:, :, self.target_input]
            self.pos += 1
            return minibatch, target
        else:
            raise StopIteration

    def __len__(self):
        return self.batches.len()

这甚至可能吗?要添加到上面的代码...

使用如下的tqdm ..

for minibatch, target in tqdm(Batches(test, target_input)):

    output = lstm(minibatch)
    loss = criterion(output, target)
    writer.add_scalar('loss', loss, tensorboard_step)

1 个答案:

答案 0 :(得分:1)

我知道已经有一段时间了,但是我一直在寻找相同的答案,这是解决方案。而不是像这样用tqdm包装您的可迭代对象

for i in tqdm(my_iterable):
    do_something()

改为使用“ with”关闭,例如:

with tqdm(total=len(my_iterable)) as progress_bar:
    for i in my_iterable:
        do_something()
        progress_bar.update(1) # update progress

对于批次,您可以将总数设置为批次数,并更新为1(如上所述)。或者,您可以将总数设置为实际的项目总数,将更新设置为当前已处理批次的大小。