所以我目前正在使用matplotlib中的许多x
和y
绘制散点图:
plt.scatter(x, y)
我想在这个散点图上绘制一条穿过整个图形的线(即击中两个'边框')我知道渐变和截距 - 方程中的m
和c
y = mx +c
。
我已经考虑过获取绘图的4个点(计算最小和最大散点x
和y
s)并从中计算线的最小和最大坐标然后绘制但是这看起来很复杂。有没有更好的方法来记住这条线甚至可能不在“情节”内?
如图中所示,四个边界坐标是粗糙的:
-1,-2
-1,2
6,-2
6,2
我现在有一条线,我需要绘制,不得超过这些边界,但如果它进入绘图必须触及两个边界点。
所以我可以检查当x = -1时y等于什么,然后检查该值是否介于-1和6之间,如果该行必须越过左边界,那么绘制它,依此类推第四个。
理想情况下,我会创建一条从-infinity到infinity的线,然后裁剪它以适应情节。
答案 0 :(得分:4)
这里的想法是在图中绘制一些等式y=m*x+y0
的线。这可以通过将最初在轴坐标中给出的水平线转换为数据坐标,根据线方程应用Affine2D变换并转换回屏幕坐标来实现。
这里的优点是您根本不需要知道轴限制。你也可以随意缩放或平移你的情节;该线将始终保持在轴边界内。因此,它有效地实现了从-infinity到+ inifinty的范围。
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.transforms as mtransforms
def axaline(m,y0, ax=None, **kwargs):
if not ax:
ax = plt.gca()
tr = mtransforms.BboxTransformTo(
mtransforms.TransformedBbox(ax.viewLim, ax.transScale)) + \
ax.transScale.inverted()
aff = mtransforms.Affine2D.from_values(1,m,0,0,0,y0)
trinv = ax.transData
line = plt.Line2D([0,1],[0,0],transform=tr+aff+trinv, **kwargs)
ax.add_line(line)
x = np.random.rand(20)*6-0.7
y = (np.random.rand(20)-.5)*4
c = (x > 3).astype(int)
fig, ax = plt.subplots()
ax.scatter(x,y, c=c, cmap="bwr")
# draw y=m*x+y0 into the plot
m = 0.4; y0 = -1
axaline(m,y0, ax=ax, color="limegreen", linewidth=5)
plt.show()
虽然这个解决方案在第一眼看上去有点复杂,但人们不需要完全理解它。只需将axaline
函数复制到您的代码中,然后按原样使用它。
<小时/> 为了在没有变换的情况下使自动更新工作,可以添加回调,这将在每次图中的某些变化时重置变换。
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import transforms
class axaline():
def __init__(self, m,y0, ax=None, **kwargs):
if not ax: ax = plt.gca()
self.ax = ax
self.aff = transforms.Affine2D.from_values(1,m,0,0,0,y0)
self.line = plt.Line2D([0,1],[0,0], **kwargs)
self.update()
self.ax.add_line(self.line)
self.ax.callbacks.connect('xlim_changed', self.update)
self.ax.callbacks.connect('ylim_changed', self.update)
def update(self, evt=None):
tr = ax.transAxes - ax.transData
trinv = ax.transData
self.line.set_transform(tr+self.aff+trinv)
x = np.random.rand(20)*6-0.7
y = (np.random.rand(20)-.5)*4
c = (x > 3).astype(int)
fig, ax = plt.subplots()
ax.scatter(x,y, c=c, cmap="bwr")
# draw y=m*x+y0 into the plot
m = 0.4; y0 = -1
al = axaline(m,y0, ax=ax, color="limegreen", linewidth=5)
plt.show()
答案 1 :(得分:2)
您可以尝试:
import matplotlib.pyplot as plt
import numpy as np
m=3
c=-2
x1Data= np.random.normal(scale=2, loc=.4, size=25)
y1Data= np.random.normal(scale=3, loc=1.2, size=25)
x2Data= np.random.normal(scale=1, loc=3.4, size=25)
y2Data= np.random.normal(scale=.65, loc=-.2, size=25)
fig = plt.figure()
ax = fig.add_subplot( 1, 1, 1 )
ax.scatter(x1Data, y1Data)
ax.scatter(x2Data, y2Data)
ylim = ax.get_ylim()
xlim = ax.get_xlim()
ax.plot( xlim, [ m * x + c for x in xlim ], 'r:' )
ax.set_ylim( ylim )
ax.set_xlim( xlim )
plt.show()
给出: