为QR更新实现伪代码-我的代码的哪一部分是错误的?

时间:2018-12-31 17:56:44

标签: python numpy linear qr-decomposition

我正在实现一些伪代码来更新QR分解-我想从矩阵A中删除一行,并相应地更新QR。 我的结果与scipy的qr_delete()产生的结果不同(请参见提供的代码段中的第21行),Q和R都不匹配。

我所指的更新的伪代码可以在here for algorithm 2.1here for algorithm 2.2中找到。

虽然我在代码中找不到错误。我猜我在将伪代码转换为Python的某个地方犯了一个错误-你们能够发现它吗?我们非常感谢您的帮助。

import numpy as np
from typing import Tuple
import math
import scipy.linalg._decomp_update as scipy_qr_update


# Generate a random matrix A and compute QR decomposition.
m, n = (15, 10)
A = np.random.rand(m, n)
x = np.ones((n, 1))
b = np.dot(A, x)
Q, R = scipy.linalg.qr(A)

# Compute update with scipy's qr_delete().
Q_tilde_corr, R_tilde_corr = scipy_qr_update.qr_delete(Q, R, k=0, p=m - m_tilde, which="row")

# Compute update with own implementation.
Q_tilde, R_tilde, b_tilde, residual = qr_delete_row(np.copy(Q), np.copy(R), np.copy(b), k=0)

# Check if scipy's and our results are equal.
print(np.allclose(Q_tilde_corr, Q_tilde), np.allclose(R_tilde_corr, R_tilde))

def qr_delete_row(Q: np.ndarray, R: np.ndarray, b: np.ndarray, k: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.float]:
    """
    Updates Q and R after deleting one single row with index k.
    :param Q:
    :param R:
    :param b: b in Ax = b.
    :param k:
    :return:
    """

    m, n = Q.shape[0], R.shape[1]

    ###################################
    # Algorithm 2.1 - compute R_tilde.
    ###################################

    q_t = Q[k, :]
    q = q_t.T
    cs_values = np.zeros((m, 2))
    cs_values[0] = givens(q[0], q[1])

    if k != 0:
        b[1:k + 1] = b[0:k]

    d = Q.T @ b

    for j in np.arange(start=m - 2, step=-1, stop=-1):
        c, s = givens(q[j], q[j + 1])
        cs_values[j] = [c, s]
        cs_matrix = np.asarray([[c, s], [-s, c]])

        q[j] = c * q[j] - s * q[j + 1]

        if j <= n:
            R[j:j + 2, j:] = cs_matrix.T @ R[j:j + 2, j:]

        d[j: j + 2] = cs_matrix.T @ d[j:j + 2]

    R_tilde = R[1:, :]
    d_tilde = d[1:]
    resid = np.linalg.norm(d_tilde[n + 1:m], ord=2)

    ###################################
    # Algorithm 2.2 - compute Q_tilde.
    ###################################

    if k != 0:
        Q[1:k + 1, 1:] = Q[0:k, 0:]

    for j in np.arange(start=m - 2, step=-1, stop=-1):
        c, s = cs_values[j]
        cs_matrix = np.asarray([[c, s], [-s, c]])

        Q[1:, j:j + 2] = Q[1:m, j:j + 2] @ cs_matrix

    Q[1:, 1] = cs_values[0, 1] * Q[1:, 0] + cs_values[0, 1] * Q[1:, 1]

    return Q[1:, 1:], R_tilde, b[1:], resid

0 个答案:

没有答案