我知道这是一个反复出现的问题,我知道一些解决方案,特别是如果此问题在训练期间发生(减小批处理大小,设置gpu选项allow_grouth=True
),但我正面临着这个问题使用已经训练有素的模型预测结果时。
因此,我能够训练模型(来自TensorFlow Object Detection Model Zoo的Faster RCNN,批处理大小为1,否则在训练过程中会收到OOM错误)。
要应用经过训练的模型,请使用以下代码加载它:
class Model:
def __init__(self, conf):
self.threshold = conf["threshold"]
self.gpu = conf['gpu']
self.scope = conf['scope']
self.frozen_inference_graph = conf['frozen_inference_graph']
self.detection_graph = self.load_model(self.gpu, self.scope, self.frozen_inference_graph)
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=conf['gpu_usage'], allocator_type='BFC')
self.session = tf.Session(graph=self.detection_graph,
config=tf.ConfigProto(gpu_options=gpu_options,
log_device_placement=True,
allow_soft_placement=True)
def load_model(gpu, scope, frozen_inference_graph):
detection_graph = tf.Graph()
with detection_graph.device(gpu):
with detection_graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile(frozen_inference_graph, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name=scope)
return detection_graph
我的conf文件如下所示:
{
"frozen_inference_graph": "/config/plate_detector/faster_rcnn/frozen_inference_graph.pb",
"gpu": "/gpu:0",
"gpu_usage": 0.05,
"threshold": 0.5,
"scope": "testing"
}
要应用模型,类Model
具有以下方法:
def detect(self, image):
image_tensor = self.detection_graph.get_tensor_by_name(self.scope + "/image_tensor:0")
detection_boxes = self.detection_graph.get_tensor_by_name(self.scope + "/detection_boxes:0")
detection_scores = self.detection_graph.get_tensor_by_name(self.scope + "/detection_scores:0")
detection_classes = self.detection_graph.get_tensor_by_name(self.scope + "/detection_classes:0")
num_detections = self.detection_graph.get_tensor_by_name(self.scope + "/num_detections:0")
height, width = image.shape[:2]
image_np = image.copy()
# Expands dimensions since the model expects images to have shape: [1, None, None, 3]
image_np_expanded = np.expand_dims(image_np, axis=0)
# Actual detection
boxes, scores, classes, num = self.session.run(
[detection_boxes, detection_scores, detection_classes, num_detections],
feed_dict={image_tensor: image_np_expanded}
)
output_dict = dict()
output_dict["num_dets"] = int(num[0])
output_dict["classes"] = classes[0].astype(np.uint8)
output_dict["bboxes"] = boxes[0]
output_dict["confidences"] = scores[0]
detections = []
for i, b in enumerate(output_dict["bboxes"]):
if output_dict["confidences"][i] >= self.threshold:
x1 = int(output_dict["bboxes"][i][1] * width)
y1 = int(output_dict["bboxes"][i][0] * height)
x2 = int(output_dict["bboxes"][i][3] * width)
y2 = int(output_dict["bboxes"][i][2] * height)
label = self.class_mapping.get(output_dict["classes"][i].astype(str))
confidence = output_dict["confidences"][i]
detections.append(detection.Detection([x1, y1, x2, y2, confidence, label]))
# Sort based on x1 to return the right order
detections = sorted(detections, key=lambda x: x.x1)
return detections
在我的main.py
文件中,我只是做(伪代码):
model = Model(*path to config*)
images = fetch_images()
for image in images:
detections = model.detect(image)
// processing over the detections
但是,有时,在过程的中间,我会收到一条消息,例如:
2019-10-16 21:42:32.394227:W tensorflow / core / common_runtime / bfc_allocator.cc:314]分配器(GPU_0_bfc)内存不足,试图分配76.56MiB(四舍五入为80281600)。当前分配摘要如下。
这种情况并非一直发生,似乎是随机发生的,我不明白。
我尝试使用allow_growth=True
(这会在我加载模型时发生错误)以及allocator_type='BFC'
(这并没有帮助(OOM仍然随机发生))时发生错误。
watch -n 1 nvidia-smi
看起来像这样:
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 435.21 Driver Version: 435.21 CUDA Version: 10.1 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
|===============================+======================+======================|
| 0 Quadro RTX 4000 Off | 00000000:01:00.0 On | N/A |
| 30% 40C P0 96W / 125W | 1589MiB / 7979MiB | 83% Default |
+-------------------------------+----------------------+----------------------+