TypeError:forward()缺少1个必需的位置参数:SentimentalRNN中的“ hidden”

时间:2019-05-06 04:56:51

标签: python pytorch onnx

It is a SentimentalRNN using LSTM 
    for reference watch 
    https://www.youtube.com/watch?v=OE6wssMJoag&t=1318s   

'''完整代码https://colab.research.google.com/drive/1OoEbmik_saHOoosCghPZuv5VA_5s1U2r?authuser=1#scrollTo=mTmqmydFI4P0'''

        from torch.autograd import Variable
        import torch.onnx

        # Load the trained model from file
        vocab_size = len(vocab_to_int)+1 
        # +1 for the 0 padding + our word tokens
        output_size = 1
        embedding_dim = 400
        hidden_dim = 256
        n_layers = 2

        net_save= SentimentRNN(vocab_size,output_size, 
                               embedding_dim,hidden_dim, n_layers)
        trained_model = net_save
        trained_model.load_state_dict(torch.load('sentiment.pth'))

        # Export the trained model to ONNX
        dummy_input = Variable(torch.randn(1, 1, 28, 28))
         # one black and white 28 x 28 picture will be the input to the model
        torch.onnx.export(trained_model, dummy_input, "sentiment.onnx")

    #covert the pytorch model to onnx format

0 个答案:

没有答案