我遇到的问题是,无论我传递什么标签/预测TF.Metrics.Mean_Squared_Error,它总是返回0值。
以下是重复问题的代码:
a = tf.constant([0,0,0,0])
b = tf.constant([1,1,1,1])
mse, update = tf.metrics.mean_squared_error(a,b)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
mse.eval(session=sess)
%%返回0.0
答案 0 :(得分:1)
我真的不知道为什么会这样,但实际上你需要在mse的内部状态考虑你的数据之前运行update
:
a = tf.constant([0,0,0,0])
b = tf.constant([1,1,1,1])
mse, update = tf.metrics.mean_squared_error(a,b)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
mse.eval(session=sess) # Gives 0.0, the initial MSE value
update.eval(session=sess) # Gives 1.0, the value of the update for the mse
mse.eval(session=sess) # Gives 1.0, which is 0.0+1.0, the updated mse value
例如, tf.metrics.mean_squared_error()
用于计算整个数据集上的MSE,因此如果您希望独立批处理结果,则不应使用它。为此,例如使用tf.losses.mean_squared_error(a, b, loss_collection=None)
。