我搜索了StackOverflow,并访问了其他网站以寻求帮助,但是找不到解决我问题的方法。 我将保留整个代码,以使您可以理解。用PyTorch编写的大约110行。
每次,我都会编译并计算一个预测,该错误代码将显示:
Traceback (most recent call last):
File "/Users/MacBookPro/Dropbox/01 GST h_da Privat/BA/06_KNN/PyTorchV1/BesucherV5.py", line 108, in <module>
result = Network(test_exp).data[0][0].item()
TypeError: __init__() takes 1 positional argument but 2 were given
我知道,其他用户也有这种情况,但是他们的解决方案都没有帮助我。我猜错误是在我的类“网络”中或在变量“结果”中。 我希望你们中有人遇到这个问题,并且知道如何解决它,或者可以通过其他方式帮助我。
有关数据集的简短信息:
我的数据集有10列,并分为两组。 X和Y。X有9列,Y只有一列。然后将它们用于训练网络。
提前谢谢!
亲切的问候 克里斯蒂安·里希特(Christian Richter)
我的代码:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.optim as optim
import pandas as pd
### Dataset ###
dataset = pd.read_csv('./data/train_data_csv.csv')
x_temp = dataset.iloc[:, :-1].values
print(x_temp)
print()
print(x_temp.size)
print()
y_temp = dataset.iloc[:, 9:].values
print(y_temp)
print()
print(y_temp.size)
print()
x_train_tensor = torch.FloatTensor(x_temp)
y_train_tensor = torch.FloatTensor(y_temp)
### Network Architecture ###
class Network(nn.Module):
def __init__(self):
super(Network, self).__init__()
self.linear1 = nn.Linear(9, 9) #10 Input-Neurons, 10 Output-Neurons, Linearer Layer
self.linear2 = nn.Linear(9, 1)
def forward(self, x):
pax_predict = F.relu(self.linear1(x))
pax_predict = self.linear2(x)
return pax_predict
def num_flat_features(self, pax_predict):
size = pax_predict.size()[1:]
num = 1
for i in size:
num *= i
return num
network = Network()
print(network)
criterion = nn.MSELoss()
target = Variable(y_train_tensor)
optimizer = torch.optim.SGD(network.parameters(), lr=0.0001)
### Training
for epoch in range(200):
input = Variable(x_train_tensor)
y_pred = network(input)
loss = criterion(y_pred, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
test_exp = torch.Tensor([[40116]])
result = Network(test_exp).data[0][0].item()
print('Result is: ', result)
答案 0 :(得分:0)
我想问题很简单,就在这一行:
result = Network(test_exp).data[0][0].item()
在这里,您应该使用network
(对象)而不是Network
(类)。根据您的定义,Network
仅接受1个参数(即self
),但您传递的是2:self
和test_exp
。
也许您为对象选择了其他名称(例如net
),则可以更轻松地发现此错误。考虑到这一点:)
而且,请始终发布完整的追溯。