Python:在散点图中查找多个线性趋势线

时间:2018-09-11 17:07:14

标签: python linear-regression

我有以下熊猫数据框-

    Atomic Number      R         C
0             2.0   49.0  0.040306
1             3.0  205.0  0.209556
2             4.0  140.0  0.107296
3             5.0  117.0  0.124688
4             6.0   92.0  0.100020
5             7.0   75.0  0.068493
6             8.0   66.0  0.082244
7             9.0   57.0  0.071332
8            10.0   51.0  0.045725
9            11.0  223.0  0.217770
10           12.0  172.0  0.130719
11           13.0  182.0  0.179953
12           14.0  148.0  0.147929
13           15.0  123.0  0.102669
14           16.0  110.0  0.120729
15           17.0   98.0  0.106872
16           18.0   88.0  0.061996
17           19.0  277.0  0.260485
18           20.0  223.0  0.164312
19           33.0  133.0  0.111359
20           36.0  103.0  0.069348
21           37.0  298.0  0.270709
22           38.0  245.0  0.177368
23           54.0  124.0  0.079491

r和C之间的趋势通常是线性的。如果可能的话,我想做的是找到一个包含3个或更多点的所有可能组合的详尽列表,以及scipy.stats.linregress的趋势,以便我找到线性最佳的点组。

Which would ideally look something like this for the data, (Source),但我也在寻找所有其他可能的趋势。

所以问题是,如何将3点或更多点的所有16776915可能的组合(sum_(i = 3)^ 24 binomial(24,i))喂到lingress,并且即使没有大量代码也可以做到吗?

1 个答案:

答案 0 :(得分:3)

我下面的解决方案建议基于RANSAC算法。这是一种将数学模型(例如一条线)拟合到具有大量异常值的数据的方法。

RANSAC是robust regression领域中的一种特定方法。

我下面的解决方案首先适合RANSAC。然后,从数据集中删除靠近这条线的数据点(这与保留异常值相同),再次拟合RANSAC,删除数据等,直到只剩下很少的点为止。

这种方法总是具有取决于数据的参数(例如,噪声水平或线路的接近度)。在以下解决方案中,MIN_SAMPLESresidual_threshold是可能需要对数据结构进行一些调整的参数:

import matplotlib.pyplot as plt
import numpy as np
from sklearn import linear_model

MIN_SAMPLES = 3

x = np.linspace(0, 2, 100)

xs, ys = [], []

# generate points for thee lines described by a and b,
# we also add some noise:
for a, b in [(1.0, 2), (0.5, 1), (1.2, -1)]:
    xs.extend(x)
    ys.extend(a * x + b + .1 * np.random.randn(len(x)))

xs = np.array(xs)
ys = np.array(ys)
plt.plot(xs, ys, "r.")

colors = "rgbky"
idx = 0

while len(xs) > MIN_SAMPLES:

    # build design matrix for linear regressor
    X = np.ones((len(xs), 2))
    X[:, 1] = xs

    ransac = linear_model.RANSACRegressor(
        residual_threshold=.3, min_samples=MIN_SAMPLES
    )

    res = ransac.fit(X, ys)

    # vector of boolean values, describes which points belong
    # to the fitted line:
    inlier_mask = ransac.inlier_mask_

    # plot point cloud:
    xinlier = xs[inlier_mask]
    yinlier = ys[inlier_mask]

    # circle through colors:
    color = colors[idx % len(colors)]
    idx += 1
    plt.plot(xinlier, yinlier, color + "*")

    # only keep the outliers:
    xs = xs[~inlier_mask]
    ys = ys[~inlier_mask]

plt.show()

在以下显示为星号的绘图点中,属于我的代码检测到的星团。您还可以看到一些圆点,它们是迭代后剩余的点。少量的黑色星星组成一个簇,您可以通过增加MIN_SAMPLES和/或residual_threshold来消除它们。

enter image description here