Tensorflow总结了给定的索引

时间:2018-03-16 06:27:35

标签: python tensorflow machine-learning

我试图在两个矩阵中找到差异,以便我可以定义我的损失函数。模型非常简单,我有一个输入矩阵和一个输出矩阵。

将X定义为输入矩阵

将Y定义为输出矩阵

通常我会做一个tf.reduce_mean(tf.abs(X-Y)),但这是不可能的,因为矩阵X包含纳米值

所以我想要做的是在False给出tf.is_nan(X)的地方加上X,然后我会在我添加X的相同索引处加起来Y.然后我将定义我的损失loss = tf.abs(reduce_nan_sum(X)-reduce_nan_sum(Y))

with tf.Session() as sess:
    sess.run(init)
    print(sess.run(tf.is_nan(X), feed_dict={X: vals}))

[[False, False,  True],
 [False,  True, False]]

样本X和Y值

X = [[0.,  1.,  nan],
    [2.,  nan, 0.5]]

Y = [[0.002,  0.967,  0,2],
    [1.956,  0.3, 0.487]] 

1 个答案:

答案 0 :(得分:0)

NaNs可以用零替换以进行操作。使用与Replace nan values in tensorflow tensor

相同的解决方案
sess.run(tf.reduce_mean(tf.abs(tf.where(tf.is_nan(x), tf.zeros_like(x), x)-y)))