IMG_Worker/modules/old_modules/background_removal_module_p...

103 lines
4.4 KiB
Python

import os
import cv2
import numpy as np
from PIL import Image
import logging
class PPMattingBackgroundRemovalModule:
"""
PaddleHub ppmatting 기반 배경제거 모듈
- CPU/GPU 모두 지원
- 최초 1회 모델 자동 다운로드
- 인물, 사물 등 모두 고품질 분리 가능
"""
SUPPORTED_MODELS = {
"modnet_mobilenetv2_matting": "가장 가볍고 빠름 | 매우 빠름(0.5~1s/이미지) | 실시간 썸네일, 대량 처리",
"modnet_resnet50vd_matting": "균형형, 성능 우수 | 빠름~보통(1~2s/이미지) | 상품 이미지, 경계 중요",
"modnet_hrnet18_matting": "품질 우수, 중간 속도 | 보통(2~3s/이미지) | 인물/상품 정밀 분리",
"gfm_resnet34_matting": "균형 맞춘 품질+속도 | 빠름~보통 | 일반 상품/사물 이미지",
"dim_vgg16_matting": "중간 품질, 중간 속도 | 보통 | 예제/비교 테스트용",
}
def __init__(self, logger=None, default_model="modnet_resnet50vd_matting"):
self.logger = logger
self.default_model = default_model
self.modules = {} # 모델별 paddlehub 모듈 캐시
def get_module(self, model_name):
"""
paddlehub ppmatting 모델 캐싱 후 반환
"""
if model_name not in self.modules:
import paddlehub as hub
self.modules[model_name] = hub.Module(name=model_name)
if self.logger:
self.logger.log(f"ppmatting 모델 로드됨: {model_name}")
return self.modules[model_name]
def set_default_model(self, model_name):
"""
기본 배경제거 모델명 변경
"""
if model_name not in self.SUPPORTED_MODELS:
raise ValueError(f"지원하지 않는 모델명: {model_name}")
self.default_model = model_name
if self.logger:
self.logger.log(f"기본 배경제거 모델이 '{model_name}'(으)로 변경됨")
def get_default_model(self):
return self.default_model
def get_supported_models(self):
return self.SUPPORTED_MODELS.copy()
def get_model_description(self, model_name):
return self.SUPPORTED_MODELS.get(model_name, "모델 설명 없음")
def to_white_background(self, img: Image.Image) -> Image.Image:
"""
알파(투명) 부분을 흰색으로 합성해서 RGB 이미지로 반환
Args:
img: PIL.Image (RGBA or BGRA)
Returns:
PIL.Image (RGB, 배경이 흰색)
"""
if img.mode in ("RGBA", "BGRA"):
bg = Image.new("RGB", img.size, (255, 255, 255))
bg.paste(img, mask=img.split()[-1]) # 알파채널로 합성
return bg
else:
return img.convert("RGB")
def remove_background(self, image_path, model_name=None):
if not os.path.exists(image_path):
self.logger.log(f"입력 이미지 없음: {image_path}", level=logging.ERROR); return None
img = cv2.imread(image_path)
if img is None:
self.logger.log(f"이미지 로드 실패: {image_path}", level=logging.ERROR); return None
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
model_name = model_name or self.default_model
if model_name not in self.SUPPORTED_MODELS:
self.logger.log(f"지원되지 않는 모델명: {model_name}", level=logging.ERROR); return None
try:
module = self.get_module(model_name)
results = module.predict(
image_list=[img_rgb],
trimap_list=None,
visualization=False,
save_path=None
)
self.logger.log(f"results type: {type(results)}, len={len(results)}", level=logging.DEBUG)
alpha = results[0]
if not isinstance(alpha, np.ndarray) or alpha.ndim != 2:
self.logger.log("매팅 결과가 유효한 알파마스크가 아닙니다", level=logging.ERROR); return None
alpha_img = alpha if alpha.dtype == np.uint8 else (alpha * 255).astype(np.uint8)
img_bgra = cv2.cvtColor(img, cv2.COLOR_BGR2BGRA)
img_bgra[...,3] = alpha_img
pil_img = Image.fromarray(cv2.cvtColor(img_bgra, cv2.COLOR_BGRA2RGBA))
return pil_img
except Exception as e:
self.logger.log(f"배경제거 예외 발생: {e}", level=logging.ERROR, exc_info=True)
return None