在CPU而非GPU上进行Tensorflow KMeansClustering培训

时间:2018-07-18 21:46:55

标签: python tensorflow

我目前正在使用cuda 9.0运行tensorflow-gpu 1.8。当我在下面训练模型时,我注意到我的CPU使用率为99%,但我的Tesla M60使用率为0%。

将tensorflow导入为tf

tf.test.gpu_device_name()

def train(tfidf_matrix, num_clusters):
    # Convert sparse matrix to dense matrix
    points = tfidf_matrix.todense()

    tf.device('/gpu:0')

    def input_fn():
        return tf.train.limit_epochs(tf.convert_to_tensor(points, dtype=tf.float32), num_epochs=1)

    kmeans = tf.contrib.factorization.KMeansClustering(num_clusters=num_clusters, use_mini_batch=False)
    num_iterations = 100

    kmeans.train(input_fn=input_fn, steps=num_iterations)
    print('score:', kmeans.score(input_fn))

    # map the input points to their clusters
    cluster_indices = list(kmeans.predict_cluster_index(input_fn))

    return kmeans, cluster_indices

我似乎无法使模型在gpu上运行。我尝试这样做:

with tf.device('/gpu:0'):
    kmeans.train(input_fn)

,但它仍在cpu上运行。 当我运行tf.test.gpu_device_name()时,我得到了

`2018-07-18 15:28:47.105452:我c:\ users \ user \ source \ repos \ tensorflow \ tensorflow \ core \ common_runtime \ gpu \ gpu_device.cc:1392]找到具有以下属性的设备0: 名称:Tesla M60主要:5个次要:2个内存时钟频率(GHz):1.1775 pciBusID:0000:00:1e.0 totalMemory:7.44GiB空闲内存:7.12GiB 2018-07-18 15:28:47.105846:我c:\ users \ user \ source \ repos \ tensorflow \ tensorflow \ core \ common_runtime \ gpu \ gpu_device.cc:1471]添加可见的gpu设备:0 2018-07-18 15:28:47.852029:我c:\ users \ user \ source \ repos \ tensorflow \ tensorflow \ core \ common_runtime \ gpu \ gpu_device.cc:952]具有强度1边缘矩阵的设备互连StreamExecutor: 2018-07-18 15:28:47.852268:我c:\ users \ user \ source \ repos \ tensorflow \ tensorflow \ core \ common_runtime \ gpu \ gpu_device.cc:958] 0 2018-07-18 15:28:47.852440:我c:\ users \ user \ source \ repos \ tensorflow \ tensorflow \ core \ common_runtime \ gpu \ gpu_device.cc:971] 0:N 2018-07-18 15:28:47.852859:我c:\ users \ user \ source \ repos \ tensorflow \ tensorflow \ core \ common_runtime \ gpu \ gpu_device.cc:1084]创建了TensorFlow设备(/ device:GPU:0与6874 MB内存)->物理GPU(设备:0,名称:Tesla M60,pci总线ID:0000:00:1e.0,计算能力:5.2

0 个答案:

没有答案