我只是试图绘制两个高斯人并找到交叉点。我有以下代码。它没有绘制确切的交叉点,但我真的无法弄清楚原因。它只是勉强稍微关闭但是如果我们采用减去的高斯的记录并且看起来它应该是正确的,我就完成了派生的解决方案。有人可以帮忙吗?非常感谢你!
import numpy as np
import matplotlib.pyplot as plt
def plot_normal(x, mean = 0, sigma = 1):
return 1.0/(2*np.pi*sigma**2) * np.exp(-((x-mean)**2)/(2*sigma**2))
# found online
def solve_gasussians(m1, s1, m2, s2):
a = 1.0/(2.0*s1**2) - 1.0/(2.0*s2**2)
b = m2/(s2**2) - m1/(s1**2)
c = m1**2 /(2*s1**2) - m2**2 / (2.0*s2**2) - np.log(s2/s1)
return np.roots([a,b,c])
s1 = np.linspace(0, 10,300)
s2 = np.linspace(0, 14, 300)
solved_val = solve_gasussians(5.0, 0.5, 7.0, 1.0)
print solved_val
solved_val = solved_val[0]
plt.figure('Baseline Distributions')
plt.title('Baseline Distributions')
plt.xlabel('Response Rate')
plt.ylabel('Probability')
plt.plot(s1, plot_normal(s1, 5.0, 0.5),'r', label='s1')
plt.plot(s2, plot_normal(s2, 7.0, 1.0),'b', label='s2')
plt.plot(solved_val, plot_normal(solved_val, 7.0, 1.0), 'mo')
plt.legend()
plt.show()
答案 0 :(得分:0)
我不知道你的代码中的错误在哪里。但我想我找到了你借来的代码并参与了你需要的调整。
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm
def solve(m1,m2,std1,std2):
a = 1/(2*std1**2) - 1/(2*std2**2)
b = m2/(std2**2) - m1/(std1**2)
c = m1**2 /(2*std1**2) - m2**2 / (2*std2**2) - np.log(std2/std1)
return np.roots([a,b,c])
m1 = 5
std1 = 0.5
m2 = 7
std2 = 1
result = solve(m1,m2,std1,std2)
x = np.linspace(-5,9,10000)
plot1=plt.plot(x,[norm.pdf(_,m1,std1) for _ in x])
plot2=plt.plot(x,[norm.pdf(_,m2,std2) for _ in x])
plot3=plt.plot(result[0],norm.pdf(result[0],m1,std1) ,'o')
plt.show()
我会提供两条不请自来的建议,这些建议可能会让你的生活更轻松(就像他们为我做的那样):
答案 1 :(得分:0)
错误就在这里。这一行:
def plot_normal(x, mean = 0, sigma = 1):
return 1.0/(2*np.pi*sigma**2) * np.exp(-((x-mean)**2)/(2*sigma**2))
应该是这样的:
def plot_normal(x, mean = 0, sigma = 1):
return 1.0/np.sqrt(2*np.pi*sigma**2) * np.exp(-((x-mean)**2)/(2*sigma**2))
您忘记了sqrt
。
如果可以使用预先存在的正常pdf,那将更明智,例如:
import scipy.stats
def plot_normal(x, mean = 0, sigma = 1):
return scipy.stats.norm.pdf(x,loc=mean,scale=sigma)
也可以准确地解决交叉路口。 This answer提供了高斯根的二次方程式。交叉点。使用maxima求解x得到以下表达式。虽然复杂,但它不依赖于迭代方法,可以从更简单的表达式自动生成。
def solve_gaussians(m1,s1,m2,s2):
x1 = (s1*s2*np.sqrt((-2*np.log(s1/s2)*s2**2)+2*s1**2*np.log(s1/s2)+m2**2-2*m1*m2+m1**2)+m1*s2**2-m2*s1**2)/(s2**2-s1**2)
x2 = -(s1*s2*np.sqrt((-2*np.log(s1/s2)*s2**2)+2*s1**2*np.log(s1/s2)+m2**2-2*m1*m2+m1**2)-m1*s2**2+m2*s1**2)/(s2**2-s1**2)
return x1,x2
完全放弃:
import numpy as np
import matplotlib.pyplot as plt
import scipy.stats
def plot_normal(x, mean = 0, sigma = 1):
return scipy.stats.norm.pdf(x,loc=mean,scale=sigma)
#Use the equation from [this answer](https://stats.stackexchange.com/a/12213/12116) solved for x
def solve_gaussians(m1,s1,m2,s2):
x1 = (s1*s2*np.sqrt((-2*np.log(s1/s2)*s2**2)+2*s1**2*np.log(s1/s2)+m2**2-2*m1*m2+m1**2)+m1*s2**2-m2*s1**2)/(s2**2-s1**2)
x2 = -(s1*s2*np.sqrt((-2*np.log(s1/s2)*s2**2)+2*s1**2*np.log(s1/s2)+m2**2-2*m1*m2+m1**2)-m1*s2**2+m2*s1**2)/(s2**2-s1**2)
return x1,x2
s = np.linspace(0, 14,300)
x = solve_gaussians(5.0,0.5,7.0,1.0)
plt.figure('Baseline Distributions')
plt.title('Baseline Distributions')
plt.xlabel('Response Rate')
plt.ylabel('Probability')
plt.plot(s, plot_normal(s, 5.0, 0.5),'r', label='s1')
plt.plot(s, plot_normal(s, 7.0, 1.0),'b', label='s2')
plt.plot(x[0],plot_normal(x[0],5.,0.5),'mo')
plt.plot(x[1],plot_normal(x[1],5.,0.5),'mo')
plt.legend()
plt.show()
,并提供:
答案 2 :(得分:0)
plot_normal
函数中有一个小错误 - 你在分母中缺少平方根。正确版本:
def plot_normal(x, mean = 0, sigma = 1):
return 1.0/np.sqrt(2*np.pi*sigma**2) * np.exp(-((x-mean)**2)/(2*sigma**2))
还有两个评论。
据我所知np.roots
为您提供了近似结果,但您可以轻松获得准确结果,将solve_gasussians
函数重写为:
def solve_gasussians(m1, s1, m2, s2):
# coefficients of quadratic equation ax^2 + bx + c = 0
a = (s1**2.0) - (s2**2.0)
b = 2 * (m1 * s2**2.0 - m2 * s1**2.0)
c = m2**2.0 * s1**2.0 - m1**2.0 * s2**2.0 - 2 * s1**2.0 * s2**2.0 * np.log(s1/s2)
x1 = (-b + np.sqrt(b**2.0 - 4.0 * a * c)) / (2.0 * a)
x2 = (-b - np.sqrt(b**2.0 - 4.0 * a * c)) / (2.0 * a)
return x1, x2