最小化具有代数约束和边界的最小二乘

时间:2018-08-29 23:25:03

标签: python python-3.x mathematical-optimization

我正在尝试根据某些矢量求和来最小化最小二乘之和。简而言之,我正在创建一个方程,该方程采用理想矢量,用确定的系数对它们进行加权,然后对加权后的矢量求和。一旦将该和与观察到的实际矢量测量值进行比较,就会出现最小二乘之和。

举个例子:

# Observation A has the following measurements:
A = [0, 4.1, 5.6, 8.9, 4.3]

# How similar is A to ideal groups identified by the following:
group1 = [1, 3, 5, 10, 3]
group2 = [6, 3, 2, 1, 10]
group3 = [3, 3, 4, 2, 1]

# Let y be the predicted measurement for A with coefficients s1, s2, and s3:
y = s1 * group1 + s2 * group2 + s3 * group3

# y will be some vector of length 5, similar to A
# Now find the sum of least squares between y and A
sum((y_i - A_i)** 2 for y_i in y for A_i in A)
  

必要的界限和约束

     

0 <= s1,s2,s3 <= 1

     

s1 + s2 + s3 = 1

     

y = s1 *组1 + s2 *组2 + s3 *组3

我想最小化y和A的最小平方和,以获得系数s1,s2,s3,但是我很难确定scipy.optimize的正确选择。似乎那里的最小二乘和最小化功能不能处理代数变量约束。我正在使用的数据是这些矢量化测量的数千个观测值。任何想法或想法将不胜感激!

1 个答案:

答案 0 :(得分:1)

对于您的情况,您可以像这样从scipy.optimize使用minimize()

minimize(fun=obj_fun, args=argtpl x0=xinit, bounds=bnds, constraints=cons)

其中obj_fun(x, *args)是目标函数,argtpl是目标函数的(可选)自变量的元组,xinit是起始点,bnds是目标函数的元组列表变量的边界和cons约束的字典列表。

import numpy as np
from scipy.optimize import minimize

# Observation A has the following measurements:
A = np.array([0, 4.1, 5.6, 8.9, 4.3])
# How similar is A to ideal groups identified by the following:
group1 = np.array([1, 3, 5, 10, 3])
group2 = np.array([6, 3, 2, 1, 10])
group3 = np.array([3, 3, 4, 2, 1])

# Define the objective function
# x is the array containing your wanted coefficients
def obj_fun(x, A, g1, g2, g3):
    y = x[0] * g1 + x[1] * g2 + x[2] * g3
    return np.sum((y-A)**2)

# Bounds for the coefficients
bnds = [(0, 1), (0, 1), (0, 1)]
# Constraint: x[0] + x[1] + x[2] - 1 = 0
cons = [{"type": "eq", "fun": lambda x: x[0] + x[1] + x[2] - 1}]

# Initial guess
xinit = np.array([1, 1, 1])
res = minimize(fun=obj_fun, args=(A, group1, group2, group3), x0=xinit, bounds=bnds, constraints=cons)
print(res.x)

您的示例解决方案:

array([9.25609756e-01, 7.43902439e-02, 6.24242179e-12])