如何向散点图矩阵中的每个散点图添加线性回归线?

时间:2020-09-16 20:11:15

标签: python plotly

我有创建散点图矩阵的代码,我想向每个构面添加线性回归线。代码和当前图形如下所示。我目前对数据集中前五个变量的每个变量组合都有一个散点图。我想添加回归线,以便当个人将鼠标悬停在这条线上时,他们也可以看到相关性。

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import pandas as pd
import plotly.express as px
from sklearn import datasets
from typing import Tuple, List
import plotly.graph_objects as go
from plotly.subplots import make_subplots

def load_data() -> Tuple[np.ndarray, np.ndarray, List[str]]:
    """Load the wine dataset

    Returns:
        features: the dataset features
        target: the labels of the dataset
        feature_names: names of each feature
    """
    wine = datasets.load_wine()
    features = wine['data']
    target = wine['target']
    feature_names = wine['feature_names']
    return features, target, feature_names

features, target, feature_names = load_data()
Data = {
    feature_names[0]:features[:,0].tolist(),
    feature_names[1]:features[:,1].tolist(),
    feature_names[2]:features[:,2].tolist(),
    'Target': target.tolist()
}
Data = pd.DataFrame(data = Data)

index_vals = Data['Target'].astype('category').cat.codes

fig = go.Figure(data = go.Splom(dimensions = [
    dict(label = feature_names[0],values = Data[feature_names[0]]),
    dict(label = feature_names[1],values = Data[feature_names[1]]),
    dict(label = feature_names[2],values = Data[feature_names[2]])],
   text = Data['Target'],
    marker = dict(color = index_vals,showscale = False,size = 8)
))

fig.update_layout(
    title='Wine Dataset',
    dragmode='select',
    width=900,
    height=600,
    hovermode='closest',
)

fig.show()

0 个答案:

没有答案