133 lines
5.3 KiB
Python
133 lines
5.3 KiB
Python
"""
|
|
OCR + Mask + Simple-Lama Inpainting 파이프라인 예제
|
|
====================================================
|
|
1) OCRModule 로 텍스트 탐지 (GPU FastDeploy)
|
|
2) MaskModule 로 탐지된 영역 마스킹 (옵션: 기본)
|
|
3) simple-lama-inpainting 모델로 인페인팅 수행
|
|
- CUDA 사용 가능 시 GPU, 없으면 CPU
|
|
4) 원본, 마스크, 인페인팅 이미지를 모두 저장
|
|
|
|
사용법 (프로젝트 루트 기준):
|
|
python ocr_inpaint_pipeline.py --image img/1.jpg \
|
|
--mask-out img/1_mask.png --inpaint-out img/1_inpaint.jpg
|
|
|
|
필요 패키지:
|
|
pip install simple-lama-inpainting shapely opencv-python fastdeploy-python --extra-index-url https://www.paddlepaddle.org.cn/whl/mkl/avx/stable.html
|
|
"""
|
|
|
|
import argparse
|
|
import logging
|
|
import os
|
|
from pathlib import Path
|
|
|
|
import cv2
|
|
import numpy as np
|
|
from PIL import Image
|
|
|
|
from ocr_module import OCRModule
|
|
from mask_module import MaskModule
|
|
|
|
try:
|
|
from simple_lama_inpainting.models.model import SimpleLama
|
|
except ImportError as e:
|
|
raise ImportError("simple-lama-inpainting 패키지가 설치되어 있지 않습니다. 'pip install simple-lama-inpainting' 명령으로 설치해 주세요.") from e
|
|
|
|
|
|
class SimpleLogger:
|
|
"""OCRModule / MaskModule 과 동일한 인터페이스를 갖는 간단 로거"""
|
|
|
|
def log(self, msg, level=logging.INFO, exc_info: bool = False):
|
|
if level == logging.ERROR:
|
|
print("[ERROR]", msg)
|
|
elif level == logging.WARNING:
|
|
print("[WARN]", msg)
|
|
else:
|
|
print("[INFO]", msg)
|
|
if exc_info:
|
|
import traceback
|
|
traceback.print_exc()
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="OCR → Mask → Simple-Lama Inpainting 파이프라인")
|
|
parser.add_argument("--image", "-i", required=True, help="원본 이미지 경로")
|
|
parser.add_argument("--method", default="polygon", choices=[
|
|
"polygon", "bbox", "expanded_bbox", "rotated_bbox", "contour"
|
|
], help="OCR 탐지 방식")
|
|
parser.add_argument("--mask-out", default=None, help="생성된 마스크 저장 경로 (png)")
|
|
parser.add_argument("--inpaint-out", default=None, help="인페인팅 결과 저장 경로 (jpg/png)")
|
|
parser.add_argument("--filter", choices=["all", "chinese", "korean"], default="all",
|
|
help="OCR 결과 필터링 옵션")
|
|
parser.add_argument("--mask-option", default="basic", choices=["basic", "processed"],
|
|
help="MaskModule 에 전달할 mask_option 값")
|
|
parser.add_argument("--expansion", type=int, default=0, help="Mask 확장 정도 (process_mask 사용 시)")
|
|
parser.add_argument("--blur", type=int, default=0, help="Mask 블러 크기 (process_mask 사용 시)")
|
|
args = parser.parse_args()
|
|
|
|
img_path = Path(args.image)
|
|
if not img_path.is_file():
|
|
raise FileNotFoundError(f"이미지 파일이 존재하지 않습니다: {img_path}")
|
|
|
|
# 저장 경로 기본값 설정
|
|
if args.mask_out is None:
|
|
args.mask_out = str(img_path.with_stem(img_path.stem + "_mask").with_suffix(".png"))
|
|
if args.inpaint_out is None:
|
|
args.inpaint_out = str(img_path.with_stem(img_path.stem + "_inpaint").with_suffix(".jpg"))
|
|
|
|
logger = SimpleLogger()
|
|
base_dir = os.getcwd() # 프로젝트 루트
|
|
|
|
# 1. OCR 수행 ------------------------------------------------------------
|
|
ocr_module = OCRModule(logger=logger, base_dir=base_dir)
|
|
ocr_results = ocr_module.detect_text(str(img_path), method=args.method)
|
|
|
|
if args.filter == "chinese":
|
|
ocr_results = ocr_module.filter_chinese_text(ocr_results)
|
|
elif args.filter == "korean":
|
|
ocr_results = ocr_module.filter_korean_text(ocr_results)
|
|
logger.log(f"OCR 결과 최종 개수: {len(ocr_results)}", level=logging.INFO)
|
|
|
|
if len(ocr_results) == 0:
|
|
logger.log("경고: OCR 결과가 없어 인페인팅을 건너뜁니다.", level=logging.WARNING)
|
|
return
|
|
|
|
# 2. 마스크 생성 ----------------------------------------------------------
|
|
mask_module = MaskModule(logger=logger, base_dir=base_dir)
|
|
if args.mask_option == "basic":
|
|
mask = mask_module.create_masks(str(img_path), ocr_results, mask_option="basic")
|
|
else:
|
|
mask = mask_module.create_masks(
|
|
str(img_path), ocr_results, expansion_size=args.expansion, blur_size=args.blur, mask_option="processed"
|
|
)
|
|
|
|
if mask is None:
|
|
logger.log("마스크 생성 실패", level=logging.ERROR)
|
|
return
|
|
|
|
# 저장 (255: 인페인팅 대상)
|
|
cv2.imwrite(args.mask_out, mask)
|
|
logger.log(f"마스크 이미지 저장 완료 → {args.mask_out}", level=logging.INFO)
|
|
|
|
# 3. Simple-Lama Inpainting --------------------------------------------
|
|
# PIL 이미지 준비 (RGB, L)
|
|
original_pil = Image.open(img_path).convert("RGB")
|
|
|
|
mask_pil = Image.fromarray(mask).convert("L")
|
|
|
|
overlay = cv2.imread("img/1.jpg")
|
|
poly = np.where(mask==255)
|
|
overlay[poly] = (0,0,255) # 마스크 영역 표시
|
|
cv2.imwrite("img/1_overlay.jpg", overlay)
|
|
|
|
|
|
lama_model = SimpleLama()
|
|
result_pil = lama_model(original_pil, mask_pil)
|
|
|
|
# 결과 저장
|
|
Path(args.inpaint_out).parent.mkdir(parents=True, exist_ok=True)
|
|
result_pil.save(args.inpaint_out)
|
|
logger.log(f"인페인팅 이미지 저장 완료 → {args.inpaint_out}", level=logging.INFO)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main() |