在tensorflow中,如何枚举训练数据(与pytorch相比)

时间:2018-12-23 18:13:30

标签: tensorflow pytorch

在pytorch中,这就是我枚举训练数据的方式。

for epoch in range(0, args.epoches):
    for i, batch in enumerate(train_data):
        model.update(batch)

train_data包含多个batch,并且枚举并更新了批次,这对我来说很清楚。


我认为这是张量流如何处理批次的基本示例。

for step in range(num_steps):
    batch_data, batch_labels = generate_batch(batch_size, num_skips, skip_window)
    feed_dict = {train_dataset : batch_data, train_labels : batch_labels}
    _, l = session.run([optimizer, loss], feed_dict=feed_dict)

也许这是一个非常明显的问题,但是我不清楚session.run如何在张量流中处理训练批次。我找不到批处理在代码中循环通过。我所看到的只是feed_dict,我认为它可以处理循环。

有人可以阐明这一点吗?

1 个答案:

答案 0 :(得分:1)

TensorFlow为此具有一个History对象。您从History方法中获得了model.fit()对象作为返回。

History对象及其History.history属性记录了连续时期训练损失值和度量值以及验证损失值和验证度量值(如果适用)的记录​​。

希望这就是您所需要的。