Tensorflow相当于Numpy的add.at()方法

时间:2017-04-26 15:56:36

标签: python numpy tensorflow

我计算多标签分类任务的混淆矩阵。代码如下所示:

def compute_key_metrics(sess,X_data,y_data):
    confusion_matrix = np.zeros([n_classes,n_classes],np.int32)
    n_examples = X_data.shape[0]

    for offset in range(0,n_examples,batch_size):
    X_batch = X_data[offset:offset+batch_size]
    y_batch = y_data[offset:offset+batch_size]

    y_predicted = sess.run(tf.arg_max(logits,1),feed_dict={X:X_batch, y:y_batch})
    is_prediction_correct = sess.run(correct_predictions,feed_dict={X:X_batch, y:y_batch})
    np.add.at(confusion_matrix,[y_batch,y_predicted],is_prediction_correct.astype(np.int32))

return confusion_matrix

我使用numpy.add.at()更新混淆矩阵的条目,而不显式循环预测标签。 python实现很慢。 Tensorflow中是否有类似的方法?

0 个答案:

没有答案