特征提取期间的GPU性能(Tesla K80)

时间:2018-07-16 08:04:28

标签: python-3.x keras gpu feature-extraction tesla

我正在使用以下代码从大约30个类别的大约4000张图像中提取特征。

 for i, label in enumerate(train_labels):
        cur_path = train_path + "/" + label
        count = 1
        for image_path in glob.glob(cur_path + "/*.jpg"):
            img = image.load_img(image_path, target_size=image_size)
            x = image.img_to_array(img)
            x = np.expand_dims(x, axis=0)
            x = preprocess_input(x)
            feature = model.predict(x)
            flat = feature.flatten()
            features.append(flat)
            labels.append(label)
            print ("[INFO] processed - " + str(count))
        count += 1
    print ("[INFO] completed label - " + label)

尽管如此,我的整个数据集都更大,最多可容纳80,000张图像。当查看我的GPU内存时,这部分代码在Keras(2.1.2)中适用于4000张图像,但几乎占用了我Tesla G80的所有5gig视频RAM。我想知道是否可以通过更改batch_size来提高性能,或者该代码的工作方式对我的GPU来说太重了,我应该重写它吗?

谢谢!

1 个答案:

答案 0 :(得分:1)

有两种可能的解决方案。

1)我假设您是以Numpy数组格式存储图像。这是非常占用内存的。而是将其存储为普通列表。当应用程序需要时,将其转换为numpy数组。就我而言,它减少了10倍的内存消耗。如果您已经将其存储为列表,则2种解决方案可能会解决您的问题。

2)将结果存储在块中,并在将其输入另一个模型时使用生成器。

chunk_of_features=[]
chunk_of_labels=[]
i=0
for i, label in enumerate(train_labels):
        cur_path = train_path + "/" + label
        count = 1
        for image_path in glob.glob(cur_path + "/*.jpg"):
            i+=1
            img = image.load_img(image_path, target_size=image_size)
            x = image.img_to_array(img)
            x = np.expand_dims(x, axis=0)
            x = preprocess_input(x)
            feature = model.predict(x)
            flat = feature.flatten()
            chunk_of_features.append(flat)
            chunk_of_labels.append(label)
            if i==4000:
                with open('useSomeCountertoPreventNameConflict','wb') as output_file:
                    pickle.dump(chunk_of_features,output_file)
                with open('useSomeCountertoPreventNameConflict','wb') as output_file:
                    pickle.dump(chunk_of_labels,output_file)
                chunk_of_features=[]
                chunk_of_labels=[]
                i=0

            print ("[INFO] processed - " + str(count))
        count += 1
    print ("[INFO] completed label - " + label)