pytorch TypeError:在引用迭代器时,“ Example”和“ Example”的实例之间不支持“ <”

时间:2019-08-22 08:33:52

标签: python pytorch

我正在尝试使用自己的数据集根据https://github.com/bentrevett/pytorch-sentiment-analysis/blob/master/5%20-%20Multi-class%20Sentiment%20Analysis.ipynb对文本进行分类。我的数据集是句子的csv和与之相关的类。有6种不同的类别:

sent                      class
'the fox is brown'        animal
'the house is big'        object
'one water is drinkable'  water
...

运行时:

N_EPOCHS = 5

best_valid_loss = float('inf')

for epoch in range(N_EPOCHS):

    start_time = time.time()
    print(start_time)
    train_loss, train_acc = train(model, train_iterator, optimizer, criterion)
    #print(train_loss.type())
    #print(train_acc.type())
    valid_loss, valid_acc = evaluate(model, valid_iterator, criterion)

    end_time = time.time()

    epoch_mins, epoch_secs = epoch_time(start_time, end_time)

    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'tut5-model.pt')

    print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. Acc: {valid_acc*100:.2f}%')

,我获得了以下错误

1566460706.977012
epoch_loss
<torchtext.data.iterator.BucketIterator object at 0x000001FABE907E80>
TypeError: '<' not supported between instances of 'Example' and 'Example'

指向:

TypeError                                 Traceback (most recent call last)
<ipython-input-22-19e8a7eb204e> in <module>()
     10     #print(train_loss.type())
     11     #print(train_acc.type())
---> 12     valid_loss, valid_acc = evaluate(model, valid_iterator, criterion)
     13 
     14     end_time = time.time()

<ipython-input-21-83b02f99bca7> in evaluate(model, iterator, criterion)
      9         print('epoch_loss')
     10         print(iterator)
---> 11         for batch in iterator:
     12             print('batch')
     13             predictions = model(batch.text)

我是pytorch的新手,所以只添加了一行以识别迭代器数据类型并得到:

<torchtext.data.iterator.BucketIterator object at 0x000001FABE907E80>

我试图根据https://github.com/pytorch/text/blob/master/torchtext/data/iterator.py确定具体属性,但无济于事。

任何建议都值得赞赏。

评价(迭代器所在的)方法的代码为:

def evaluate(model, iterator, criterion):

    epoch_loss = 0
    epoch_acc = 0

    model.eval()

    with torch.no_grad():
        print('epoch_loss')
        print(iterator)
        for batch in iterator:
            print('batch')
            predictions = model(batch.text)

            loss = criterion(predictions, batch.label)

            acc = categorical_accuracy(predictions, batch.label)

            epoch_loss += loss.item()
            epoch_acc += acc.item()

    return epoch_loss / len(iterator), epoch_acc / len(iterator)

0 个答案:

没有答案