高斯之间的交集

时间:2016-12-28 19:40:39

标签: python gaussian

我只是试图绘制两个高斯人并找到交叉点。我有以下代码。它没有绘制确切的交叉点,但我真的无法弄清楚原因。它只是勉强稍微关闭但是如果我们采用减去的高斯的记录并且看起来它应该是正确的,我就完成了派生的解决方案。有人可以帮忙吗?非常感谢你!

import numpy as np 
import matplotlib.pyplot as plt 

def plot_normal(x, mean = 0, sigma = 1):
    return 1.0/(2*np.pi*sigma**2) * np.exp(-((x-mean)**2)/(2*sigma**2))

# found online
def solve_gasussians(m1, s1, m2, s2):
  a = 1.0/(2.0*s1**2) - 1.0/(2.0*s2**2)
  b = m2/(s2**2) - m1/(s1**2)
  c = m1**2 /(2*s1**2) - m2**2 / (2.0*s2**2) - np.log(s2/s1)
  return np.roots([a,b,c])

s1 = np.linspace(0, 10,300)
s2 = np.linspace(0, 14, 300)

solved_val = solve_gasussians(5.0, 0.5, 7.0, 1.0)
print solved_val
solved_val = solved_val[0]
plt.figure('Baseline Distributions')
plt.title('Baseline Distributions')
plt.xlabel('Response Rate')
plt.ylabel('Probability')
plt.plot(s1, plot_normal(s1, 5.0, 0.5),'r', label='s1')
plt.plot(s2, plot_normal(s2, 7.0, 1.0),'b', label='s2')
plt.plot(solved_val, plot_normal(solved_val, 7.0, 1.0), 'mo')
plt.legend()
plt.show()

3 个答案:

答案 0 :(得分:0)

我不知道你的代码中的错误在哪里。但我想我找到了你借来的代码并参与了你需要的调整。

import numpy as np 
import matplotlib.pyplot as plt 
from scipy.stats import norm

def solve(m1,m2,std1,std2):
  a = 1/(2*std1**2) - 1/(2*std2**2)
  b = m2/(std2**2) - m1/(std1**2)
  c = m1**2 /(2*std1**2) - m2**2 / (2*std2**2) - np.log(std2/std1)
  return np.roots([a,b,c])

m1 = 5
std1 = 0.5
m2 = 7
std2 = 1

result = solve(m1,m2,std1,std2)

x = np.linspace(-5,9,10000)
plot1=plt.plot(x,[norm.pdf(_,m1,std1) for _ in x])
plot2=plt.plot(x,[norm.pdf(_,m2,std2) for _ in x])
plot3=plt.plot(result[0],norm.pdf(result[0],m1,std1) ,'o')

plt.show()

我会提供两条不请自来的建议,这些建议可能会让你的生活更轻松(就像他们为我做的那样):

  • 当您调整代码时,尝试进行小的增量更改,并检查代码是否仍然适用于每一步。
  • 寻找现有的免费图书馆。在这种情况下,来自scipy的 norm 可以替代原始代码中使用的内容。

答案 1 :(得分:0)

错误就在这里。这一行:

def plot_normal(x, mean = 0, sigma = 1):
  return 1.0/(2*np.pi*sigma**2) * np.exp(-((x-mean)**2)/(2*sigma**2))

应该是这样的:

def plot_normal(x, mean = 0, sigma = 1):
  return 1.0/np.sqrt(2*np.pi*sigma**2) * np.exp(-((x-mean)**2)/(2*sigma**2))

您忘记了sqrt

如果可以使用预先存在的正常pdf,那将更明智,例如:

import scipy.stats
def plot_normal(x, mean = 0, sigma = 1):
  return scipy.stats.norm.pdf(x,loc=mean,scale=sigma)

也可以准确地解决交叉路口。 This answer提供了高斯根的二次方程式。交叉点。使用maxima求解x得到以下表达式。虽然复杂,但它不依赖于迭代方法,可以从更简单的表达式自动生成。

def solve_gaussians(m1,s1,m2,s2):
  x1 = (s1*s2*np.sqrt((-2*np.log(s1/s2)*s2**2)+2*s1**2*np.log(s1/s2)+m2**2-2*m1*m2+m1**2)+m1*s2**2-m2*s1**2)/(s2**2-s1**2)
  x2 = -(s1*s2*np.sqrt((-2*np.log(s1/s2)*s2**2)+2*s1**2*np.log(s1/s2)+m2**2-2*m1*m2+m1**2)-m1*s2**2+m2*s1**2)/(s2**2-s1**2)
  return x1,x2

完全放弃:

import numpy as np 
import matplotlib.pyplot as plt 
import scipy.stats

def plot_normal(x, mean = 0, sigma = 1):
  return scipy.stats.norm.pdf(x,loc=mean,scale=sigma)

#Use the equation from [this answer](https://stats.stackexchange.com/a/12213/12116) solved for x
def solve_gaussians(m1,s1,m2,s2):
  x1 = (s1*s2*np.sqrt((-2*np.log(s1/s2)*s2**2)+2*s1**2*np.log(s1/s2)+m2**2-2*m1*m2+m1**2)+m1*s2**2-m2*s1**2)/(s2**2-s1**2)
  x2 = -(s1*s2*np.sqrt((-2*np.log(s1/s2)*s2**2)+2*s1**2*np.log(s1/s2)+m2**2-2*m1*m2+m1**2)-m1*s2**2+m2*s1**2)/(s2**2-s1**2)
  return x1,x2

s = np.linspace(0, 14,300)
x = solve_gaussians(5.0,0.5,7.0,1.0)

plt.figure('Baseline Distributions')
plt.title('Baseline Distributions')
plt.xlabel('Response Rate')
plt.ylabel('Probability')
plt.plot(s, plot_normal(s, 5.0, 0.5),'r', label='s1')
plt.plot(s, plot_normal(s, 7.0, 1.0),'b', label='s2')
plt.plot(x[0],plot_normal(x[0],5.,0.5),'mo')
plt.plot(x[1],plot_normal(x[1],5.,0.5),'mo')
plt.legend()
plt.show()

,并提供:

Intersection of Gaussians

答案 2 :(得分:0)

plot_normal函数中有一个小错误 - 你在分母中缺少平方根。正确版本:

def plot_normal(x, mean = 0, sigma = 1):
    return 1.0/np.sqrt(2*np.pi*sigma**2) * np.exp(-((x-mean)**2)/(2*sigma**2))

给出了预期的结果: enter image description here

还有两个评论。

  1. 请记住,一般来说,你可以有2个方程根(两个交点),这就是你提供的参数的情况。
  2. 据我所知np.roots为您提供了近似结果,但您可以轻松获得准确结果,将solve_gasussians函数重写为:

    def solve_gasussians(m1, s1, m2, s2):
        # coefficients of quadratic equation ax^2 + bx + c = 0
        a = (s1**2.0) - (s2**2.0)
        b = 2 * (m1 * s2**2.0 - m2 * s1**2.0)
        c = m2**2.0 * s1**2.0 - m1**2.0 * s2**2.0 - 2 * s1**2.0 * s2**2.0 * np.log(s1/s2)
        x1 = (-b + np.sqrt(b**2.0 - 4.0 * a * c)) / (2.0 * a)
        x2 = (-b - np.sqrt(b**2.0 - 4.0 * a * c)) / (2.0 * a)
        return x1, x2