我有一个场景,我可以获得一个封闭形式的解决方案,其中涉及tf.matrix_solve。 Tensorflow非常智能,可以检查矩阵是否是单数,如果是这样的话会抛出异常。我怎么处理这个?以下示例不起作用。
import tensorflow as tf
import numpy as np
sess = tf.InteractiveSession()
def closed_form_solution(A,b):
try:
x = tf.matrix_solve(A, b)
except Exception, e:
x = tf.convert_to_tensor([[np.nan]] * 2)
return x
# A and b are actually filled at each training step,
A = tf.convert_to_tensor(np.array([[1., 1.], [2., 2.]]))
b = tf.convert_to_tensor(np.array([[0.], [1.]]))
x = closed_form_solution(A,b)
x.eval() # Throws InvalidArgumentError
以下功能似乎在大部分时间都有效,但由于我没有抓到任何东西,它仍然可以杀死我的训练。
def closed_form_solution(A,b):
return tf.cond(tf.less_equal(tf.abs(tf.matrix_determinant(A)), 1e-5),
lambda : tf.convert_to_tensor([[np.nan]] * 2, dtype=tf.float64),
lambda : tf.matrix_solve(A, b))
有没有更好的解决方案来避免异常?