我正在尝试使用Scipy的非线性共轭梯度下降实现来解决矩阵分解问题。
我的问题尝试解决以下情况的B
和A @ np.transpose(B) = Y
:
(1/N) * absolute_value_norm(Y-AB) + regularisation_params
同时最小化:
scipy.optimize.fmin_cg
确切的f
函数需要fprime
和def f(x, *args):
A_shape, B_shape, matrix, mask, l = args
A_cutoff = int(A_shape[0] * A_shape[1])
A = np.reshape(x[0:A_cutoff:1], (A_shape[0], A_shape[1]))
B = np.reshape(x[(A_cutoff):len(x)], (B_shape[0], B_shape[1]))
container = csr_matrix((matrix.shape))
product_AB = np.matmul(A, np.transpose(B))
product_AB_csr = csr_matrix(product_AB)
container[mask] = matrix[mask] - product_AB_csr[mask]
error = ((1 / mask.nnz) * container.__abs__().sum()) + (l * np.square(np.linalg.norm(A))) + \
(l * np.square(np.linalg.norm(B)))
return error
参数作为目标函数及其梯度,所以这是我的输入:
def fprime(x, *args):
A_shape, B_shape, matrix, mask, l = args
A_cutoff = (A_shape[0] * A_shape[1])
A = np.reshape(x[0:(A_cutoff)], (A_shape[0], A_shape[1]))
B = np.reshape(x[(A_cutoff):len(x)], (B_shape[0], B_shape[1]))
container = csr_matrix((matrix.shape))
product_AB = np.matmul(A, np.transpose(B))
product_AB_csr = csr_matrix(product_AB)
container[mask] = matrix[mask] - product_AB_csr[mask]
original = - container / mask.nnz
grad_A, grad_B = get_abs_gradient(A, B, l, original)
# return np.asarray((grad_A, grad_B))
return np.concatenate((grad_A.flatten(), grad_B.flatten()))
并且:
train_matrix = sparse.load_npz("data.npz")
mask = csr_matrix(train_matrix > 0)
rank = 10
l = 0.001
A = np.random.random((train_matrix.shape[0], rank))
B = np.random.random((train_matrix.shape[1], rank))
matrices_flat = matrix_factorizer.flatten_matrices(A, B)
optimal = optimize.fmin_cg(f, matrices_flat,
fprime=fprime,
args=((A.shape), (B.shape), train_matrix, mask, l), full_output=True)
要运行整个管道,请使用:
Warning: Desired error not necessarily achieved due to precision loss.
Current function value: 0.193586
Iterations: 5
Function evaluations: 115
Gradient evaluations: 108
但是,我怀疑绝对值范数之和的梯度函数可能不正确,因为我在输出中收到与“精度损失”有关的警告:
Instantclick.js
我在以下位置发现了一个非常相似的问题: fmin_cg: Desired error not necessarily achieved due to precision loss
但是我输入的数据已经标准化,因此其解决方案将不适用。有人能猜出这个问题吗?