141 lines
5.6 KiB
Python
141 lines
5.6 KiB
Python
import os
|
|
import logging
|
|
from typing import List, Dict, Any
|
|
|
|
import cv2
|
|
import numpy as np
|
|
import fastdeploy as fd
|
|
|
|
|
|
class FastOCRModule:
|
|
"""FastDeploy 기반 간단 OCR 모듈 (PPOCRv3 시스템)
|
|
|
|
기존 ocr_module.OCRModule 과 동일한 입출력 구조를 유지하면서
|
|
FastDeploy 의 엔드투엔드 API(PPOCRv3)를 사용해 전처리 · 후처리를
|
|
내부에서 자동 처리합니다.
|
|
"""
|
|
|
|
def __init__(self, logger: logging.Logger | None = None, base_dir: str | None = None):
|
|
self.logger = logger or logging.getLogger(__name__)
|
|
self.base_dir = base_dir or os.getcwd()
|
|
|
|
self._system = self._initialize_system()
|
|
if self._system is None:
|
|
raise RuntimeError("FastDeploy OCR 초기화 실패")
|
|
|
|
# ------------------------------------------------------------------
|
|
# 내부 초기화 -------------------------------------------------------
|
|
# ------------------------------------------------------------------
|
|
def _initialize_system(self):
|
|
try:
|
|
det_dir = os.path.join(self.base_dir, "modules", "PP_Models", "det")
|
|
rec_dir = os.path.join(self.base_dir, "modules", "PP_Models", "rec")
|
|
cls_dir = os.path.join(self.base_dir, "modules", "PP_Models", "cls")
|
|
|
|
# Runtime (GPU)
|
|
opt = fd.RuntimeOption()
|
|
opt.use_gpu()
|
|
|
|
# Detector
|
|
det = fd.vision.ocr.DBDetector(
|
|
os.path.join(det_dir, "inference.pdmodel"),
|
|
os.path.join(det_dir, "inference.pdiparams"),
|
|
runtime_option=opt,
|
|
)
|
|
|
|
# Recognizer
|
|
label_file = None
|
|
for cand in ("ppocr_keys_v1.txt", "dict.txt"):
|
|
tmp = os.path.join(rec_dir, cand)
|
|
if os.path.isfile(tmp):
|
|
label_file = tmp
|
|
break
|
|
if label_file is None:
|
|
raise FileNotFoundError("인식 라벨 파일을 찾을 수 없습니다.")
|
|
|
|
rec = fd.vision.ocr.Recognizer(
|
|
os.path.join(rec_dir, "inference.pdmodel"),
|
|
os.path.join(rec_dir, "inference.pdiparams"),
|
|
label_file,
|
|
runtime_option=opt,
|
|
)
|
|
|
|
# Angle classifier (optional)
|
|
cls_model_file = os.path.join(cls_dir, "inference.pdmodel")
|
|
cls_params_file = os.path.join(cls_dir, "inference.pdiparams")
|
|
cls = None
|
|
if os.path.isfile(cls_model_file):
|
|
cls = fd.vision.ocr.Classifier(
|
|
cls_model_file, cls_params_file, runtime_option=opt
|
|
)
|
|
|
|
system = fd.vision.ocr.PPOCRv3(det, cls, rec) if cls else fd.vision.ocr.PPOCRv3(det, rec)
|
|
self.logger.info("✅ FastDeploy PPOCRv3 시스템 초기화 완료")
|
|
return system
|
|
except Exception as e:
|
|
self.logger.error(f"❌ PPOCRv3 초기화 실패: {e}", exc_info=True)
|
|
return None
|
|
|
|
# ------------------------------------------------------------------
|
|
# 퍼블릭 API -------------------------------------------------------
|
|
# ------------------------------------------------------------------
|
|
def detect_text(self, image_path: str) -> List[Dict[str, Any]]:
|
|
"""이미지 파일에서 텍스트 탐지 + 인식 결과를 반환.
|
|
|
|
반환 형식은 기존 ocr_module.detect_text 와 동일:
|
|
[
|
|
{
|
|
'text': str,
|
|
'confidence': float,
|
|
'polygon': list[list[int, int]], # 4개 꼭짓점
|
|
'bbox': (x, y, w, h),
|
|
},
|
|
...
|
|
]
|
|
"""
|
|
if not os.path.isfile(image_path):
|
|
self.logger.error(f"이미지 파일이 존재하지 않습니다: {image_path}")
|
|
return []
|
|
|
|
img = cv2.imread(image_path)
|
|
if img is None:
|
|
self.logger.error(f"이미지를 읽을 수 없습니다: {image_path}")
|
|
return []
|
|
|
|
try:
|
|
res = self._system.predict(img) # OCRResult object
|
|
if res is None or len(res.text) == 0:
|
|
self.logger.warning("OCR 결과가 비어있음")
|
|
return []
|
|
|
|
out: List[Dict[str, Any]] = []
|
|
for poly, txt, score in zip(res.boxes, res.text, res.rec_scores):
|
|
poly = np.array(poly).astype(int).tolist()
|
|
x, y, w, h = cv2.boundingRect(np.array(poly))
|
|
out.append(
|
|
{
|
|
"text": txt,
|
|
"confidence": float(score),
|
|
"polygon": poly,
|
|
"bbox": (int(x), int(y), int(w), int(h)),
|
|
}
|
|
)
|
|
return out
|
|
except Exception as e:
|
|
self.logger.error(f"OCR 예측 중 오류: {e}", exc_info=True)
|
|
return []
|
|
finally:
|
|
del img
|
|
|
|
# ------------------------------------------------------------------
|
|
# 후처리 유틸 -------------------------------------------------------
|
|
# ------------------------------------------------------------------
|
|
def filter_chinese_text(self, ocr_results: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
chinese = [r for r in ocr_results if any("\u4e00" <= ch <= "\u9fff" for ch in r["text"])]
|
|
self.logger.info(f"중국어 텍스트 {len(chinese)}개 필터링 완료")
|
|
return chinese
|
|
|
|
def filter_korean_text(self, ocr_results: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
korean = [r for r in ocr_results if any("\uac00" <= ch <= "\ud7a3" for ch in r["text"])]
|
|
self.logger.info(f"한글 텍스트 {len(korean)}개 필터링 완료")
|
|
return korean |