如何计算Tensorflow中的R ^ 2

时间:2017-02-20 17:48:07

标签: python tensorflow regression

我想在Tensorflow中做回归。我不正确我正在正确计算R ^ 2,因为Tensorflow给出了与sklearn.metrics.r2_score不同的答案。有人可以查看下面的代码并告诉我是否正确实现了图像方程式。感谢

The formula I am attempting to create in TF

total_error = tf.square(tf.sub(y, tf.reduce_mean(y)))
unexplained_error = tf.square(tf.sub(y, prediction))
R_squared = tf.reduce_mean(tf.sub(tf.div(unexplained_error, total_error), 1.0))
R = tf.mul(tf.sign(R_squared),tf.sqrt(tf.abs(R_squared)))

6 个答案:

答案 0 :(得分:6)

你在计算什么" R ^ 2"是

R^2_{\text{wrong}} = \operatorname{mean}_i \left( \frac{(y_i-\hat y_i)^2}{(y_i-\mu)^2} - 1\right)[1]

与给定的表达式相比,您在错误的位置计算均值。在进行除法之前,你应该在计算错误时采用均值。

total_error = tf.reduce_sum(tf.square(tf.sub(y, tf.reduce_mean(y))))
unexplained_error = tf.reduce_sum(tf.square(tf.sub(y, prediction)))
R_squared = tf.sub(1, tf.div(unexplained_error, total_error))

答案 1 :(得分:1)

我强烈建议您不要使用配方来计算!我发现的示例无法产生一致的结果,尤其是只有一个目标变量时。这让我非常头疼!

正确的做法是使用tensorflow_addons.metrics.RQsquare()。 Tensorflow附加组件为on PyPi here,文档为part of Tensorflow here。您所要做的就是将y_shape设置为输出的形状,对于单个输出变量,通常将其设置为(1,)

答案 2 :(得分:0)

实际上它应该与rhs相反。不明原因的方差除以总方差

答案 3 :(得分:0)

我认为应该是这样的:

total_error = tf.reduce_sum(tf.square(tf.sub(y, tf.reduce_mean(y))))
unexplained_error = tf.reduce_sum(tf.square(tf.sub(y, prediction)))
R_squared = tf.sub(1, tf.div(unexplained_error, total_error))

答案 4 :(得分:0)

函数被赋予here

<canvas id="canvas" width="300" height="300"></canvas><br>
<button  id="loadjson">loadfromjson </button>
<script src='https://www.multicastr.com/imageeditor/assets/js/fabric.unmin.js'></script>
<script src="https://www.multicastr.com/user/js/jquery.min.js"></script>

该概念在here中进行了说明。

答案 5 :(得分:0)

所有其他解决方案都无法为多维y产生正确的R平方得分。在TensorFlow中计算R2(方差加权)的正确方法是:

unexplained_error = tf.reduce_sum(tf.square(labels - predictions))
total_error = tf.reduce_sum(tf.square(labels - tf.reduce_mean(labels, axis=0)))
R2 = 1. - tf.div(unexplained_error, total_error)

此TF代码段的结果与sklearn的结果完全匹配:

from sklearn.metrics import r2_score
R2 = r2_score(labels, predictions, multioutput='variance_weighted')