我正在尝试在PyTorch中实现一个简单的线性模型,该模型可以提供x数据和y数据,然后经过训练可以识别方程y = mx + b。但是,每当我尝试在训练后测试模型时,它都会认为方程为y = mx + 2b。我将展示我的代码,希望有人能够发现问题。预先感谢您的帮助。
import torch
D_in = 500
D_out = 500
batch=200
model=torch.nn.Sequential(
torch.nn.Linear(D_in,D_out),
)
接下来,我创建一些数据并设置规则。让我们做3x + 4。
x_data=torch.rand(batch,D_in)
y_data=torch.randn(batch,D_out)
for i in range(batch):
for j in range(D_in):
y_data[i][j]=3*x_data[i][j]+5 # model thinks y=mx+c -> y=mx+2c?
loss_fn=torch.nn.MSELoss(size_average=False)
optimizer=torch.optim.Adam(model.parameters(),lr=0.001)
现在要培训...
for epoch in range(500):
y_pred=model(x_data)
loss=loss_fn(y_pred,y_data)
optimizer.zero_grad()
loss.backward()
optimizer.step()
然后,我使用仅为1的张量/矩阵测试模型。
test_data=torch.ones(batch,D_in)
y_pred=model(test_data)
现在,我希望得到3 * 1 + 4 = 7,但是我的模型却认为是11 *。
[[ 10.7286, 11.0499, 10.9448, ..., 11.0812, 10.9387,
10.7516],
[ 10.7286, 11.0499, 10.9448, ..., 11.0812, 10.9387,
10.7516],
[ 10.7286, 11.0499, 10.9448, ..., 11.0812, 10.9387,
10.7516],
...,
[ 10.7286, 11.0499, 10.9448, ..., 11.0812, 10.9387,
10.7516],
[ 10.7286, 11.0499, 10.9448, ..., 11.0812, 10.9387,
10.7516],
[ 10.7286, 11.0499, 10.9448, ..., 11.0812, 10.9387,
10.7516]])
类似地,如果我将规则更改为y = 3x + 8,我的模型将猜测为19。因此,我不确定发生了什么。为什么常数要加两次?顺便说一句,如果我只是将规则设置为y = 3x,则我的模型可以正确推断3,而对于y = mx,通常我的模型可以正确推断m。由于某种原因,常数项将其抛弃。非常感谢您为解决该问题提供的帮助。谢谢!
答案 0 :(得分:2)
您的网络学习时间不够长。 它具有一个具有500个特征的向量来描述单个基准。
您的网络必须将500个要素的大输入映射到包含500个值的输出。 您的训练数据是随机创建的,不像您的简单示例那样,因此我认为您只需要训练更长的时间以适合您的体重,即可将此函数从R ^ 500近似为R ^ 500。
如果我减小输入和输出尺寸并增加批次大小,学习率和培训步骤,我将获得预期的结果:
import torch
D_in = 100
D_out = 100
batch = 512
model=torch.nn.Sequential(
torch.nn.Linear(D_in,D_out),
)
x_data=torch.rand(batch,D_in)
y_data=torch.randn(batch,D_out)
for i in range(batch):
for j in range(D_in):
y_data[i][j]=3*x_data[i][j]+4 # model thinks y=mx+c -> y=mx+2c?
loss_fn=torch.nn.MSELoss(size_average=False)
optimizer=torch.optim.Adam(model.parameters(),lr=0.01)
for epoch in range(10000):
y_pred=model(x_data)
loss=loss_fn(y_pred,y_data)
optimizer.zero_grad()
loss.backward()
optimizer.step()
test_data=torch.ones(batch,D_in)
y_pred=model(test_data)
print(y_pred)
如果您只想用一个输入来近似f(x) = 3x + 4
,那么也可以将D_in
和D_out
设置为1。