我开始观看有关 PyTorch 的教程,并且正在学习逻辑回归的概念。
我使用我拥有的一些股票数据进行了尝试。我有 inputs
,它包含两个参数 trade_quantity
和 trade_value
,以及 targets
,它具有相应的股票价格。
inputs = torch.tensor([[182723838.00, 2375432.00],
[185968153.00, 2415558.00],
[181970093.00, 2369140.00],
[221676832.00, 2811589.00],
[339785916.00, 4291782.00],
[225855390.00, 2821301.00],
[151430199.00, 1889032.00],
[122645372.00, 1552998.00],
[129015052.00, 1617158.00],
[121207837.00, 1532166.00],
[139554705.00, 1789392.00]])
targets = torch.tensor([[76.90],
[76.90],
[76.90],
[80.70],
[78.95],
[79.60],
[80.05],
[78.90],
[79.40],
[78.95],
[77.80]])
我定义了模型函数,损失为均方误差,并尝试运行几次以获得一些预测。代码如下:
def model(x):
return x @ w.t() + b
def mse(t1, t2):
diff = t1 - t2
return torch.sum(diff * diff) / diff.numel()
preds = model(inputs)
loss = mse(preds, targets)
loss.backward()
with torch.no_grad():
w -= w.grad * 1e-5
b -= b.grad * 1e-5
w.grad.zero_()
b.grad.zero_()
我为此使用了 Jupyter 并运行了代码的最后一部分,之后预测如下:
tensor([[inf],
[inf],
[inf],
[inf],
[inf],
[inf],
[inf],
[inf],
[inf],
[inf],
[inf]], grad_fn=<AddBackward0>)
如果我再运行几次,预测就会变成 nan
。你能告诉我为什么会这样吗?
答案 0 :(得分:0)
对我来说,这看起来更像是线性回归而不是逻辑回归。您正在尝试将线性模型拟合到您的数据上。它与需要使用特殊类型的激活函数(例如 sigmoid)以便输出为 0
或 1
的二元分类任务不同。
在这个特定的例子中,你想解决一个二维线性问题,给定形状为 x
的输入 (batch, x1, x2)
(其中 x1
是 trade_quantity
,x2
是 { {1}}) 和目标 trade_value
((batch, y)
是 y
)。
所以目标是找到最好的stock_price
和w
矩阵(权重矩阵和偏差列),使b
尽可能接近x@w + b
,< em>根据你的标准,均方误差。
我建议对您的数据进行标准化,使其保持在 y
范围内。您可以通过测量 [0, 1]
和 inputs
的均值和标准差来实现。
targets
然后应用转换:
inputs_min, inputs_max = inputs.min(axis=0).values, inputs.max(axis=0).values
targets_min, targets_max = targets.min(axis=0).values, targets.max(axis=0).values
尝试改变你的学习率并让它运行多个时期。
x = (inputs - inputs_min)/(inputs_max - inputs_min)
y = (targets - targets_min)/(targets_max - targets_min)
我对 lr = 1e-2
for epochs in range(100):
preds = model(x)
loss = mse(preds, y)
loss.backward()
with torch.no_grad():
w -= lr*w.grad
b -= lr*b.grad
w.grad.zero_()
b.grad.zero_()
使用了一个 (1, 2)
随机初始化矩阵(对 w
使用了一个 (1,)
矩阵):
b
并且在 100 个 epoch 中得到以下训练损失:
要找到正确的超参数,最好有一个验证集。该集合将使用训练集中的 mean 和 std 进行标准化。它将用于评估每个时期结束时模型“未知”数据的性能。如果您有测试集,也同样如此。