281 lines
12 KiB
Python
281 lines
12 KiB
Python
"""
|
|
ONNX Runtime 기반 OCR 백엔드
|
|
PaddleOCR 모델을 ONNX 형식으로 변환하여 ARM에서 실행
|
|
"""
|
|
|
|
import os
|
|
import cv2
|
|
import numpy as np
|
|
import logging
|
|
from typing import List, Tuple, Dict, Any
|
|
|
|
class ONNXRuntimeOCR:
|
|
"""ONNX Runtime을 사용한 PaddleOCR 호환 클래스"""
|
|
|
|
def __init__(self, use_gpu=False, use_angle_cls=True, lang='ch',
|
|
det_model_dir=None, rec_model_dir=None, cls_model_dir=None,
|
|
logger=None, **kwargs):
|
|
self.logger = logger
|
|
self.use_gpu = use_gpu
|
|
self.use_angle_cls = use_angle_cls
|
|
self.lang = lang
|
|
|
|
# ONNX 모델 세션
|
|
self.det_session = None
|
|
self.rec_session = None
|
|
self.cls_session = None
|
|
|
|
try:
|
|
import onnxruntime as ort
|
|
self.ort = ort
|
|
|
|
# ARM 최적화 설정
|
|
providers = ['CPUExecutionProvider']
|
|
if use_gpu:
|
|
# ARM에서 GPU 지원이 제한적이므로 CPU 사용 권장
|
|
providers = ['CPUExecutionProvider']
|
|
|
|
# 세션 옵션 설정 (ARM 최적화)
|
|
self.sess_options = ort.SessionOptions()
|
|
self.sess_options.inter_op_num_threads = 4 # ARM 최적화
|
|
self.sess_options.intra_op_num_threads = 4
|
|
self.sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
|
|
|
# 모델 로드
|
|
self._load_onnx_models(det_model_dir, rec_model_dir, cls_model_dir, providers)
|
|
|
|
if self.logger:
|
|
self.logger.log("✅ ONNX Runtime OCR 초기화 성공 (ARM 최적화)", level=logging.INFO)
|
|
|
|
except ImportError:
|
|
if self.logger:
|
|
self.logger.log("❌ ONNX Runtime 모듈을 찾을 수 없습니다", level=logging.ERROR)
|
|
raise ImportError("ONNX Runtime이 설치되지 않았습니다: pip install onnxruntime")
|
|
except Exception as e:
|
|
if self.logger:
|
|
self.logger.log(f"❌ ONNX Runtime OCR 초기화 실패: {e}", level=logging.ERROR)
|
|
raise
|
|
|
|
def _load_onnx_models(self, det_model_dir, rec_model_dir, cls_model_dir, providers):
|
|
"""ONNX 모델 로드"""
|
|
try:
|
|
# Detection 모델
|
|
if det_model_dir:
|
|
det_onnx_path = os.path.join(det_model_dir, "model.onnx")
|
|
if os.path.exists(det_onnx_path):
|
|
self.det_session = self.ort.InferenceSession(
|
|
det_onnx_path,
|
|
sess_options=self.sess_options,
|
|
providers=providers
|
|
)
|
|
self.det_input_name = self.det_session.get_inputs()[0].name
|
|
self.det_output_names = [output.name for output in self.det_session.get_outputs()]
|
|
|
|
# Recognition 모델
|
|
if rec_model_dir:
|
|
rec_onnx_path = os.path.join(rec_model_dir, "model.onnx")
|
|
if os.path.exists(rec_onnx_path):
|
|
self.rec_session = self.ort.InferenceSession(
|
|
rec_onnx_path,
|
|
sess_options=self.sess_options,
|
|
providers=providers
|
|
)
|
|
self.rec_input_name = self.rec_session.get_inputs()[0].name
|
|
self.rec_output_names = [output.name for output in self.rec_session.get_outputs()]
|
|
|
|
# Classification 모델
|
|
if self.use_angle_cls and cls_model_dir:
|
|
cls_onnx_path = os.path.join(cls_model_dir, "model.onnx")
|
|
if os.path.exists(cls_onnx_path):
|
|
self.cls_session = self.ort.InferenceSession(
|
|
cls_onnx_path,
|
|
sess_options=self.sess_options,
|
|
providers=providers
|
|
)
|
|
self.cls_input_name = self.cls_session.get_inputs()[0].name
|
|
self.cls_output_names = [output.name for output in self.cls_session.get_outputs()]
|
|
|
|
except Exception as e:
|
|
if self.logger:
|
|
self.logger.log(f"ONNX 모델 로드 실패: {e}", level=logging.ERROR)
|
|
raise
|
|
|
|
def _preprocess_det_image(self, img):
|
|
"""Detection 전처리"""
|
|
# PaddleOCR Detection 전처리 로직
|
|
h, w = img.shape[:2]
|
|
|
|
# 크기 조정 (보통 960x960 또는 640x640)
|
|
target_size = 960
|
|
ratio = target_size / max(h, w)
|
|
new_h, new_w = int(h * ratio), int(w * ratio)
|
|
|
|
img_resized = cv2.resize(img, (new_w, new_h))
|
|
|
|
# 패딩 추가
|
|
padded_img = np.zeros((target_size, target_size, 3), dtype=np.uint8)
|
|
padded_img[:new_h, :new_w] = img_resized
|
|
|
|
# 정규화 (RGB 변환 및 스케일링)
|
|
img_rgb = cv2.cvtColor(padded_img, cv2.COLOR_BGR2RGB)
|
|
img_norm = img_rgb.astype(np.float32) / 255.0
|
|
|
|
# 채널 순서 변경 (HWC -> CHW)
|
|
img_chw = np.transpose(img_norm, (2, 0, 1))
|
|
|
|
# 배치 차원 추가
|
|
img_batch = np.expand_dims(img_chw, axis=0)
|
|
|
|
return img_batch, ratio
|
|
|
|
def _postprocess_det_output(self, output, ratio, original_shape):
|
|
"""Detection 후처리"""
|
|
# ONNX 출력을 PaddleOCR 형식으로 변환
|
|
boxes = []
|
|
|
|
# 임계값 적용 및 컨투어 찾기 (간단한 구현)
|
|
# 실제로는 더 복잡한 후처리가 필요
|
|
if len(output) > 0 and len(output[0]) > 0:
|
|
heatmap = output[0][0] # 첫 번째 출력 채널
|
|
|
|
# 이진화
|
|
_, binary = cv2.threshold((heatmap * 255).astype(np.uint8), 127, 255, cv2.THRESH_BINARY)
|
|
|
|
# 컨투어 찾기
|
|
contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
|
|
|
h_orig, w_orig = original_shape[:2]
|
|
|
|
for contour in contours:
|
|
if cv2.contourArea(contour) < 100: # 최소 영역 필터
|
|
continue
|
|
|
|
# 바운딩 박스 계산
|
|
rect = cv2.minAreaRect(contour)
|
|
box = cv2.boxPoints(rect)
|
|
|
|
# 좌표 스케일 복원
|
|
box[:, 0] = box[:, 0] / ratio
|
|
box[:, 1] = box[:, 1] / ratio
|
|
|
|
# 이미지 경계 클리핑
|
|
box[:, 0] = np.clip(box[:, 0], 0, w_orig)
|
|
box[:, 1] = np.clip(box[:, 1], 0, h_orig)
|
|
|
|
boxes.append(box)
|
|
|
|
return boxes
|
|
|
|
def ocr(self, img, det=True, rec=True, cls=True):
|
|
"""
|
|
PaddleOCR과 호환되는 OCR 메서드
|
|
"""
|
|
try:
|
|
# 이미지 전처리
|
|
if hasattr(img, 'save'): # PIL Image
|
|
img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
|
|
elif isinstance(img, str): # 파일 경로
|
|
img = cv2.imread(img)
|
|
|
|
results = []
|
|
|
|
if not det and not rec:
|
|
return results
|
|
|
|
# Detection 수행
|
|
boxes = []
|
|
if det and self.det_session:
|
|
try:
|
|
# 전처리
|
|
det_input, ratio = self._preprocess_det_image(img)
|
|
|
|
# 추론
|
|
det_outputs = self.det_session.run(
|
|
self.det_output_names,
|
|
{self.det_input_name: det_input}
|
|
)
|
|
|
|
# 후처리
|
|
boxes = self._postprocess_det_output(det_outputs, ratio, img.shape)
|
|
|
|
if not rec:
|
|
# Detection만 수행하는 경우
|
|
for box in boxes:
|
|
results.append([box.tolist()])
|
|
return results
|
|
|
|
except Exception as e:
|
|
if self.logger:
|
|
self.logger.log(f"ONNX Detection 실패: {e}", level=logging.WARNING)
|
|
boxes = []
|
|
|
|
# Recognition 수행
|
|
if rec and self.rec_session:
|
|
if boxes:
|
|
# Detection + Recognition
|
|
for box in boxes:
|
|
try:
|
|
# 텍스트 영역 추출
|
|
x_coords = box[:, 0] if isinstance(box, np.ndarray) else [p[0] for p in box]
|
|
y_coords = box[:, 1] if isinstance(box, np.ndarray) else [p[1] for p in box]
|
|
x_min, x_max = int(min(x_coords)), int(max(x_coords))
|
|
y_min, y_max = int(min(y_coords)), int(max(y_coords))
|
|
|
|
text_region = img[y_min:y_max, x_min:x_max]
|
|
|
|
if text_region.size > 0:
|
|
# Recognition 전처리 (간단한 구현)
|
|
rec_img = cv2.resize(text_region, (100, 32)) # 표준 크기
|
|
rec_img = rec_img.astype(np.float32) / 255.0
|
|
rec_img = np.transpose(rec_img, (2, 0, 1))
|
|
rec_input = np.expand_dims(rec_img, axis=0)
|
|
|
|
# Recognition 추론
|
|
rec_outputs = self.rec_session.run(
|
|
self.rec_output_names,
|
|
{self.rec_input_name: rec_input}
|
|
)
|
|
|
|
# 간단한 후처리 (실제로는 더 복잡함)
|
|
text = "detected_text" # 실제 구현 필요
|
|
confidence = 0.8
|
|
|
|
box_list = box.tolist() if isinstance(box, np.ndarray) else box
|
|
results.append([box_list, (text, confidence)])
|
|
|
|
except Exception as e:
|
|
if self.logger:
|
|
self.logger.log(f"ONNX Recognition 실패: {e}", level=logging.WARNING)
|
|
continue
|
|
else:
|
|
# Recognition만 수행하는 경우
|
|
try:
|
|
# 전체 이미지 Recognition
|
|
rec_img = cv2.resize(img, (100, 32))
|
|
rec_img = rec_img.astype(np.float32) / 255.0
|
|
rec_img = np.transpose(rec_img, (2, 0, 1))
|
|
rec_input = np.expand_dims(rec_img, axis=0)
|
|
|
|
rec_outputs = self.rec_session.run(
|
|
self.rec_output_names,
|
|
{self.rec_input_name: rec_input}
|
|
)
|
|
|
|
text = "full_image_text" # 실제 구현 필요
|
|
confidence = 0.8
|
|
|
|
h, w = img.shape[:2]
|
|
bbox = [[0, 0], [w, 0], [w, h], [0, h]]
|
|
results.append([bbox, (text, confidence)])
|
|
|
|
except Exception as e:
|
|
if self.logger:
|
|
self.logger.log(f"ONNX 전체 이미지 Recognition 실패: {e}", level=logging.WARNING)
|
|
|
|
return results
|
|
|
|
except Exception as e:
|
|
if self.logger:
|
|
self.logger.log(f"ONNX OCR 추론 실패: {e}", level=logging.ERROR)
|
|
return []
|