107 lines
3.7 KiB
Python
107 lines
3.7 KiB
Python
#!/usr/bin/env python3
|
|
import os
|
|
import sys
|
|
import time
|
|
from datetime import datetime
|
|
|
|
def log(msg: str):
|
|
ts = datetime.now().strftime('%H:%M:%S')
|
|
print(f"[{ts}] {msg}", flush=True)
|
|
|
|
def main():
|
|
# 환경 설정
|
|
u2net_home = os.path.join(os.path.expanduser('~'), '.u2net')
|
|
os.makedirs(u2net_home, exist_ok=True)
|
|
os.environ.setdefault('U2NET_HOME', u2net_home)
|
|
# 캐시 파일이 있으면 체크섬 스킵
|
|
if any(f.endswith('.onnx') for f in os.listdir(u2net_home)):
|
|
os.environ.setdefault('MODEL_CHECKSUM_DISABLED', '1')
|
|
|
|
# ORT 상세로그 (원인 추적)
|
|
os.environ.setdefault('ORT_LOG_SEVERITY_LEVEL', '0')
|
|
os.environ.setdefault('ORT_LOG_VERBOSITY_LEVEL', '1')
|
|
|
|
model_name = os.environ.get('REMBG_TEST_MODEL', 'birefnet-general-lite')
|
|
custom_path = os.environ.get('REMBG_MODEL_PATH')
|
|
|
|
log(f"U2NET_HOME={os.environ.get('U2NET_HOME')}")
|
|
log(f"MODEL_CHECKSUM_DISABLED={os.environ.get('MODEL_CHECKSUM_DISABLED')}")
|
|
log(f"REMBG_TEST_MODEL={model_name}, REMBG_MODEL_PATH={custom_path}")
|
|
|
|
import onnxruntime as ort
|
|
log(f"ORT version={getattr(ort, '__version__', 'unknown')}")
|
|
log(f"ORT device={ort.get_device()}")
|
|
log(f"ORT available providers={ort.get_available_providers()}")
|
|
|
|
import rembg
|
|
log(f"rembg version={getattr(rembg, '__version__', 'unknown')}")
|
|
|
|
# 세션 생성
|
|
# 일부 버전에서 sessions dict가 공개되지 않을 수 있어 안전하게 시도
|
|
try:
|
|
from rembg.sessions import sessions as rembg_sessions
|
|
log(f"rembg sessions registered (count)={len(rembg_sessions)}")
|
|
if model_name not in rembg_sessions and model_name != 'ben_custom':
|
|
log(f"WARN: session class not registered for '{model_name}'. Using available: {list(rembg_sessions.keys())[:10]} ...")
|
|
except Exception as e:
|
|
log(f"rembg.sessions registry read skipped: {e}")
|
|
|
|
args = []
|
|
kwargs = {}
|
|
if model_name == 'ben_custom':
|
|
if not custom_path:
|
|
print("ERROR: REMBG_MODEL_PATH must be set for ben_custom", file=sys.stderr)
|
|
sys.exit(2)
|
|
kwargs['model_path'] = custom_path
|
|
|
|
# Jetson + ONNX 1.17 조합에서 TensorRT 충돌을 피하기 위해 CUDA 명시
|
|
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
|
|
log(f"Using providers: {providers}")
|
|
kwargs['providers'] = providers
|
|
|
|
log("Creating rembg session...")
|
|
t0 = time.time()
|
|
sess = rembg.new_session(model_name, *args, **kwargs)
|
|
t1 = time.time()
|
|
prov = []
|
|
try:
|
|
if hasattr(sess, 'inner_session') and hasattr(sess.inner_session, 'get_providers'):
|
|
prov = sess.inner_session.get_providers() or []
|
|
except Exception as e:
|
|
log(f"provider read failed: {e}")
|
|
log(f"Session created in {t1 - t0:.2f}s, providers={prov}")
|
|
|
|
# 테스트 이미지 생성 (단색 배경 + 전경 사각형)
|
|
from PIL import Image, ImageDraw
|
|
W, H = 512, 384
|
|
img = Image.new('RGB', (W, H), (200, 220, 240))
|
|
d = ImageDraw.Draw(img)
|
|
d.rectangle([W//4, H//4, 3*W//4, 3*H//4], fill=(180, 50, 50))
|
|
|
|
# 배경제거 수행
|
|
from rembg import remove
|
|
log("Calling rembg.remove()...")
|
|
t2 = time.time()
|
|
out = remove(img, session=sess)
|
|
t3 = time.time()
|
|
log(f"remove() finished in {t3 - t2:.2f}s")
|
|
|
|
# 결과 저장
|
|
out_dir = os.path.join(os.getcwd(), 'outputs')
|
|
os.makedirs(out_dir, exist_ok=True)
|
|
out_path = os.path.join(out_dir, f'rembg_test_{model_name}.png')
|
|
try:
|
|
if hasattr(out, 'save'):
|
|
out.save(out_path)
|
|
else:
|
|
from PIL import Image as PILImage
|
|
PILImage.open(out).save(out_path)
|
|
log(f"Result saved: {out_path}")
|
|
except Exception as e:
|
|
log(f"Result save failed: {e}")
|
|
|
|
if __name__ == '__main__':
|
|
main()
|
|
|
|
|