104 lines
3.9 KiB
Python
104 lines
3.9 KiB
Python
# ocr_infer_module.py
|
|
import os
|
|
import cv2
|
|
import time
|
|
import onnxruntime as ort
|
|
from paddleocr import PaddleOCR
|
|
from tools.infer.predict_system import TextSystem # PaddleOCR 내부 코드
|
|
|
|
class OCRInfer:
|
|
def __init__(self, use_onnx=False, onnx_dir=None, dict_path=None, ep="cpu", use_cls=True):
|
|
self.use_onnx = use_onnx
|
|
self.onnx_dir = onnx_dir
|
|
self.dict_path = dict_path
|
|
self.use_cls = use_cls
|
|
|
|
if not use_onnx:
|
|
# Paddle Inference 사용
|
|
self.ocr = PaddleOCR(use_angle_cls=use_cls, lang='ch') # 네 환경 맞게 lang 조정
|
|
self.mode = "Paddle"
|
|
else:
|
|
# ONNX Runtime 초기화
|
|
providers = {
|
|
"cpu": ["CPUExecutionProvider"],
|
|
"cuda": ["CUDAExecutionProvider", "CPUExecutionProvider"],
|
|
"trt": ["TensorrtExecutionProvider", "CUDAExecutionProvider", "CPUExecutionProvider"],
|
|
}
|
|
chosen_providers = providers.get(ep.lower(), ["CPUExecutionProvider"])
|
|
sess_opts = ort.SessionOptions()
|
|
sess_opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
|
|
|
self.det_sess = ort.InferenceSession(
|
|
os.path.join(onnx_dir, "det.onnx"), sess_options=sess_opts, providers=chosen_providers
|
|
)
|
|
self.rec_sess = ort.InferenceSession(
|
|
os.path.join(onnx_dir, "rec.onnx"), sess_options=sess_opts, providers=chosen_providers
|
|
)
|
|
if use_cls and os.path.exists(os.path.join(onnx_dir, "cls.onnx")):
|
|
self.cls_sess = ort.InferenceSession(
|
|
os.path.join(onnx_dir, "cls.onnx"), sess_options=sess_opts, providers=chosen_providers
|
|
)
|
|
else:
|
|
self.cls_sess = None
|
|
|
|
self.mode = f"ONNX ({ep})"
|
|
|
|
def run(self, img_path):
|
|
img = cv2.imread(img_path)
|
|
if img is None:
|
|
raise ValueError(f"이미지 로드 실패: {img_path}")
|
|
|
|
if not self.use_onnx:
|
|
# PaddleOCR 경로
|
|
t0 = time.time()
|
|
result = self.ocr.ocr(img, cls=self.use_cls)
|
|
t1 = time.time()
|
|
return result[0], (t1 - t0) * 1000
|
|
else:
|
|
# ONNX 경로
|
|
t0 = time.time()
|
|
# PaddleOCR 샘플 코드의 전/후처리 유틸리티 그대로 사용 권장
|
|
# (여기서는 단순히 구조만 보여줌)
|
|
det_input = self._preprocess_det(img)
|
|
det_out = self.det_sess.run(None, {"x": det_input})[0]
|
|
boxes = self._postprocess_det(det_out, img.shape)
|
|
|
|
rec_results = []
|
|
for box in boxes:
|
|
crop = self._get_rotate_crop(img, box)
|
|
rec_input = self._preprocess_rec(crop)
|
|
rec_out = self.rec_sess.run(None, {"x": rec_input})[0]
|
|
txt, conf = self._postprocess_rec(rec_out)
|
|
rec_results.append((txt, conf))
|
|
|
|
t1 = time.time()
|
|
return rec_results, (t1 - t0) * 1000
|
|
|
|
# --- 아래는 PaddleOCR 샘플 코드에 있는 유틸리티 그대로 붙여넣기 필요 ---
|
|
def _preprocess_det(self, img):
|
|
# predict_det.py 참고
|
|
import numpy as np
|
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
|
img = img.astype("float32").transpose(2, 0, 1)[None, :]
|
|
return img
|
|
|
|
def _postprocess_det(self, det_out, shape):
|
|
# predict_det.py 참고
|
|
return []
|
|
|
|
def _get_rotate_crop(self, img, box):
|
|
# utility.py get_rotate_crop_image 복붙
|
|
return img
|
|
|
|
def _preprocess_rec(self, img):
|
|
# predict_rec.py 참고
|
|
import numpy as np
|
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
|
img = cv2.resize(img, (320, 48))
|
|
img = img.astype("float32").transpose(2, 0, 1)[None, :]
|
|
return img
|
|
|
|
def _postprocess_rec(self, rec_out):
|
|
# predict_rec.py 참고 → softmax + argmax → 텍스트 디코딩
|
|
return "DUMMY", 1.0
|