AutoPercenty3/test/paddle2onnx/ocr_infer_module.py

104 lines
3.9 KiB
Python

# ocr_infer_module.py
import os
import cv2
import time
import onnxruntime as ort
from paddleocr import PaddleOCR
from tools.infer.predict_system import TextSystem # PaddleOCR 내부 코드
class OCRInfer:
def __init__(self, use_onnx=False, onnx_dir=None, dict_path=None, ep="cpu", use_cls=True):
self.use_onnx = use_onnx
self.onnx_dir = onnx_dir
self.dict_path = dict_path
self.use_cls = use_cls
if not use_onnx:
# Paddle Inference 사용
self.ocr = PaddleOCR(use_angle_cls=use_cls, lang='ch') # 네 환경 맞게 lang 조정
self.mode = "Paddle"
else:
# ONNX Runtime 초기화
providers = {
"cpu": ["CPUExecutionProvider"],
"cuda": ["CUDAExecutionProvider", "CPUExecutionProvider"],
"trt": ["TensorrtExecutionProvider", "CUDAExecutionProvider", "CPUExecutionProvider"],
}
chosen_providers = providers.get(ep.lower(), ["CPUExecutionProvider"])
sess_opts = ort.SessionOptions()
sess_opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
self.det_sess = ort.InferenceSession(
os.path.join(onnx_dir, "det.onnx"), sess_options=sess_opts, providers=chosen_providers
)
self.rec_sess = ort.InferenceSession(
os.path.join(onnx_dir, "rec.onnx"), sess_options=sess_opts, providers=chosen_providers
)
if use_cls and os.path.exists(os.path.join(onnx_dir, "cls.onnx")):
self.cls_sess = ort.InferenceSession(
os.path.join(onnx_dir, "cls.onnx"), sess_options=sess_opts, providers=chosen_providers
)
else:
self.cls_sess = None
self.mode = f"ONNX ({ep})"
def run(self, img_path):
img = cv2.imread(img_path)
if img is None:
raise ValueError(f"이미지 로드 실패: {img_path}")
if not self.use_onnx:
# PaddleOCR 경로
t0 = time.time()
result = self.ocr.ocr(img, cls=self.use_cls)
t1 = time.time()
return result[0], (t1 - t0) * 1000
else:
# ONNX 경로
t0 = time.time()
# PaddleOCR 샘플 코드의 전/후처리 유틸리티 그대로 사용 권장
# (여기서는 단순히 구조만 보여줌)
det_input = self._preprocess_det(img)
det_out = self.det_sess.run(None, {"x": det_input})[0]
boxes = self._postprocess_det(det_out, img.shape)
rec_results = []
for box in boxes:
crop = self._get_rotate_crop(img, box)
rec_input = self._preprocess_rec(crop)
rec_out = self.rec_sess.run(None, {"x": rec_input})[0]
txt, conf = self._postprocess_rec(rec_out)
rec_results.append((txt, conf))
t1 = time.time()
return rec_results, (t1 - t0) * 1000
# --- 아래는 PaddleOCR 샘플 코드에 있는 유틸리티 그대로 붙여넣기 필요 ---
def _preprocess_det(self, img):
# predict_det.py 참고
import numpy as np
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = img.astype("float32").transpose(2, 0, 1)[None, :]
return img
def _postprocess_det(self, det_out, shape):
# predict_det.py 참고
return []
def _get_rotate_crop(self, img, box):
# utility.py get_rotate_crop_image 복붙
return img
def _preprocess_rec(self, img):
# predict_rec.py 참고
import numpy as np
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (320, 48))
img = img.astype("float32").transpose(2, 0, 1)[None, :]
return img
def _postprocess_rec(self, rec_out):
# predict_rec.py 참고 → softmax + argmax → 텍스트 디코딩
return "DUMMY", 1.0