ImageProcessor_MainServer/worker/rembg_module.py

121 lines
4.8 KiB
Python

import os
import logging
from typing import Union
import cv2
import numpy as np
from PIL import Image
from rembg import remove
class RembgRemover:
"""배경 제거를 위해 rembg 패키지를 사용하는 로컬 모듈.
기존 AI 서버 호출 방식(`request_rembg`)과 동일하게 np.ndarray(BGR) 이미지를 반환합니다.
실패 시 None 을 반환합니다.
"""
def __init__(self, logger: logging.Logger = None):
self.logger = logger or logging.getLogger(__name__)
# ---------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------
def remove_background(self, image: Union[str, np.ndarray]) -> np.ndarray | None: # type: ignore[override]
"""배경을 제거 후 흰 배경 중앙 배치된 이미지를 반환한다.
Args:
image (str | np.ndarray): 파일 경로 또는 BGR np.ndarray.
Returns:
np.ndarray | None: 전처리 완료된 3채널 BGR 이미지. 실패 시 None.
"""
try:
img_bgr = self._load_image(image)
if img_bgr is None:
return None
# rembg 는 Pillow 이미지 또는 numpy RGB(A)를 입력으로 받음
img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
pil_input = Image.fromarray(img_rgb)
# 배경 제거 (RGBA Pillow Image 반환)
pil_output = remove(pil_input)
result_rgba = np.array(pil_output) # RGBA
result_img = cv2.cvtColor(result_rgba, cv2.COLOR_RGBA2BGRA) # BGRA
return self._post_process(result_img)
except Exception as e:
self.logger.error(f"RembgRemover 오류: {e}", exc_info=True)
return None
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def _load_image(self, image: Union[str, np.ndarray]) -> np.ndarray | None:
"""경로 또는 ndarray 로부터 BGR 이미지를 로드."""
if isinstance(image, str):
if not os.path.isfile(image):
self.logger.error(f"파일을 찾을 수 없습니다: {image}")
return None
img = cv2.imread(image, cv2.IMREAD_COLOR)
if img is None:
self.logger.error(f"cv2 로 이미지 읽기 실패: {image}")
return img
elif isinstance(image, np.ndarray):
return image
else:
self.logger.error(f"지원되지 않는 타입: {type(image)}")
return None
def _post_process(self, result_img: np.ndarray) -> np.ndarray:
"""마스크 정제, crop, 흰 배경 합성 등 후처리를 수행한다."""
if result_img.ndim != 3:
return result_img # 예외적인 반환(단채널 등) → 그대로 돌려줌
# ----- 1) 초기 마스크 생성 -----
if result_img.shape[2] == 4:
mask_init = (result_img[:, :, 3] > 200).astype(np.uint8)
rgba_img = result_img
else:
gray = cv2.cvtColor(result_img[:, :, :3], cv2.COLOR_BGR2GRAY)
mask_init = (gray < 230).astype(np.uint8)
alpha_channel = (mask_init * 255).astype(np.uint8)
rgba_img = np.dstack([result_img, alpha_channel])
# ----- 2) 모폴로지 정제 -----
kernel = np.ones((3, 3), np.uint8)
mask = cv2.erode(mask_init, kernel, iterations=1)
mask = cv2.dilate(mask, kernel, iterations=2)
# ----- 3) 최대 연결 요소 유지 -----
num, labels, stats, _ = cv2.connectedComponentsWithStats(mask, connectivity=8)
if num > 1:
largest = 1 + np.argmax(stats[1:, cv2.CC_STAT_AREA])
mask = (labels == largest).astype(np.uint8)
ys, xs = np.where(mask > 0)
if len(xs) == 0 or len(ys) == 0:
# 객체 감지 실패 → 흰 배경 합성만
white_bg = np.full_like(rgba_img[:, :, :3], 255)
return white_bg
# ----- 4) 바운딩 박스 + 마진 -----
top, left = ys.min(), xs.min()
bottom, right = ys.max(), xs.max()
crop_rgba = rgba_img[top:bottom + 1, left:right + 1]
ch, cw = crop_rgba.shape[:2]
margin = int(max(ch, cw) * 0.1)
crop_rgba = cv2.copyMakeBorder(
crop_rgba, margin, margin, margin, margin,
borderType=cv2.BORDER_CONSTANT,
value=[255, 255, 255, 0]
)
# ----- 5) 흰 배경 합성 -----
bgr_crop = crop_rgba[:, :, :3].astype(np.float32)
alpha_crop = crop_rgba[:, :, 3:4].astype(np.float32) / 255.0
white_bg = np.full_like(bgr_crop, 255.0)
final_img = (bgr_crop * alpha_crop + white_bg * (1 - alpha_crop)).astype(np.uint8)
return final_img