我如何使用tqdm使用tf.data.Dataset API可视化训练步骤的进度?

时间:2018-12-28 12:02:53

标签: tensorflow tqdm

我想使用tqdm可视化我的cnn网络培训步骤。
如何使用tqdm api实现tf.data.Dataset()
您能给我看示例代码吗?谢谢!

2 个答案:

答案 0 :(得分:0)

注意:虽然这是一个明显标记的问题,但我仍然认为这是一个有效的问题(我自己有过)并发布了解决方案。

一种可能性是先前获取数据集的基数,并将其用作 tqdm 中的 @mr.melon 状态。

cardinality = np.sum([1 for i in dataset.batch(batch_size)])

其中 dataset 属于 tf.data.Dataset 类并且您还没有完成完整的准备管道(我指的是交错、混洗、批处理和级长等)。

然后你可以

for input, label in tqdm(dataset, total=cardinality):
...

答案 1 :(得分:-1)

这很简单:

  1. 导出数据集中的样本数量,
  2. 然后,将数字转换为一些可迭代的python结构。

    for _ in tqdm(iterable=xxx, total=num_samples):
         batch_data = sess.run(ele_derived_from_tf_dataset)