TRN-Pytorch模型-RuntimeError:尺寸超出范围(预计在[-1,0]范围内,但得到1)

时间:2019-03-15 06:05:26

标签: github pytorch pre-trained-model torchvision

我在colab中使用TRN-Pytorch model,而pytorch版本是0.4.1。训练模型时,我遇到了 RuntimeError:维度超出范围(预计在[-1,0]范围内,但有1)

这是培训代码

!python3 main.py something RGB \
                     --arch BNInception --num_segments 3 \
                     --consensus_type TRN --batch-size 2

我收到此错误

storing name: TRN_something_RGB_BNInception_TRN_segment3

    Initializing TSN with base model: BNInception.
    TSN Configurations:
        input_modality:     RGB
        num_segments:       3
        new_length:         1
        consensus_module:   TRN
        dropout_ratio:      0.8
        img_feature_dim:    256

/content/drive/My Drive/TRN-pytorch/models.py:87: UserWarning: nn.init.normal is now deprecated in favor of nn.init.normal_.
  normal(self.new_fc.weight, 0, std)
/content/drive/My Drive/TRN-pytorch/models.py:88: UserWarning: nn.init.constant is now deprecated in favor of nn.init.constant_.
  constant(self.new_fc.bias, 0)
video number:4
/usr/local/lib/python3.6/dist-packages/torchvision/transforms/transforms.py:208: UserWarning: The use of the transforms.Scale transform is deprecated, please use transforms.Resize instead.
  "please use transforms.Resize instead.")
video number:1
group: first_conv_weight has 1 params, lr_mult: 1, decay_mult: 1
group: first_conv_bias has 1 params, lr_mult: 2, decay_mult: 0
group: normal_weight has 71 params, lr_mult: 1, decay_mult: 1
group: normal_bias has 71 params, lr_mult: 2, decay_mult: 0
group: BN scale/shift has 2 params, lr_mult: 1, decay_mult: 0
Freezing BatchNorm2D except the first one.
Traceback (most recent call last):
  File "main.py", line 324, in <module>
    main()
  File "main.py", line 128, in main
    train(train_loader, model, criterion, optimizer, epoch, log_training)
  File "main.py", line 175, in train
    prec1, prec5 = accuracy(output.data, target, topk=(1,5))
  File "main.py", line 301, in accuracy
    batch_size = target.size(1)
RuntimeError: Dimension out of range (expected to be in range of [-1, 0], but got 1)

main.py文件为here

帮我解决这个问题

0 个答案:

没有答案