我计算多标签分类任务的混淆矩阵。代码如下所示:
def compute_key_metrics(sess,X_data,y_data):
confusion_matrix = np.zeros([n_classes,n_classes],np.int32)
n_examples = X_data.shape[0]
for offset in range(0,n_examples,batch_size):
X_batch = X_data[offset:offset+batch_size]
y_batch = y_data[offset:offset+batch_size]
y_predicted = sess.run(tf.arg_max(logits,1),feed_dict={X:X_batch, y:y_batch})
is_prediction_correct = sess.run(correct_predictions,feed_dict={X:X_batch, y:y_batch})
np.add.at(confusion_matrix,[y_batch,y_predicted],is_prediction_correct.astype(np.int32))
return confusion_matrix
我使用numpy.add.at()更新混淆矩阵的条目,而不显式循环预测标签。 python实现很慢。 Tensorflow中是否有类似的方法?