我正在尝试使用1个特征进行简单的线性回归。这是一个简单的“根据多年经验预测薪水”问题。
NN会训练年经验(X)和薪水(Y)。
由于某种原因,损失激增,最终返回inf
或nan
这是我的代码:
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
dataset = pd.read_csv('./salaries.csv')
x_temp = dataset.iloc[:, :-1].values
y_temp = dataset.iloc[:, 1:].values
X_train = torch.FloatTensor(x_temp)
Y_train = torch.FloatTensor(y_temp)
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(1,1)
def forward(self, x):
y_pred = self.linear(x)
return y_pred
model = Model()
loss_func = torch.nn.MSELoss(size_average=False)
optim = torch.optim.SGD(model.parameters(), lr=0.01)
#training
for epoch in range(200):
#calculate y_pred
y_pred = model(X_train)
#calculate loss
loss = loss_func(y_pred, Y_train)
print(epoch, "{:.2f}".format(loss.data))
#backward pass + update weights
optim.zero_grad()
loss.backward()
optim.step()
test_exp = torch.FloatTensor([[8.0]])
print("8 years experience --> ", model(test_exp).data[0][0].item())
正如我提到的,一旦开始训练,损失就会变得非常大,最终在第10个时期之后显示inf
。
我怀疑这可能与我如何加载数据有关?这就是salaries.csv
文件中的内容:
Years Salary
1.1 39343
1.3 46205
1.5 37731
2 43525
2.2 39891
2.9 56642
3 60150
3.2 54445
3.2 64445
3.7 57189
3.9 63218
4 55794
4 56957
4.1 57081
4.5 61111
4.9 67938
5.1 66029
5.3 83088
谢谢您的帮助
答案 0 :(得分:5)
一旦经过一定的损失后,损失变为inf,则模型在反向传播后会损坏。这可能是因为“薪水”列中的值太大。尝试使工资标准化。
或者,您可以尝试手动初始化参数(而不是让它随机初始化),让偏差项为薪水的平均值,而直线的斜率为0(例如)。这样,初始模型将足够接近最佳解,因此损失不会增加。
答案 1 :(得分:0)
以下是所有情况的示例。您可以尝试运行该程序,该程序基本上表示r-深层网络。
import torch
import math
import matplotlib.pyplot as plt
def stat(t, p=True):
m = t.mean()
s = t.std()
if p==True:
print(f"MEAN: {m}, STD: {s}")
return(m,s)
_m = []
_s = []
c = 100
r = 50# repeat steps
x = torch.randn(c)
m = torch.randn(c,c)#/math.sqrt(n)
stat(x)
for _ in range (0,r):
x = m@x
_1, _2 = stat(x, False)
_m.append(_1)
_s.append(_2)
stat(x)
plt.plot(_m)
plt.plot(_s)
plt.legend(["mean","std"])
plt.show()
答案 2 :(得分:0)
请将学习率“lr”降低到 0.001 或 0.0001。较大的 lr 值会使梯度爆炸并导致 inf。我已经尝试过 lr=0.001 和 lr=0.0001 它对我来说很好用。请尝试一次并告诉我。
答案 3 :(得分:0)
获得 import pyspark.sql.functions as F
df1 = df.withColumn(
'CONV_ID',
F.split('URL', '(?<=conversations/)')[1] # just using 'conversations/' should also be enough
)
df1.show()
+---+--------------------+----------+
| ID| URL| CONV_ID|
+---+--------------------+----------+
| 1|https://app.xyz.c...|2686735685|
| 2|https://app.xyz.c...|2938415796|
| 3|https://app.drift...|2938419189|
+---+--------------------+----------+
损失的另一种可能性是包含 nan
值的模型的输入张量。尝试从模型输入中过滤 nan
值。