Tensorflow在多数投票中表现缓慢

时间:2018-08-23 18:25:29

标签: python tensorflow machine-learning

我在Tensorflow 1.10上实现了多数表决算法(对不同分类器进行计数预测),并且预测1000个数据集(MNIST)的速度非常慢(10个分类器需要3个小时以上)。根据我的猜测,这是因为我的代码经常调用session.run(),但是我该如何优化它呢?

def majority_voting(session, x, y):
    votes = []
    for i in range(number_of_ensemble_modules):
        # run the training
        feature_extractor = iterators[i][3]
        input, label = feature_extractor(x, y)
        transformed_x = session.run(input)
        ensemble_prediction = nn_models[0][i][0][3]
        prediction = session.run(ensemble_prediction, feed_dict={X: transformed_x, Y: y})
        votes.append(prediction[0])
    nearest_k_y, idx, vote = tf.unique_with_counts(tf.convert_to_tensor(votes, tf.int64))
    majority = tf.argmax(vote)
    predict_res = tf.gather(nearest_k_y, majority)
    return predict_res


def calculate_ensemble_accuracy():
    accuracy = 0
    for j in range(voting_iterations):
        accuracy += 0
        (features, labels) = session.run(next_element)
        vote = majority_voting(session, features, labels)
        correct_label = session.run(tf.argmax(labels, axis=1))
        if vote == correct_label[0]:
            accuracy += 1
    return accuracy

1 个答案:

答案 0 :(得分:0)

一些提示可能会解决您的问题。

1-在创建tensorflow graph之前进行特征提取。例如,如果创建TfIDF特征向量,则可以在预处理步骤中进行操作并保存numpy以供图形输入。

input, label = feature_extractor(x, y)

2-删除不必要的session.run()。例如,当您调用Optimizer时,它会自动调用x_transformed。

transformed_x = session.run(input)

3-以更好的方式使用tf.dataDataset API)。无需调用sess.run(next_element)。因为next_element是图形的一部分。