如何使用Tensorflow自定义错误消息

时间:2017-06-06 15:02:09

标签: python tensorflow

我正在尝试使用TensorFlow矩阵显示自定义消息错误,当矩阵的行列式等于0时,无法计算逆,但无法使用我的函数显示消息错误。我的代码结构如下:

import tensorflow as tf
def inversematricx(arg):
    args = tf.convert_to_tensor(arg, dtype=tf.float32)
    try:
        return tf.matrix_inverse(args)
    except:
        raise ValueError("Determinant is 0. Input is not invertible")

mat1=tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]) # Determinant is 0 for mat1
mat2=tf.constant([[1.0, 2.0, 4.0], [3.0, 5.0, 6.0], [7.0, 8.0, 9.0]])

inverse=inversematricx(mat1)

with tf.Session() as sess: 

    result = sess.run(inverse)    
    print(result)

mat2

的结果
  

[[0.17647055 -0.82352936 0.47058824] [-0.88235289 1.11764693   -0.35294116] [0.64705878 -0.35294113 0.05882351]]

但对于决定因素等于0的mat1,我想强制输出 对于ValueError消息,而不是生成的错误:

InvalidArgumentError: Input is not invertible.
     [[Node: MatrixInverse_21 = MatrixInverse[T=DT_FLOAT, adjoint=false, _device="/job:localhost/replica:0/task:0/cpu:0"](Const_69)]]

Caused by op 'MatrixInverse_21', defined at:
    File "D:\WinPython-64bit-3.5.3.0Qt5\python-3.5.3.amd64\lib\site-packages\spyder\utils\ipython\start_kernel.py", line 227, in <module>
main()
....
InvalidArgumentError (see above for traceback): Input is not invertible.
 [[Node: MatrixInverse_21 = MatrixInverse[T=DT_FLOAT, adjoint=false, _device="/job:localhost/replica:0/task:0/cpu:0"](Const_69)]]

1 个答案:

答案 0 :(得分:0)

我在自定义函数上找到了tf.Print的解决方案,如下所示:

sess = tf.InteractiveSession()
def checkMatrixInverse(arg):
    f=tf.matrix_determinant(arg).eval() #get determinant value
    args = tf.convert_to_tensor(arg, dtype=tf.float32)
    inv=tf.matrix_inverse(args) 
    err='Input is not invertible:'    
    if(f==0):
        return tf.Print(err,[err], name="NotInvertible")
    else:
        return tf.Print(inv, [inv], name="Inverse")

noninverse=checkMatrixInverse(mat1) #output b'Input is not invertible:' 

inverse=checkMatrixInverse(mat2) 
#output:
#[[ 0.17647055 -0.82352936  0.47058824]
# [-0.88235289  1.11764693 -0.35294116]
# [ 0.64705878 -0.35294113  0.05882351]]