我正在实现自己的迭代器。 tqdm不显示进度条,因为它不知道列表中元素的总量。我不想使用“total =”,因为它看起来很难看。相反,我更愿意在我的迭代器中添加一些东西,tqdm可以用来计算总数。
class Batches:
def __init__(self, batches, target_input):
self.batches = batches
self.pos = 0
self.target_input = target_input
def __iter__(self):
return self
def __next__(self):
if self.pos < len(self.batches):
minibatch = self.batches[self.pos]
target = minibatch[:, :, self.target_input]
self.pos += 1
return minibatch, target
else:
raise StopIteration
def __len__(self):
return self.batches.len()
这甚至可能吗?要添加到上面的代码...
使用如下的tqdm ..
for minibatch, target in tqdm(Batches(test, target_input)):
output = lstm(minibatch)
loss = criterion(output, target)
writer.add_scalar('loss', loss, tensorboard_step)
答案 0 :(得分:1)
我知道已经有一段时间了,但是我一直在寻找相同的答案,这是解决方案。而不是像这样用tqdm包装您的可迭代对象
for i in tqdm(my_iterable):
do_something()
改为使用“ with”关闭,例如:
with tqdm(total=len(my_iterable)) as progress_bar:
for i in my_iterable:
do_something()
progress_bar.update(1) # update progress
对于批次,您可以将总数设置为批次数,并更新为1(如上所述)。或者,您可以将总数设置为实际的项目总数,将更新设置为当前已处理批次的大小。