inpaintServer/tests/rembg_test.py

107 lines
3.7 KiB
Python

#!/usr/bin/env python3
import os
import sys
import time
from datetime import datetime
def log(msg: str):
ts = datetime.now().strftime('%H:%M:%S')
print(f"[{ts}] {msg}", flush=True)
def main():
# 환경 설정
u2net_home = os.path.join(os.path.expanduser('~'), '.u2net')
os.makedirs(u2net_home, exist_ok=True)
os.environ.setdefault('U2NET_HOME', u2net_home)
# 캐시 파일이 있으면 체크섬 스킵
if any(f.endswith('.onnx') for f in os.listdir(u2net_home)):
os.environ.setdefault('MODEL_CHECKSUM_DISABLED', '1')
# ORT 상세로그 (원인 추적)
os.environ.setdefault('ORT_LOG_SEVERITY_LEVEL', '0')
os.environ.setdefault('ORT_LOG_VERBOSITY_LEVEL', '1')
model_name = os.environ.get('REMBG_TEST_MODEL', 'birefnet-general-lite')
custom_path = os.environ.get('REMBG_MODEL_PATH')
log(f"U2NET_HOME={os.environ.get('U2NET_HOME')}")
log(f"MODEL_CHECKSUM_DISABLED={os.environ.get('MODEL_CHECKSUM_DISABLED')}")
log(f"REMBG_TEST_MODEL={model_name}, REMBG_MODEL_PATH={custom_path}")
import onnxruntime as ort
log(f"ORT version={getattr(ort, '__version__', 'unknown')}")
log(f"ORT device={ort.get_device()}")
log(f"ORT available providers={ort.get_available_providers()}")
import rembg
log(f"rembg version={getattr(rembg, '__version__', 'unknown')}")
# 세션 생성
# 일부 버전에서 sessions dict가 공개되지 않을 수 있어 안전하게 시도
try:
from rembg.sessions import sessions as rembg_sessions
log(f"rembg sessions registered (count)={len(rembg_sessions)}")
if model_name not in rembg_sessions and model_name != 'ben_custom':
log(f"WARN: session class not registered for '{model_name}'. Using available: {list(rembg_sessions.keys())[:10]} ...")
except Exception as e:
log(f"rembg.sessions registry read skipped: {e}")
args = []
kwargs = {}
if model_name == 'ben_custom':
if not custom_path:
print("ERROR: REMBG_MODEL_PATH must be set for ben_custom", file=sys.stderr)
sys.exit(2)
kwargs['model_path'] = custom_path
# Jetson + ONNX 1.17 조합에서 TensorRT 충돌을 피하기 위해 CUDA 명시
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
log(f"Using providers: {providers}")
kwargs['providers'] = providers
log("Creating rembg session...")
t0 = time.time()
sess = rembg.new_session(model_name, *args, **kwargs)
t1 = time.time()
prov = []
try:
if hasattr(sess, 'inner_session') and hasattr(sess.inner_session, 'get_providers'):
prov = sess.inner_session.get_providers() or []
except Exception as e:
log(f"provider read failed: {e}")
log(f"Session created in {t1 - t0:.2f}s, providers={prov}")
# 테스트 이미지 생성 (단색 배경 + 전경 사각형)
from PIL import Image, ImageDraw
W, H = 512, 384
img = Image.new('RGB', (W, H), (200, 220, 240))
d = ImageDraw.Draw(img)
d.rectangle([W//4, H//4, 3*W//4, 3*H//4], fill=(180, 50, 50))
# 배경제거 수행
from rembg import remove
log("Calling rembg.remove()...")
t2 = time.time()
out = remove(img, session=sess)
t3 = time.time()
log(f"remove() finished in {t3 - t2:.2f}s")
# 결과 저장
out_dir = os.path.join(os.getcwd(), 'outputs')
os.makedirs(out_dir, exist_ok=True)
out_path = os.path.join(out_dir, f'rembg_test_{model_name}.png')
try:
if hasattr(out, 'save'):
out.save(out_path)
else:
from PIL import Image as PILImage
PILImage.open(out).save(out_path)
log(f"Result saved: {out_path}")
except Exception as e:
log(f"Result save failed: {e}")
if __name__ == '__main__':
main()