Tensorflow捕获奇异矩阵误差(tf.matrix_solve)

时间:2017-07-03 14:36:54

标签: tensorflow exception-handling

我有一个场景,我可以获得一个封闭形式的解决方案,其中涉及tf.matrix_solve。 Tensorflow非常智能,可以检查矩阵是否是单数,如果是这样的话会抛出异常。我怎么处理这个?以下示例不起作用。

import tensorflow as tf
import numpy as np
sess = tf.InteractiveSession()

def closed_form_solution(A,b):
    try:
        x = tf.matrix_solve(A, b)
    except Exception, e:
        x = tf.convert_to_tensor([[np.nan]] * 2)
    return x

# A and b are actually filled at each training step,
A = tf.convert_to_tensor(np.array([[1., 1.], [2., 2.]]))
b = tf.convert_to_tensor(np.array([[0.], [1.]]))
x = closed_form_solution(A,b)
x.eval()  # Throws InvalidArgumentError

以下功能似乎在大部分时间都有效,但由于我没有抓到任何东西,它仍然可以杀死我的训练。

def closed_form_solution(A,b):
    return tf.cond(tf.less_equal(tf.abs(tf.matrix_determinant(A)), 1e-5),
                   lambda : tf.convert_to_tensor([[np.nan]] * 2, dtype=tf.float64),
                   lambda : tf.matrix_solve(A, b))

有没有更好的解决方案来避免异常?

0 个答案:

没有答案