多GPU推理张量流

时间:2019-07-10 15:37:49

标签: tensorflow keras

我想使用Tensorflow / Keras执行多GPU推理

这是我的预测

 model = modellib.MaskRCNN(mode="inference", model_dir=MODEL_DIR, config=config)

 # Load weights trained on MS-COCO
 model.load_weights(COCO_MODEL_PATH, by_name=True)

 # COCO Class names
 # Index of the class in the list is its ID. For example, to get ID of
 # the teddy bear class, use: class_names.index('teddy bear')
 class_names = ['BG', 'person', 'bicycle', 'car', 'motorcycle', 'airplane',
                'bus', 'train', 'truck', 'boat', 'traffic light',
                'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird',
                'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear',
                'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',
                'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
                'kite', 'baseball bat', 'baseball glove', 'skateboard',
                'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup',
                'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
                'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
                'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed',
                'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
                'keyboard', 'cell phone', 'microwave', 'oven', 'toaster',
                'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors',
                'teddy bear', 'hair drier', 'toothbrush']


 # Load a random image from the images folder
 file_names = next(os.walk(IMAGE_DIR))[2]
 image = skimage.io.imread(os.path.join(IMAGE_DIR, random.choice(file_names)))

 # Run detection
 results = model.detect([image], verbose=1)

 # Visualize results
 r = results[0]

是否可以在多个GPU上运行此模型?

谢谢。

1 个答案:

答案 0 :(得分:1)

根据系统中GPU的数量增加GPU_COUNT,并在使用config创建模型时传递新的modellib.MaskRCNN

class InferenceConfig(coco.CocoConfig):
    GPU_COUNT = 1 # increase the GPU count based on number of GPUs
    IMAGES_PER_GPU = 1

config = InferenceConfig()
model = modellib.MaskRCNN(mode="inference", model_dir=MODEL_DIR, config=config)

https://github.com/matterport/Mask_RCNN/blob/master/samples/demo.ipynb

相关问题