我有创建散点图矩阵的代码,我想向每个构面添加线性回归线。代码和当前图形如下所示。我目前对数据集中前五个变量的每个变量组合都有一个散点图。我想添加回归线,以便当个人将鼠标悬停在这条线上时,他们也可以看到相关性。
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import pandas as pd
import plotly.express as px
from sklearn import datasets
from typing import Tuple, List
import plotly.graph_objects as go
from plotly.subplots import make_subplots
def load_data() -> Tuple[np.ndarray, np.ndarray, List[str]]:
"""Load the wine dataset
Returns:
features: the dataset features
target: the labels of the dataset
feature_names: names of each feature
"""
wine = datasets.load_wine()
features = wine['data']
target = wine['target']
feature_names = wine['feature_names']
return features, target, feature_names
features, target, feature_names = load_data()
Data = {
feature_names[0]:features[:,0].tolist(),
feature_names[1]:features[:,1].tolist(),
feature_names[2]:features[:,2].tolist(),
'Target': target.tolist()
}
Data = pd.DataFrame(data = Data)
index_vals = Data['Target'].astype('category').cat.codes
fig = go.Figure(data = go.Splom(dimensions = [
dict(label = feature_names[0],values = Data[feature_names[0]]),
dict(label = feature_names[1],values = Data[feature_names[1]]),
dict(label = feature_names[2],values = Data[feature_names[2]])],
text = Data['Target'],
marker = dict(color = index_vals,showscale = False,size = 8)
))
fig.update_layout(
title='Wine Dataset',
dragmode='select',
width=900,
height=600,
hovermode='closest',
)
fig.show()