t_local_serv/modules/background_removal_module.py

124 lines
4.8 KiB
Python

import os
from rembg import new_session, remove
import cv2
import numpy as np
from PIL import Image
import logging
class BackgroundRemovalModule:
"""
rembg 기반 배경제거 모듈
주요 지원 모델 설명:
- 'u2net': 범용성, 속도/품질 밸런스, 사람/사물 모두 OK (기본값)
- 'u2netp': u2net 경량버전, 속도 빠름(저사양PC, 실시간)
- 'u2net_human_seg': 인물(사람) 세그멘테이션 특화
- 'u2net_cloth_seg': 옷(패션) 세그멘테이션 특화
- 'isnet-general-use': 범용, 디테일 강조, 크기 큼(최신 고성능)
- 'sam': Segment Anything, 사물/사람 모두, 고품질(고성능PC 권장)
- 'sam-mobile': SAM 경량화(속도↑, 성능↓), 모바일·저사양 PC도 사용 가능
"""
SUPPORTED_MODELS = {
"u2net": "범용 모델 (사람/사물 모두, 빠르고 정확함, 대부분 상황에서 권장)",
"u2netp": "u2net 경량화 버전 (속도 빠름, 품질은 약간 낮음, 실시간/저사양용)",
"u2net_human_seg": "사람 인식 특화 (프로필, 인물 사진)",
"u2net_cloth_seg": "의류/패션 이미지 특화",
"isnet-general-use": "최신 고성능 범용 모델 (디테일 강조, 고성능PC 추천)",
"sam": "Segment Anything Model (최고 품질, 고사양/메모리 충분할 때)",
"sam-mobile": "SAM의 경량화 (모바일, 저사양PC, 빠른 속도)"
}
def __init__(self, logger=None, default_model="u2net"):
self.logger = logger
self.default_model = default_model
self.sessions = {} # 모델별 세션 캐시
def get_session(self, model_name):
"""
모델별 세션을 캐싱하여 반환
"""
if model_name not in self.sessions:
self.sessions[model_name] = new_session(model_name=model_name)
if self.logger:
self.logger.log(f"rembg 세션 생성: {model_name}")
return self.sessions[model_name]
def set_default_model(self, model_name):
"""
배경제거 기본 모델을 변경
"""
if model_name not in self.SUPPORTED_MODELS:
raise ValueError(f"지원하지 않는 모델명: {model_name}")
self.default_model = model_name
# self.get_session(model_name) # 필요시 즉시 생성
if self.logger:
self.logger.log(f"rembg 기본 모델이 '{model_name}'(으)로 변경됨")
def get_default_model(self):
"""
현재 사용 중인 기본 모델 반환
"""
return self.default_model
def get_supported_models(self):
"""
지원 모델/설명 dict 반환 (UI, 도움말 등 활용)
"""
return self.SUPPORTED_MODELS.copy()
def get_model_description(self, model_name):
"""
모델명에 대한 설명 반환
"""
return self.SUPPORTED_MODELS.get(model_name, "모델 설명 없음")
def to_white_background(self, img: Image.Image) -> Image.Image:
"""
알파(투명) 부분을 흰색으로 합성해서 RGB 이미지로 반환
Args:
img: PIL.Image (RGBA or BGRA)
Returns:
PIL.Image (RGB, 배경이 흰색)
"""
if img.mode in ("RGBA", "BGRA"):
bg = Image.new("RGB", img.size, (255, 255, 255))
bg.paste(img, mask=img.split()[-1]) # 알파채널로 합성
return bg
else:
return img.convert("RGB")
def remove_background(self, image_path, model_name=None, **kwargs):
"""
이미지에서 배경을 제거한 결과(PIL.Image)를 반환
Args:
image_path (str): 입력 이미지 경로
model_name (str|None): None이면 기본 모델 사용
kwargs: rembg 옵션(alpha_matting 등)
Returns:
PIL.Image | None
"""
if not os.path.exists(image_path):
if self.logger:
self.logger.log(f"입력 이미지가 존재하지 않습니다: {image_path}")
return None
img = cv2.imread(image_path)
if img is None:
if self.logger:
self.logger.log(f"이미지 로드 실패: {image_path}")
return None
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
model_name = model_name or self.default_model
if model_name not in self.SUPPORTED_MODELS:
if self.logger:
self.logger.log(f"지원하지 않는 모델명: {model_name}")
return None
session = self.get_session(model_name)
result = remove(img_rgb, session=session, alpha_matting=kwargs.get("alpha_matting", False))
if not isinstance(result, Image.Image):
result = Image.fromarray(result)
return result