124 lines
4.8 KiB
Python
124 lines
4.8 KiB
Python
import os
|
|
from rembg import new_session, remove
|
|
import cv2
|
|
import numpy as np
|
|
from PIL import Image
|
|
import logging
|
|
|
|
|
|
class BackgroundRemovalModule:
|
|
"""
|
|
rembg 기반 배경제거 모듈
|
|
|
|
주요 지원 모델 설명:
|
|
- 'u2net': 범용성, 속도/품질 밸런스, 사람/사물 모두 OK (기본값)
|
|
- 'u2netp': u2net 경량버전, 속도 빠름(저사양PC, 실시간)
|
|
- 'u2net_human_seg': 인물(사람) 세그멘테이션 특화
|
|
- 'u2net_cloth_seg': 옷(패션) 세그멘테이션 특화
|
|
- 'isnet-general-use': 범용, 디테일 강조, 크기 큼(최신 고성능)
|
|
- 'sam': Segment Anything, 사물/사람 모두, 고품질(고성능PC 권장)
|
|
- 'sam-mobile': SAM 경량화(속도↑, 성능↓), 모바일·저사양 PC도 사용 가능
|
|
"""
|
|
|
|
SUPPORTED_MODELS = {
|
|
"u2net": "범용 모델 (사람/사물 모두, 빠르고 정확함, 대부분 상황에서 권장)",
|
|
"u2netp": "u2net 경량화 버전 (속도 빠름, 품질은 약간 낮음, 실시간/저사양용)",
|
|
"u2net_human_seg": "사람 인식 특화 (프로필, 인물 사진)",
|
|
"u2net_cloth_seg": "의류/패션 이미지 특화",
|
|
"isnet-general-use": "최신 고성능 범용 모델 (디테일 강조, 고성능PC 추천)",
|
|
"sam": "Segment Anything Model (최고 품질, 고사양/메모리 충분할 때)",
|
|
"sam-mobile": "SAM의 경량화 (모바일, 저사양PC, 빠른 속도)"
|
|
}
|
|
|
|
def __init__(self, logger=None, default_model="u2net"):
|
|
self.logger = logger
|
|
self.default_model = default_model
|
|
self.sessions = {} # 모델별 세션 캐시
|
|
|
|
def get_session(self, model_name):
|
|
"""
|
|
모델별 세션을 캐싱하여 반환
|
|
"""
|
|
if model_name not in self.sessions:
|
|
self.sessions[model_name] = new_session(model_name=model_name)
|
|
if self.logger:
|
|
self.logger.log(f"rembg 세션 생성: {model_name}")
|
|
return self.sessions[model_name]
|
|
|
|
def set_default_model(self, model_name):
|
|
"""
|
|
배경제거 기본 모델을 변경
|
|
"""
|
|
if model_name not in self.SUPPORTED_MODELS:
|
|
raise ValueError(f"지원하지 않는 모델명: {model_name}")
|
|
self.default_model = model_name
|
|
# self.get_session(model_name) # 필요시 즉시 생성
|
|
if self.logger:
|
|
self.logger.log(f"rembg 기본 모델이 '{model_name}'(으)로 변경됨")
|
|
|
|
def get_default_model(self):
|
|
"""
|
|
현재 사용 중인 기본 모델 반환
|
|
"""
|
|
return self.default_model
|
|
|
|
def get_supported_models(self):
|
|
"""
|
|
지원 모델/설명 dict 반환 (UI, 도움말 등 활용)
|
|
"""
|
|
return self.SUPPORTED_MODELS.copy()
|
|
|
|
def get_model_description(self, model_name):
|
|
"""
|
|
모델명에 대한 설명 반환
|
|
"""
|
|
return self.SUPPORTED_MODELS.get(model_name, "모델 설명 없음")
|
|
|
|
def to_white_background(self, img: Image.Image) -> Image.Image:
|
|
"""
|
|
알파(투명) 부분을 흰색으로 합성해서 RGB 이미지로 반환
|
|
Args:
|
|
img: PIL.Image (RGBA or BGRA)
|
|
Returns:
|
|
PIL.Image (RGB, 배경이 흰색)
|
|
"""
|
|
if img.mode in ("RGBA", "BGRA"):
|
|
bg = Image.new("RGB", img.size, (255, 255, 255))
|
|
bg.paste(img, mask=img.split()[-1]) # 알파채널로 합성
|
|
return bg
|
|
else:
|
|
return img.convert("RGB")
|
|
|
|
def remove_background(self, image_path, model_name=None, **kwargs):
|
|
"""
|
|
이미지에서 배경을 제거한 결과(PIL.Image)를 반환
|
|
Args:
|
|
image_path (str): 입력 이미지 경로
|
|
model_name (str|None): None이면 기본 모델 사용
|
|
kwargs: rembg 옵션(alpha_matting 등)
|
|
Returns:
|
|
PIL.Image | None
|
|
"""
|
|
if not os.path.exists(image_path):
|
|
if self.logger:
|
|
self.logger.log(f"입력 이미지가 존재하지 않습니다: {image_path}")
|
|
return None
|
|
|
|
img = cv2.imread(image_path)
|
|
if img is None:
|
|
if self.logger:
|
|
self.logger.log(f"이미지 로드 실패: {image_path}")
|
|
return None
|
|
|
|
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
|
model_name = model_name or self.default_model
|
|
if model_name not in self.SUPPORTED_MODELS:
|
|
if self.logger:
|
|
self.logger.log(f"지원하지 않는 모델명: {model_name}")
|
|
return None
|
|
session = self.get_session(model_name)
|
|
result = remove(img_rgb, session=session, alpha_matting=kwargs.get("alpha_matting", False))
|
|
if not isinstance(result, Image.Image):
|
|
result = Image.fromarray(result)
|
|
return result
|