103 lines
4.4 KiB
Python
103 lines
4.4 KiB
Python
import os
|
|
import cv2
|
|
import numpy as np
|
|
from PIL import Image
|
|
import logging
|
|
|
|
class PPMattingBackgroundRemovalModule:
|
|
"""
|
|
PaddleHub ppmatting 기반 배경제거 모듈
|
|
- CPU/GPU 모두 지원
|
|
- 최초 1회 모델 자동 다운로드
|
|
- 인물, 사물 등 모두 고품질 분리 가능
|
|
"""
|
|
|
|
SUPPORTED_MODELS = {
|
|
"modnet_mobilenetv2_matting": "가장 가볍고 빠름 | 매우 빠름(0.5~1s/이미지) | 실시간 썸네일, 대량 처리",
|
|
"modnet_resnet50vd_matting": "균형형, 성능 우수 | 빠름~보통(1~2s/이미지) | 상품 이미지, 경계 중요",
|
|
"modnet_hrnet18_matting": "품질 우수, 중간 속도 | 보통(2~3s/이미지) | 인물/상품 정밀 분리",
|
|
"gfm_resnet34_matting": "균형 맞춘 품질+속도 | 빠름~보통 | 일반 상품/사물 이미지",
|
|
"dim_vgg16_matting": "중간 품질, 중간 속도 | 보통 | 예제/비교 테스트용",
|
|
}
|
|
def __init__(self, logger=None, default_model="modnet_resnet50vd_matting"):
|
|
self.logger = logger
|
|
self.default_model = default_model
|
|
self.modules = {} # 모델별 paddlehub 모듈 캐시
|
|
|
|
def get_module(self, model_name):
|
|
"""
|
|
paddlehub ppmatting 모델 캐싱 후 반환
|
|
"""
|
|
if model_name not in self.modules:
|
|
import paddlehub as hub
|
|
self.modules[model_name] = hub.Module(name=model_name)
|
|
if self.logger:
|
|
self.logger.log(f"ppmatting 모델 로드됨: {model_name}")
|
|
return self.modules[model_name]
|
|
|
|
def set_default_model(self, model_name):
|
|
"""
|
|
기본 배경제거 모델명 변경
|
|
"""
|
|
if model_name not in self.SUPPORTED_MODELS:
|
|
raise ValueError(f"지원하지 않는 모델명: {model_name}")
|
|
self.default_model = model_name
|
|
if self.logger:
|
|
self.logger.log(f"기본 배경제거 모델이 '{model_name}'(으)로 변경됨")
|
|
|
|
def get_default_model(self):
|
|
return self.default_model
|
|
|
|
def get_supported_models(self):
|
|
return self.SUPPORTED_MODELS.copy()
|
|
|
|
def get_model_description(self, model_name):
|
|
return self.SUPPORTED_MODELS.get(model_name, "모델 설명 없음")
|
|
|
|
def to_white_background(self, img: Image.Image) -> Image.Image:
|
|
"""
|
|
알파(투명) 부분을 흰색으로 합성해서 RGB 이미지로 반환
|
|
Args:
|
|
img: PIL.Image (RGBA or BGRA)
|
|
Returns:
|
|
PIL.Image (RGB, 배경이 흰색)
|
|
"""
|
|
if img.mode in ("RGBA", "BGRA"):
|
|
bg = Image.new("RGB", img.size, (255, 255, 255))
|
|
bg.paste(img, mask=img.split()[-1]) # 알파채널로 합성
|
|
return bg
|
|
else:
|
|
return img.convert("RGB")
|
|
|
|
def remove_background(self, image_path, model_name=None):
|
|
if not os.path.exists(image_path):
|
|
self.logger.log(f"입력 이미지 없음: {image_path}", level=logging.ERROR); return None
|
|
img = cv2.imread(image_path)
|
|
if img is None:
|
|
self.logger.log(f"이미지 로드 실패: {image_path}", level=logging.ERROR); return None
|
|
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
|
model_name = model_name or self.default_model
|
|
if model_name not in self.SUPPORTED_MODELS:
|
|
self.logger.log(f"지원되지 않는 모델명: {model_name}", level=logging.ERROR); return None
|
|
try:
|
|
module = self.get_module(model_name)
|
|
results = module.predict(
|
|
image_list=[img_rgb],
|
|
trimap_list=None,
|
|
visualization=False,
|
|
save_path=None
|
|
)
|
|
self.logger.log(f"results type: {type(results)}, len={len(results)}", level=logging.DEBUG)
|
|
alpha = results[0]
|
|
if not isinstance(alpha, np.ndarray) or alpha.ndim != 2:
|
|
self.logger.log("매팅 결과가 유효한 알파마스크가 아닙니다", level=logging.ERROR); return None
|
|
|
|
alpha_img = alpha if alpha.dtype == np.uint8 else (alpha * 255).astype(np.uint8)
|
|
img_bgra = cv2.cvtColor(img, cv2.COLOR_BGR2BGRA)
|
|
img_bgra[...,3] = alpha_img
|
|
pil_img = Image.fromarray(cv2.cvtColor(img_bgra, cv2.COLOR_BGRA2RGBA))
|
|
return pil_img
|
|
except Exception as e:
|
|
self.logger.log(f"배경제거 예외 발생: {e}", level=logging.ERROR, exc_info=True)
|
|
return None
|