Initial commit: IMG Worker project setup
|
|
@ -0,0 +1,22 @@
|
|||
# 큐 대기 한도(가득 차면 429)
|
||||
IMGWK_MAX_PENDING=400
|
||||
|
||||
IMGWK_WORKER_READY_TIMEOUT_SEC = 120 # 워커 작시 READY 타임아웃
|
||||
|
||||
# 워커 롤링 임계치
|
||||
IMGWK_ROLL_MAX_RSS_MB=3600
|
||||
IMGWK_ROLL_MAX_JOBS=500
|
||||
IMGWK_ROLL_MAX_UPTIME_SEC=3600
|
||||
|
||||
# 잡 타임아웃(초)
|
||||
IMGWK_JOB_TIMEOUT_SEC=600
|
||||
|
||||
# 테스트: 2건마다 워커 재시작 (1=활성화)
|
||||
IMGWK_TEST_ROLL_EVERY_2=0
|
||||
|
||||
IMGWK_OCR_PROVIDER=auto # auto|dml|cpu
|
||||
IMGWK_MIGAN_PROVIDER=auto # auto|dml|cpu
|
||||
IMGWK_REMBG_PROVIDER=auto # auto|dml|cpu
|
||||
|
||||
IMGWK_NO_TRAY = 0 # IMGWK_NO_TRAY=1 설정 시 트레이 비활성화
|
||||
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
Lib/
|
||||
Scripts/
|
||||
pyvenv.cfg
|
||||
include/
|
||||
share/
|
||||
pyvenv
|
||||
|
|
@ -0,0 +1,77 @@
|
|||
## Image Worker API (FastAPI)
|
||||
|
||||
### 개요
|
||||
- 이미지 번역 파이프라인을 별도 프로세스로 격리한 로컬 API 서버입니다.
|
||||
- 단일 워커 프로세스가 순차 처리하며, OCR 이후에는 번역(I/O)과 마스크 생성(CPU)을 동시 실행해 성능을 높입니다.
|
||||
- 재시작/메모리 이슈를 대비해 워커 롤링(메모리/건수/업타임), 잡 타임아웃, 롤링 중 버퍼링, 429 백프레셔를 제공합니다.
|
||||
|
||||
### 아키텍처
|
||||
- FastAPI 서버(가벼움) + 워커 프로세스(무거운 모듈)
|
||||
- 모듈별 프로바이더 전략(DirectML 우선 → 실패 시 CPU 폴백) 및 성공 프로바이더 캐시
|
||||
- OCR: `modules/onnx_ocr_module/src/onnx_ocr_wrapper.py`
|
||||
- MIGAN: `modules/migan_module.py`
|
||||
- Rembg(BriaAI): `modules/bria_background_removal_module.py`
|
||||
- 데이터 전달은 파일 경로 기반(대용량 base64 불사용)
|
||||
- ProgramData 경로 표준화
|
||||
- 작업/임시: `C:\ProgramData\ImgWorker\work`
|
||||
- 출력: `C:\ProgramData\ImgWorker\outputs`
|
||||
|
||||
### 환경변수(.env)
|
||||
- IMGWK_MAX_PENDING=200 # 큐 대기 한도(가득 차면 429)
|
||||
- IMGWK_ROLL_MAX_RSS_MB=1800 # 워커 RSS 임계(MB)
|
||||
- IMGWK_ROLL_MAX_JOBS=500 # 워커 재시작 처리 건수 임계
|
||||
- IMGWK_ROLL_MAX_UPTIME_SEC=7200 # 워커 업타임 임계(초)
|
||||
- IMGWK_JOB_TIMEOUT_SEC=600 # 잡 타임아웃(초)
|
||||
- IMGWK_TEST_ROLL_EVERY_2=0 # (테스트) 2건마다 롤링
|
||||
- IMGWK_OCR_PROVIDER=auto # auto|dml|cpu (강제 가능)
|
||||
- IMGWK_REMBG_PROVIDER=auto # rembg용 provider 강제(auto|dml|cpu)
|
||||
- (선택) 기타 모델 경로/스레드 설정은 토글로 전달
|
||||
|
||||
### 기본 토글(default)
|
||||
- Rembg: 자체 로컬 사용(use_local_rembg=True)
|
||||
- Inpaint: MIGAN을 기본(GPU 시도; DirectML→CPU 폴백)
|
||||
- OCR provider override: `.env` IMGWK_OCR_PROVIDER(없으면 auto)
|
||||
|
||||
### 프로바이더 선택/폴백/캐시
|
||||
- 순서: DirectML 우선 → 실패 시 CPU 폴백
|
||||
- 성공한 프로바이더는 캐시 파일에 저장되어 이후 재초기화 시 재검증 생략
|
||||
- OCR: `user_data/ocr_provider.json`
|
||||
- MIGAN: `user_data/migan_provider.json`
|
||||
- Rembg: `user_data/rembg_provider.json`
|
||||
- 모델별 강제 설정
|
||||
- OCR: 토글 `ocr_provider_override` = auto|dml|cpu
|
||||
- MIGAN: 토글 `migan_provider_override` = auto|dml|cpu
|
||||
- Rembg: 환경변수 IMGWK_REMBG_PROVIDER = auto|dml|cpu
|
||||
|
||||
### 엔드포인트
|
||||
- GET `/health`
|
||||
- ready 여부, 워커 PID
|
||||
- GET `/info`
|
||||
- 환경/경로/워커 상태 정보
|
||||
- POST `/v1/process-image`
|
||||
- Body: `{ file_path, index, file_prefix?, toggle_overrides?, group_id?, seq? }`
|
||||
- Return: `{ accepted: true, job_id }`
|
||||
- POST `/v1/remove-background`
|
||||
- Body: `{ file_path, file_prefix?, toggle_overrides? }`
|
||||
- Return: `{ accepted: true, job_id }`
|
||||
- GET `/v1/jobs/{job_id}`
|
||||
- 상태: queued|running|done|error|cancelled, 결과/오류 포함
|
||||
- DELETE `/v1/jobs/{job_id}`
|
||||
- queued 상태일 때만 취소
|
||||
- 제어 엔드포인트
|
||||
- POST `/v1/ocr/reinit` `{ provider?: 'auto'|'dml'|'cpu' }` → OCR 재초기화(프로바이더 반영/캐시)
|
||||
- POST `/v1/migan/reset` `{ use_cuda?: true|false }` → MIGAN 세션 재설정(DirectML 시도/폴백)
|
||||
|
||||
### 파이프라인(1 이미지)
|
||||
1) 로드/검증 → 2) OCR → 3) [번역, 마스크] 병렬 → 4) 인페인트(MIGAN 기본) → 5) 렌더링 → 6) 저장
|
||||
- 그룹 순서 보장: 클라이언트는 `group_id + seq`로 정렬, 서버는 무순서 완료 허용
|
||||
|
||||
### 런타임 변경
|
||||
- OCR 프로바이더: `/v1/ocr/reinit` 로 dml/cpu/auto 적용 가능(재초기화)
|
||||
- MIGAN: `/v1/migan/reset` 으로 DirectML 사용 여부 전환 가능
|
||||
|
||||
### 운영 팁
|
||||
- DML 미지원 VM/환경에서는 자동으로 CPU 폴백
|
||||
- 큐 가득 차면 429 반환 → 재시도 로직 구현 권장
|
||||
- 롤링 임계는 .env 로 조정
|
||||
- 로그는 `logs/`에서 롤링 저장, 자세한 타이밍/프로바이더 사용 내역 확인 가능
|
||||
|
|
@ -0,0 +1,63 @@
|
|||
## Why this file is included
|
||||
|
||||
This program has been frozen with cx_Freeze. The freezing process
|
||||
resulted in certain components from the cx_Freeze software being included
|
||||
in the frozen application, in particular bootstrap code for launching
|
||||
the frozen python script. The cx_Freeze software is subject to the
|
||||
license set out below.
|
||||
|
||||
# Licensing
|
||||
|
||||
- Copyright © 2020-2025, Marcelo Duarte.
|
||||
- Copyright © 2007-2019, Anthony Tuininga.
|
||||
- Copyright © 2001-2006, Computronix (Canada) Ltd., Edmonton, Alberta,
|
||||
Canada.
|
||||
- All rights reserved.
|
||||
|
||||
NOTE: This license is derived from the Python Software Foundation
|
||||
License which can be found at
|
||||
<https://docs.python.org/3/license.html#psf-license-agreement-for-python-release>
|
||||
|
||||
## License for cx_Freeze
|
||||
|
||||
1. This LICENSE AGREEMENT is between the copyright holders and the
|
||||
Individual or Organization ("Licensee") accessing and otherwise
|
||||
using cx_Freeze software in source or binary form and its associated
|
||||
documentation.
|
||||
2. Subject to the terms and conditions of this License Agreement, the
|
||||
copyright holders hereby grant Licensee a nonexclusive,
|
||||
royalty-free, world-wide license to reproduce, analyze, test,
|
||||
perform and/or display publicly, prepare derivative works,
|
||||
distribute, and otherwise use cx_Freeze alone or in any derivative
|
||||
version, provided, however, that this License Agreement and this
|
||||
notice of copyright are retained in cx_Freeze alone or in any
|
||||
derivative version prepared by Licensee.
|
||||
3. In the event Licensee prepares a derivative work that is based on or
|
||||
incorporates cx_Freeze or any part thereof, and wants to make the
|
||||
derivative work available to others as provided herein, then
|
||||
Licensee hereby agrees to include in any such work a brief summary
|
||||
of the changes made to cx_Freeze.
|
||||
4. The copyright holders are making cx_Freeze available to Licensee on
|
||||
an "AS IS" basis. THE COPYRIGHT HOLDERS MAKE NO REPRESENTATIONS OR
|
||||
WARRANTIES, EXPRESS OR IMPLIED. BY WAY OF EXAMPLE, BUT NOT
|
||||
LIMITATION, THE COPYRIGHT HOLDERS MAKE NO AND DISCLAIM ANY
|
||||
REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS FOR ANY
|
||||
PARTICULAR PURPOSE OR THAT THE USE OF CX_FREEZE WILL NOT INFRINGE
|
||||
ANY THIRD PARTY RIGHTS.
|
||||
5. THE COPYRIGHT HOLDERS SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER
|
||||
USERS OF CX_FREEZE FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL
|
||||
DAMAGES OR LOSS AS A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE
|
||||
USING CX_FREEZE, OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE
|
||||
POSSIBILITY THEREOF.
|
||||
6. This License Agreement will automatically terminate upon a material
|
||||
breach of its terms and conditions.
|
||||
7. Nothing in this License Agreement shall be deemed to create any
|
||||
relationship of agency, partnership, or joint venture between the
|
||||
copyright holders and Licensee. This License Agreement does not
|
||||
grant permission to use copyright holder's trademarks or trade name
|
||||
in a trademark sense to endorse or promote products or services of
|
||||
Licensee, or any third party.
|
||||
8. By copying, installing or otherwise using cx_Freeze, Licensee agrees
|
||||
to be bound by the terms and conditions of this License Agreement.
|
||||
|
||||
Computronix® is a registered trademark of Computronix (Canada) Ltd.
|
||||
|
|
@ -0,0 +1,393 @@
|
|||
"""
|
||||
서버/CLI 친화 Logger 모듈 - 파일/콘솔 분리 로깅과 자동 정리
|
||||
|
||||
권장 로그 레벨:
|
||||
- 개발: logging.DEBUG
|
||||
- 운영: logging.INFO 또는 logging.WARNING
|
||||
|
||||
사용 예시:
|
||||
```python
|
||||
# 기본 사용
|
||||
logger = Logger(
|
||||
log_file="logs/app.log",
|
||||
file_log_level=logging.DEBUG
|
||||
)
|
||||
|
||||
logger.info("서버 시작")
|
||||
|
||||
# FastAPI에서 사용 (예)
|
||||
# from fastapi import FastAPI
|
||||
# app = FastAPI()
|
||||
# log = Logger(log_file="logs/api.log")
|
||||
# @app.get("/")
|
||||
# def root():
|
||||
# log.info("요청 수신")
|
||||
# return {"ok": True}
|
||||
```
|
||||
"""
|
||||
import re
|
||||
import logging
|
||||
from logging.handlers import TimedRotatingFileHandler
|
||||
import os
|
||||
import time
|
||||
import threading
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
import traceback
|
||||
|
||||
# 로그 레벨 이름 매핑 (안전하고 호환성 좋은 방법)
|
||||
def get_level_name(level):
|
||||
"""로그 레벨을 이름으로 변환 (호환성 보장)"""
|
||||
level_names = {
|
||||
logging.NOTSET: "NOTSET",
|
||||
logging.DEBUG: "DEBUG",
|
||||
logging.INFO: "INFO",
|
||||
logging.WARNING: "WARNING",
|
||||
logging.ERROR: "ERROR",
|
||||
logging.CRITICAL: "CRITICAL"
|
||||
}
|
||||
return level_names.get(level, f"Level {level}")
|
||||
|
||||
class Logger:
|
||||
|
||||
def __init__(self, gui_logger=None, log_file="Edit_PartTimer_log.log", logger_name="Edit_PartTimer_log",
|
||||
file_log_level=logging.DEBUG, gui_log_level=logging.INFO,
|
||||
max_days=3, cleanup_interval=3600):
|
||||
"""
|
||||
개선된 Logger 초기화 (서버/CLI용)
|
||||
:param gui_logger: 선택적 콜백 함수(메시지 수신용)
|
||||
:param log_file: 로그 파일 이름
|
||||
:param logger_name: 로거 이름
|
||||
:param file_log_level: 파일 로거의 로그 레벨
|
||||
:param gui_log_level: 콜백 호출 로그 레벨(기본 INFO)
|
||||
:param max_days: 로그 파일 보관 일수 (기본 3일)
|
||||
:param cleanup_interval: 정리 작업 간격(초, 기본 1시간)
|
||||
"""
|
||||
self.gui_logger = gui_logger
|
||||
self.file_log_level = file_log_level
|
||||
self.gui_log_level = gui_log_level
|
||||
self.max_days = max_days
|
||||
self.cleanup_interval = cleanup_interval
|
||||
self.log_dir = Path(log_file).parent
|
||||
self.log_base_name = Path(log_file).stem
|
||||
|
||||
# 로그 디렉토리 생성
|
||||
self.log_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 로그 설정
|
||||
self.logger = logging.getLogger(logger_name)
|
||||
self.logger.setLevel(file_log_level)
|
||||
# 상위 로거로의 전파 방지(중복 출력 방지)
|
||||
self.logger.propagate = False
|
||||
|
||||
# 기존 핸들러 제거 (중복 방지)
|
||||
for handler in self.logger.handlers[:]:
|
||||
self.logger.removeHandler(handler)
|
||||
|
||||
# 포맷 설정
|
||||
self.simple_format = "[%(asctime)s] [%(levelname)s] %(message)s"
|
||||
self.detailed_format = (
|
||||
"[%(asctime)s] [%(threadName)s] [%(levelname)s] "
|
||||
"[%(filename)s:%(funcName)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
# 콜백 전용 간결한 포맷 (시간:분:초 + 메시지만)
|
||||
# 필요 시 log()에서 Formatter를 생성하여 사용
|
||||
|
||||
# 핸들러 추가
|
||||
self._add_console_handler(file_log_level)
|
||||
self._add_file_handler(log_file, file_log_level)
|
||||
|
||||
# 자동 정리 스레드 시작
|
||||
self._start_cleanup_thread()
|
||||
|
||||
def _add_console_handler(self, level):
|
||||
"""콘솔 핸들러 추가"""
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setLevel(level)
|
||||
formatter = logging.Formatter(
|
||||
self.detailed_format if level <= logging.DEBUG else self.simple_format
|
||||
)
|
||||
console_handler.setFormatter(formatter)
|
||||
self.logger.addHandler(console_handler)
|
||||
|
||||
def _add_file_handler(self, log_file, level):
|
||||
"""파일 크기 기반 로테이팅 + 수명 관리"""
|
||||
from logging.handlers import RotatingFileHandler
|
||||
# 확장자가 .log가 아니면 .log로 변경
|
||||
if not log_file.endswith('.log'):
|
||||
base_name, _ = os.path.splitext(log_file)
|
||||
log_file = base_name + '.log'
|
||||
|
||||
# 크기 기반 로테이션: 10MB, 최대 50개(정리 스레드가 3일/일5개로 실제 보존 제한)
|
||||
file_handler = RotatingFileHandler(
|
||||
log_file,
|
||||
maxBytes=10 * 1024 * 1024,
|
||||
backupCount=50,
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
file_handler.setLevel(level)
|
||||
formatter = logging.Formatter(
|
||||
self.detailed_format if level <= logging.DEBUG else self.simple_format
|
||||
)
|
||||
file_handler.setFormatter(formatter)
|
||||
self.logger.addHandler(file_handler)
|
||||
|
||||
# 핸들러 참조 저장 (정리 작업용)
|
||||
self.file_handler = file_handler
|
||||
|
||||
def _start_cleanup_thread(self):
|
||||
"""자동 정리 스레드 시작"""
|
||||
def cleanup_worker():
|
||||
while True:
|
||||
try:
|
||||
self._cleanup_old_logs()
|
||||
time.sleep(self.cleanup_interval)
|
||||
except Exception as e:
|
||||
# 정리 작업 실패 시에도 계속 동작
|
||||
pass
|
||||
|
||||
cleanup_thread = threading.Thread(target=cleanup_worker, daemon=True)
|
||||
cleanup_thread.start()
|
||||
|
||||
def _cleanup_old_logs(self):
|
||||
"""오래된 로그(> max_days) 및 날짜별 최대 10개 초과분 정리"""
|
||||
cutoff_date = datetime.now() - timedelta(days=self.max_days)
|
||||
log_pattern = f"{self.log_base_name}*.log*" # .log, .log.1 등 모두 포함
|
||||
|
||||
files = []
|
||||
for log_file in self.log_dir.glob(log_pattern):
|
||||
try:
|
||||
if log_file.name == f"{self.log_base_name}.log":
|
||||
continue
|
||||
file_mtime = datetime.fromtimestamp(log_file.stat().st_mtime)
|
||||
files.append((log_file, file_mtime))
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# 1) 보존 기간 초과 파일 삭제
|
||||
for log_file, mtime in files:
|
||||
if mtime < cutoff_date:
|
||||
try:
|
||||
log_file.unlink()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 2) 날짜별 최대 5개 유지 (최신순 보존)
|
||||
from collections import defaultdict
|
||||
by_day = defaultdict(list)
|
||||
for log_file, mtime in files:
|
||||
day_key = mtime.strftime('%Y-%m-%d')
|
||||
by_day[day_key].append((log_file, mtime))
|
||||
|
||||
for day_key, items in by_day.items():
|
||||
# 최신순으로 정렬
|
||||
items.sort(key=lambda x: x[1], reverse=True)
|
||||
for (log_file, _mtime) in items[5:]:
|
||||
try:
|
||||
if log_file.exists():
|
||||
log_file.unlink()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def log(self, message, level=logging.INFO, exc_info=False):
|
||||
"""로그 메시지 기록"""
|
||||
if exc_info:
|
||||
message = f"{message}\n{traceback.format_exc()}"
|
||||
|
||||
# 호출 위치 정보를 동적으로 추출
|
||||
caller_frame = logging.currentframe().f_back
|
||||
record = self.logger.makeRecord(
|
||||
self.logger.name, level, caller_frame.f_code.co_filename,
|
||||
caller_frame.f_lineno, message, None, None, caller_frame.f_code.co_name
|
||||
)
|
||||
|
||||
# 파일/콘솔 핸들러에 메시지 전달
|
||||
if level >= self.file_log_level:
|
||||
self.logger.handle(record)
|
||||
|
||||
# 선택적 콜백으로 전달 (간결 포맷 적용)
|
||||
if self.gui_logger and level >= self.gui_log_level:
|
||||
gui_formatter = logging.Formatter(
|
||||
fmt="[%(asctime)s] %(message)s",
|
||||
datefmt="%H:%M:%S"
|
||||
)
|
||||
formatted_message = gui_formatter.format(record)
|
||||
try:
|
||||
self.gui_logger(formatted_message)
|
||||
except Exception:
|
||||
# 콜백 오류는 로깅 실패로 간주하지 않음
|
||||
pass
|
||||
|
||||
def set_gui_logger(self, gui_logger, gui_log_level=None):
|
||||
"""
|
||||
로그 콜백 함수를 설정합니다.
|
||||
선택적으로 콜백 로그 레벨도 변경할 수 있습니다.
|
||||
"""
|
||||
self.gui_logger = gui_logger
|
||||
|
||||
if gui_log_level is not None:
|
||||
self.gui_log_level = gui_log_level
|
||||
|
||||
def set_gui_log_level(self, level):
|
||||
"""콜백 로그 레벨을 동적으로 변경합니다"""
|
||||
self.gui_log_level = level
|
||||
level_name = get_level_name(level)
|
||||
self.logger.info(f"콜백 로그 레벨이 {level_name}로 변경되었습니다")
|
||||
|
||||
def set_file_log_level(self, level):
|
||||
"""파일 로그 레벨을 동적으로 변경합니다"""
|
||||
self.file_log_level = level
|
||||
self.logger.setLevel(level)
|
||||
for handler in self.logger.handlers:
|
||||
if isinstance(handler, (logging.FileHandler, TimedRotatingFileHandler)):
|
||||
handler.setLevel(level)
|
||||
level_name = get_level_name(level)
|
||||
self.logger.info(f"파일 로그 레벨이 {level_name}로 변경되었습니다")
|
||||
|
||||
def get_log_levels(self):
|
||||
"""현재 로그 레벨 정보 반환"""
|
||||
return {
|
||||
"file_level": get_level_name(self.file_log_level),
|
||||
"gui_level": get_level_name(self.gui_log_level),
|
||||
"logger_level": get_level_name(self.logger.level)
|
||||
}
|
||||
|
||||
def get_log_info(self):
|
||||
"""로그 파일 정보 반환"""
|
||||
try:
|
||||
log_files = list(self.log_dir.glob(f"{self.log_base_name}*.log"))
|
||||
total_size = sum(f.stat().st_size for f in log_files)
|
||||
|
||||
return {
|
||||
"log_dir": str(self.log_dir),
|
||||
"total_files": len(log_files),
|
||||
"total_size_mb": total_size / (1024 * 1024),
|
||||
"max_days": self.max_days,
|
||||
"files": [f.name for f in log_files]
|
||||
}
|
||||
except Exception:
|
||||
return {"error": "로그 정보 조회 실패"}
|
||||
|
||||
def force_cleanup(self):
|
||||
"""수동 정리 실행"""
|
||||
self._cleanup_old_logs()
|
||||
|
||||
# 파이썬 표준 로깅 인터페이스 지원
|
||||
def debug(self, message, *args, **kwargs):
|
||||
"""DEBUG 레벨 로그"""
|
||||
if args:
|
||||
message = message % args
|
||||
self.log(message, level=logging.DEBUG, exc_info=kwargs.get('exc_info', False))
|
||||
|
||||
def info(self, message, *args, **kwargs):
|
||||
"""INFO 레벨 로그"""
|
||||
if args:
|
||||
message = message % args
|
||||
self.log(message, level=logging.INFO, exc_info=kwargs.get('exc_info', False))
|
||||
|
||||
def warning(self, message, *args, **kwargs):
|
||||
"""WARNING 레벨 로그"""
|
||||
if args:
|
||||
message = message % args
|
||||
self.log(message, level=logging.WARNING, exc_info=kwargs.get('exc_info', False))
|
||||
|
||||
def error(self, message, *args, **kwargs):
|
||||
"""ERROR 레벨 로그"""
|
||||
if args:
|
||||
message = message % args
|
||||
self.log(message, level=logging.ERROR, exc_info=kwargs.get('exc_info', False))
|
||||
|
||||
def critical(self, message, *args, **kwargs):
|
||||
"""CRITICAL 레벨 로그"""
|
||||
if args:
|
||||
message = message % args
|
||||
self.log(message, level=logging.CRITICAL, exc_info=kwargs.get('exc_info', False))
|
||||
|
||||
def exception(self, message, *args, **kwargs):
|
||||
"""ERROR 레벨로 예외 정보와 함께 로그"""
|
||||
if args:
|
||||
message = message % args
|
||||
self.log(message, level=logging.ERROR, exc_info=True)
|
||||
|
||||
# 전역 기본 로거 인스턴스
|
||||
_default_logger = None
|
||||
|
||||
def get_default_logger():
|
||||
"""기본 로거 인스턴스 반환"""
|
||||
global _default_logger
|
||||
if _default_logger is None:
|
||||
_default_logger = Logger(log_file="logs/default.log")
|
||||
return _default_logger
|
||||
|
||||
def set_default_logger(logger):
|
||||
"""기본 로거 설정"""
|
||||
global _default_logger
|
||||
_default_logger = logger
|
||||
|
||||
# 편의 함수들 - 실제 구현
|
||||
def debug(msg, *args, **kwargs):
|
||||
"""편의 함수 - DEBUG 레벨 로그"""
|
||||
get_default_logger().debug(msg, *args, **kwargs)
|
||||
|
||||
def info(msg, *args, **kwargs):
|
||||
"""편의 함수 - INFO 레벨 로그"""
|
||||
get_default_logger().info(msg, *args, **kwargs)
|
||||
|
||||
def warning(msg, *args, **kwargs):
|
||||
"""편의 함수 - WARNING 레벨 로그"""
|
||||
get_default_logger().warning(msg, *args, **kwargs)
|
||||
|
||||
def error(msg, *args, **kwargs):
|
||||
"""편의 함수 - ERROR 레벨 로그"""
|
||||
get_default_logger().error(msg, *args, **kwargs)
|
||||
|
||||
def critical(msg, *args, **kwargs):
|
||||
"""편의 함수 - CRITICAL 레벨 로그"""
|
||||
get_default_logger().critical(msg, *args, **kwargs)
|
||||
|
||||
def exception(msg, *args, **kwargs):
|
||||
"""편의 함수 - 예외 정보와 함께 로그"""
|
||||
get_default_logger().exception(msg, *args, **kwargs)
|
||||
|
||||
|
||||
# 구조화된 로깅을 위한 추가 클래스
|
||||
class StructuredLogger(Logger):
|
||||
"""JSON 형태의 구조화된 로깅을 지원하는 로거"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def log_structured(self, event, level=logging.INFO, **context):
|
||||
"""구조화된 로그 기록"""
|
||||
import json
|
||||
|
||||
log_data = {
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"event": event,
|
||||
"level": get_level_name(level),
|
||||
**context
|
||||
}
|
||||
|
||||
message = json.dumps(log_data, ensure_ascii=False, separators=(',', ':'))
|
||||
self.log(message, level)
|
||||
|
||||
def log_performance(self, operation, duration, **context):
|
||||
"""성능 로그 기록"""
|
||||
self.log_structured(
|
||||
"performance",
|
||||
level=logging.INFO,
|
||||
operation=operation,
|
||||
duration_ms=round(duration * 1000, 2),
|
||||
**context
|
||||
)
|
||||
|
||||
def log_error_with_context(self, error, **context):
|
||||
"""에러와 컨텍스트를 함께 로그"""
|
||||
self.log_structured(
|
||||
"error",
|
||||
level=logging.ERROR,
|
||||
error_type=type(error).__name__,
|
||||
error_message=str(error),
|
||||
**context
|
||||
)
|
||||
|
|
@ -0,0 +1,776 @@
|
|||
"""
|
||||
BriaAI RMBG 1.4 ONNXRuntime 기반 배경제거 모듈
|
||||
기존 background_removal_module.py를 완전 대체할 수 있는 호환 인터페이스 제공
|
||||
"""
|
||||
import os
|
||||
import cv2
|
||||
from PIL import Image
|
||||
import logging
|
||||
import numpy as np
|
||||
import time
|
||||
from typing import Optional, Tuple
|
||||
|
||||
|
||||
class BriaBackgroundRemovalModule:
|
||||
"""
|
||||
BriaAI RMBG 1.4 ONNX 모델 기반 배경제거 모듈
|
||||
기존 BackgroundRemovalModule과 완전 호환되는 인터페이스 제공
|
||||
"""
|
||||
|
||||
# 지원하는 모델 목록 (BriaAI 기반)
|
||||
SUPPORTED_MODELS = {
|
||||
"bria-rmbg-1.4": "BriaAI RMBG 1.4 | 고품질 | 빠름 | 범용 (기본값)",
|
||||
"bria-rmbg-aggressive": "BriaAI RMBG 1.4 | 강력한 배경제거 | 빠름 | 깔끔한 결과",
|
||||
"bria-rmbg-gentle": "BriaAI RMBG 1.4 | 부드러운 배경제거 | 빠름 | 세밀한 경계"
|
||||
}
|
||||
|
||||
# 모델별 aggressiveness 매핑
|
||||
MODEL_AGGRESSIVENESS = {
|
||||
"bria-rmbg-1.4": 0.5, # 기본값
|
||||
"bria-rmbg-aggressive": 0.8, # 강력한 배경제거
|
||||
"bria-rmbg-gentle": 0.2 # 부드러운 배경제거
|
||||
}
|
||||
|
||||
def __init__(self, logger=None, default_model="bria-rmbg-1.4", gpu_manager=None, local_rembg_model_path: str | None = None):
|
||||
self.logger = logger
|
||||
self.default_model = default_model
|
||||
self.gpu_manager = gpu_manager
|
||||
self.local_model_path = local_rembg_model_path # BriaAI ONNX 모델 경로
|
||||
|
||||
# ONNX 세션 관련
|
||||
self._session: Optional = None
|
||||
self._input_name: Optional[str] = None
|
||||
self._output_name: Optional[str] = None
|
||||
self._model_input_size: Tuple[int, int] = (1024, 1024) # (W, H)
|
||||
self._model_loaded = False
|
||||
self._init_error = None
|
||||
|
||||
if self.logger:
|
||||
self.logger.log("BriaAI 배경제거 모듈 초기화 시작", level=logging.INFO)
|
||||
|
||||
# ONNX Runtime 사용 가능성 확인
|
||||
self._check_onnxruntime_availability()
|
||||
|
||||
if self.logger:
|
||||
self.logger.log("BriaAI 배경제거 모듈 초기화 완료", level=logging.INFO)
|
||||
|
||||
def _check_onnxruntime_availability(self):
|
||||
"""ONNX Runtime 사용 가능 여부 확인"""
|
||||
try:
|
||||
import onnxruntime as ort
|
||||
|
||||
# 사용 가능한 프로바이더 확인
|
||||
available_providers = ort.get_available_providers()
|
||||
if self.logger:
|
||||
self.logger.log(f"ONNX Runtime 사용 가능한 프로바이더: {available_providers}", level=logging.INFO)
|
||||
|
||||
# Override/캐시 우선 결정
|
||||
provider_override = os.environ.get('IMGWK_REMBG_PROVIDER', 'auto').lower()
|
||||
cached_provider = self._read_provider_cache()
|
||||
self.providers = ['CPUExecutionProvider']
|
||||
|
||||
# GPU 가속 설정 (BriaAI는 가벼워서 1GB도 충분)
|
||||
if self.gpu_manager and self.gpu_manager.can_use_cuda:
|
||||
gpu_memory_mb = int(getattr(self.gpu_manager, 'gpu_memory_total', 0) or 0)
|
||||
|
||||
def can_use_dml():
|
||||
return ('DmlExecutionProvider' in available_providers) and (gpu_memory_mb >= 1024)
|
||||
|
||||
def can_use_cuda():
|
||||
return ('CUDAExecutionProvider' in available_providers) and (gpu_memory_mb >= 1024)
|
||||
|
||||
if provider_override == 'cpu':
|
||||
self.providers = ['CPUExecutionProvider']
|
||||
if self.logger:
|
||||
self.logger.log("BriaAI provider override=cpu", level=logging.INFO)
|
||||
elif provider_override == 'dml':
|
||||
if can_use_dml():
|
||||
self.providers = ['DmlExecutionProvider', 'CPUExecutionProvider']
|
||||
else:
|
||||
self.providers = ['CPUExecutionProvider']
|
||||
if self.logger:
|
||||
self.logger.log(f"BriaAI provider override=dml → {self.providers}", level=logging.INFO)
|
||||
elif cached_provider == 'dml':
|
||||
if can_use_dml():
|
||||
self.providers = ['DmlExecutionProvider', 'CPUExecutionProvider']
|
||||
if self.logger:
|
||||
self.logger.log("BriaAI provider cache=dml 적용", level=logging.INFO)
|
||||
else:
|
||||
self.providers = ['CPUExecutionProvider']
|
||||
else:
|
||||
# auto: DML → CUDA → CPU 순으로 선택
|
||||
if can_use_dml():
|
||||
self.providers = ['DmlExecutionProvider', 'CPUExecutionProvider']
|
||||
if self.logger:
|
||||
self.logger.log(f"BriaAI DirectML 가속 사용 가능 (VRAM: {gpu_memory_mb}MB)", level=logging.INFO)
|
||||
elif can_use_cuda():
|
||||
self.providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
|
||||
if self.logger:
|
||||
self.logger.log(f"BriaAI CUDA 가속 사용 가능 (VRAM: {gpu_memory_mb}MB)", level=logging.INFO)
|
||||
else:
|
||||
self.providers = ['CPUExecutionProvider']
|
||||
if self.logger:
|
||||
self.logger.log(f"VRAM/Provider 조건 불충족으로 CPU 모드 사용 (VRAM: {gpu_memory_mb}MB)", level=logging.WARNING)
|
||||
else:
|
||||
self.providers = ['CPUExecutionProvider']
|
||||
if self.logger:
|
||||
self.logger.log("BriaAI CPU 모드로 설정", level=logging.INFO)
|
||||
|
||||
self._onnxruntime_available = True
|
||||
return True
|
||||
|
||||
except ImportError as e:
|
||||
self._init_error = f"ONNX Runtime이 설치되지 않음: {e}"
|
||||
self._onnxruntime_available = False
|
||||
if self.logger:
|
||||
self.logger.log(self._init_error, level=logging.ERROR)
|
||||
return False
|
||||
except Exception as e:
|
||||
self._init_error = f"ONNX Runtime 초기화 실패: {e}"
|
||||
self._onnxruntime_available = False
|
||||
if self.logger:
|
||||
self.logger.log(self._init_error, level=logging.ERROR)
|
||||
return False
|
||||
|
||||
def _load_model(self) -> bool:
|
||||
"""BriaAI ONNX 모델을 로드합니다."""
|
||||
if self._model_loaded:
|
||||
return True
|
||||
|
||||
if not self._onnxruntime_available:
|
||||
if self.logger:
|
||||
self.logger.log("ONNX Runtime을 사용할 수 없어 모델 로드 실패", level=logging.ERROR)
|
||||
return False
|
||||
|
||||
if not self.local_model_path or not os.path.exists(self.local_model_path):
|
||||
self._init_error = f"BriaAI 모델 파일을 찾을 수 없음: {self.local_model_path}"
|
||||
if self.logger:
|
||||
self.logger.log(self._init_error, level=logging.ERROR)
|
||||
return False
|
||||
|
||||
# 단계별 폴백 시도
|
||||
fallback_providers = [
|
||||
(self.providers, "원본 providers"),
|
||||
([("CPUExecutionProvider", {})], "CPU 폴백")
|
||||
]
|
||||
|
||||
for attempt_providers, attempt_name in fallback_providers:
|
||||
try:
|
||||
import onnxruntime as ort
|
||||
|
||||
if self.logger:
|
||||
self.logger.log(f"BriaAI ONNX 모델 로딩 시도 ({attempt_name}): {self.local_model_path}", level=logging.INFO)
|
||||
|
||||
# 세션 옵션 설정 (안전성 우선)
|
||||
sess_options = ort.SessionOptions()
|
||||
|
||||
# CPU 모드에서는 더 보수적인 설정
|
||||
if attempt_providers == [("CPUExecutionProvider", {})]:
|
||||
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL
|
||||
sess_options.enable_mem_pattern = False
|
||||
sess_options.enable_cpu_mem_arena = False
|
||||
sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
|
||||
if self.logger:
|
||||
self.logger.log("🔒 CPU 안전 모드: 모든 최적화 비활성화", level=logging.INFO)
|
||||
else:
|
||||
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_BASIC
|
||||
sess_options.enable_mem_pattern = True
|
||||
sess_options.enable_cpu_mem_arena = True
|
||||
|
||||
# DirectML 사용시 추가 설정
|
||||
if 'DmlExecutionProvider' in str(attempt_providers):
|
||||
sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
|
||||
if self.logger:
|
||||
self.logger.log("DirectML 최적화 설정 적용", level=logging.DEBUG)
|
||||
|
||||
# ONNX 세션 생성 (타임아웃 설정)
|
||||
import signal
|
||||
|
||||
def timeout_handler(signum, frame):
|
||||
raise TimeoutError("모델 로딩 타임아웃")
|
||||
|
||||
# Windows에서는 signal.alarm이 작동하지 않으므로 threading 사용
|
||||
import threading
|
||||
import time
|
||||
|
||||
session_created = [False]
|
||||
session_result = [None]
|
||||
session_error = [None]
|
||||
|
||||
def create_session():
|
||||
try:
|
||||
session_result[0] = ort.InferenceSession(
|
||||
self.local_model_path,
|
||||
sess_options=sess_options,
|
||||
providers=attempt_providers
|
||||
)
|
||||
session_created[0] = True
|
||||
except Exception as e:
|
||||
session_error[0] = e
|
||||
|
||||
# 세션 생성을 별도 스레드에서 실행 (30초 타임아웃)
|
||||
session_thread = threading.Thread(target=create_session)
|
||||
session_thread.daemon = True
|
||||
session_thread.start()
|
||||
session_thread.join(timeout=30)
|
||||
|
||||
if not session_created[0]:
|
||||
if session_error[0]:
|
||||
raise session_error[0]
|
||||
else:
|
||||
raise TimeoutError("모델 로딩이 30초 내에 완료되지 않음")
|
||||
|
||||
self._session = session_result[0]
|
||||
|
||||
# 입출력 정보 가져오기
|
||||
inputs = self._session.get_inputs()
|
||||
outputs = self._session.get_outputs()
|
||||
|
||||
if not inputs or not outputs:
|
||||
raise RuntimeError("ONNX 모델의 입출력 정의를 찾을 수 없습니다")
|
||||
|
||||
# 입력/출력 이름 설정
|
||||
self._input_name = inputs[0].name
|
||||
self._output_name = outputs[0].name
|
||||
|
||||
# 실제 사용된 프로바이더 확인
|
||||
actual_providers = self._session.get_providers()
|
||||
|
||||
if self.logger:
|
||||
self.logger.log(
|
||||
f"✅ BriaAI ONNX 모델 로딩 완료 ({attempt_name}) | "
|
||||
f"Providers: {actual_providers} | "
|
||||
f"Input: {self._input_name} | Output: {self._output_name}",
|
||||
level=logging.INFO
|
||||
)
|
||||
|
||||
self._model_loaded = True
|
||||
self._providers_used = actual_providers
|
||||
# 성공 프로바이더 캐시 기록
|
||||
try:
|
||||
prov = 'dml' if any('Dml' in p for p in actual_providers) else 'cpu'
|
||||
self._write_provider_cache(prov)
|
||||
except Exception:
|
||||
pass
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"BriaAI ONNX 모델 로딩 실패 ({attempt_name}): {e}"
|
||||
if self.logger:
|
||||
self.logger.log(error_msg, level=logging.WARNING if attempt_name != "CPU 폴백" else logging.ERROR, exc_info=True)
|
||||
|
||||
# 마지막 시도가 아니면 계속 진행
|
||||
if attempt_name != "CPU 폴백":
|
||||
continue
|
||||
|
||||
# 모든 시도 실패
|
||||
self._init_error = "모든 provider에서 BriaAI ONNX 모델 로딩 실패"
|
||||
if self.logger:
|
||||
self.logger.log(self._init_error, level=logging.ERROR)
|
||||
return False
|
||||
|
||||
# ─────────────────────────────────────────────────────────────
|
||||
# 프로바이더 캐시 유틸
|
||||
# ─────────────────────────────────────────────────────────────
|
||||
def _cache_path(self):
|
||||
try:
|
||||
root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
user_data = os.path.join(root_dir, 'user_data')
|
||||
os.makedirs(user_data, exist_ok=True)
|
||||
return os.path.join(user_data, 'rembg_provider.json')
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def _read_provider_cache(self):
|
||||
try:
|
||||
path = self._cache_path()
|
||||
if not path or not os.path.exists(path):
|
||||
return None
|
||||
import json
|
||||
with open(path, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
prov = (data or {}).get('last_success_provider', '').lower()
|
||||
return prov if prov in ('dml', 'cpu') else None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def _write_provider_cache(self, provider: str):
|
||||
try:
|
||||
path = self._cache_path()
|
||||
if not path:
|
||||
return
|
||||
import json
|
||||
with open(path, 'w', encoding='utf-8') as f:
|
||||
json.dump({"last_success_provider": provider}, f, ensure_ascii=False)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _preprocess(self, image_bgr: np.ndarray) -> Tuple[np.ndarray, Tuple[int, int]]:
|
||||
"""BGR uint8 이미지를 모델 입력(NCHW float32, 정규화)로 변환 (허깅페이스 호환)"""
|
||||
orig_h, orig_w = image_bgr.shape[:2]
|
||||
|
||||
# 입력 검증
|
||||
if len(image_bgr.shape) != 3 or image_bgr.shape[2] != 3:
|
||||
raise ValueError(f"입력 이미지는 3채널 BGR이어야 합니다. 현재: {image_bgr.shape}")
|
||||
|
||||
# BGR -> RGB (허깅페이스는 RGB 입력 가정)
|
||||
image_rgb = image_bgr[:, :, ::-1].copy()
|
||||
|
||||
# 리사이즈 (W,H) - 허깅페이스와 동일한 bilinear 보간
|
||||
target_w, target_h = self._model_input_size
|
||||
resized = cv2.resize(image_rgb, (target_w, target_h), interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
# 허깅페이스 방식: float32 변환 후 정규화
|
||||
# 1. [0,255] -> [0,1]
|
||||
tensor = resized.astype(np.float32) / 255.0
|
||||
|
||||
# 2. normalize(mean=[0.5,0.5,0.5], std=[1.0,1.0,1.0]) -> (x - 0.5) / 1.0
|
||||
tensor = tensor - 0.5
|
||||
|
||||
# 3. HWC -> CHW, 배치 축 추가 (NCHW)
|
||||
nchw = np.transpose(tensor, (2, 0, 1))[np.newaxis, ...]
|
||||
|
||||
# 디버그 정보
|
||||
if self.logger:
|
||||
self.logger.log(f"전처리 완료: {image_bgr.shape} -> {nchw.shape}, 값 범위: [{nchw.min():.3f}, {nchw.max():.3f}]", level=logging.DEBUG)
|
||||
|
||||
return nchw, (orig_h, orig_w)
|
||||
|
||||
def _infer(self, input_tensor: np.ndarray) -> np.ndarray:
|
||||
"""ONNX 추론 수행 후 [H,W] 마스크 확률맵 반환 (BriaAI 모델 특화)"""
|
||||
outputs = self._session.run(None, {self._input_name: input_tensor})
|
||||
|
||||
# BriaAI 모델은 여러 출력을 가질 수 있음 (6개 side outputs)
|
||||
# 첫 번째 출력(d1)이 가장 정확한 결과
|
||||
if isinstance(outputs, (list, tuple)) and len(outputs) > 0:
|
||||
pred = outputs[0] # 첫 번째 출력 사용 (d1)
|
||||
if self.logger:
|
||||
self.logger.log(f"ONNX 모델 출력 개수: {len(outputs)}, 첫 번째 출력 shape: {pred.shape}", level=logging.DEBUG)
|
||||
else:
|
||||
pred = outputs
|
||||
|
||||
# numpy array로 변환
|
||||
pred = np.array(pred)
|
||||
|
||||
# 차원 정리: [B, C, H, W] -> [H, W]
|
||||
if pred.ndim == 4:
|
||||
# [1, 1, H, W] -> [H, W]
|
||||
pred = pred[0, 0]
|
||||
elif pred.ndim == 3:
|
||||
if pred.shape[0] == 1:
|
||||
# [1, H, W] -> [H, W]
|
||||
pred = pred[0]
|
||||
else:
|
||||
# [C, H, W] -> [H, W] (첫 번째 채널 사용)
|
||||
pred = pred[0]
|
||||
elif pred.ndim == 2:
|
||||
# 이미 [H, W] 형태
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f"예상하지 못한 출력 차원: {pred.shape}")
|
||||
|
||||
# 확률값 범위 확인 및 정규화 (0~1 사이가 아닐 경우)
|
||||
if pred.max() > 1.0 or pred.min() < 0.0:
|
||||
if self.logger:
|
||||
self.logger.log(f"출력 값 범위 이상: [{pred.min():.3f}, {pred.max():.3f}] -> sigmoid 적용", level=logging.WARNING)
|
||||
# sigmoid 함수 적용 (모델에서 sigmoid가 적용되지 않은 경우)
|
||||
pred = 1.0 / (1.0 + np.exp(-pred))
|
||||
|
||||
if self.logger:
|
||||
self.logger.log(f"추론 완료: {pred.shape}, 값 범위: [{pred.min():.3f}, {pred.max():.3f}]", level=logging.DEBUG)
|
||||
|
||||
return pred
|
||||
|
||||
def _postprocess(self, mask_pred: np.ndarray, orig_size: Tuple[int, int], aggressiveness: float = 0.5) -> np.ndarray:
|
||||
"""모델 출력 마스크를 원본 해상도로 보간하고 0..255 uint8로 변환 (허깅페이스 방식)"""
|
||||
orig_h, orig_w = orig_size
|
||||
|
||||
# 모델 출력(H,W)을 원본 크기로 리사이즈 (W,H) - 허깅페이스와 동일하게 bilinear 보간
|
||||
mask_resized = cv2.resize(mask_pred, (orig_w, orig_h), interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
# 허깅페이스 방식의 min-max 정규화 (임계값 없이)
|
||||
ma = float(mask_resized.max())
|
||||
mi = float(mask_resized.min())
|
||||
denom = (ma - mi) if (ma - mi) != 0 else 1.0
|
||||
mask_norm = (mask_resized - mi) / denom
|
||||
|
||||
# 허깅페이스 방식: 직접 0-255로 스케일링 (임계값 적용 안함)
|
||||
mask_255 = (mask_norm * 255).astype(np.uint8)
|
||||
|
||||
# aggressiveness 파라미터 적용 (필요시에만)
|
||||
if aggressiveness != 0.5:
|
||||
# aggressiveness가 0.5가 아닐 때만 약간의 조정 적용
|
||||
if aggressiveness > 0.5:
|
||||
# 더 공격적: 밝기 증가
|
||||
factor = 1.0 + (aggressiveness - 0.5) * 0.4
|
||||
mask_255 = np.clip(mask_255 * factor, 0, 255).astype(np.uint8)
|
||||
else:
|
||||
# 더 부드럽게: 밝기 감소
|
||||
factor = 0.6 + aggressiveness * 0.8
|
||||
mask_255 = np.clip(mask_255 * factor, 0, 255).astype(np.uint8)
|
||||
|
||||
return mask_255
|
||||
|
||||
# ===================================================================================
|
||||
# 기존 BackgroundRemovalModule 호환 인터페이스
|
||||
# ===================================================================================
|
||||
|
||||
def _refine_mask(
|
||||
self,
|
||||
alpha_mask: np.ndarray,
|
||||
*,
|
||||
alpha_threshold: int | None = None,
|
||||
keep_top_k_components: int | None = None,
|
||||
min_component_area: int | None = None,
|
||||
morph_kernel: int | None = 3,
|
||||
morph_open_iters: int = 0,
|
||||
morph_close_iters: int = 0,
|
||||
dilate_iters: int = 0,
|
||||
erode_iters: int = 0,
|
||||
fill_holes: bool = False,
|
||||
edge_feather: int = 0,
|
||||
) -> np.ndarray:
|
||||
"""알파 마스크를 후처리하여 과도 제거/잔여 잡음을 완화합니다.
|
||||
|
||||
매개변수 설명:
|
||||
- alpha_threshold: 이진화 임계값(0~255). 설정 시 확실한 전경만 남깁니다.
|
||||
- keep_top_k_components: 가장 큰 연결요소 K개만 유지(사람+상품=2 권장).
|
||||
- min_component_area: 이 면적 미만의 작은 요소 제거.
|
||||
- morph_*: 모폴로지 연산으로 노이즈 제거/홀 메움.
|
||||
- dilate/erode: 경계 확장/축소 미세 조정.
|
||||
- fill_holes: 내부 구멍 채우기(큰 구멍 포함).
|
||||
- edge_feather: 가장자리 페더링(>0이면 가우시안 블러, 값은 커널 반경).
|
||||
"""
|
||||
|
||||
try:
|
||||
mask_uint8 = alpha_mask.astype(np.uint8)
|
||||
|
||||
# 1) 이진화 준비
|
||||
if alpha_threshold is not None:
|
||||
_, bin_mask = cv2.threshold(mask_uint8, int(alpha_threshold), 255, cv2.THRESH_BINARY)
|
||||
else:
|
||||
# 소프트 마스크라도 후단 처리를 위해 이진 마스크 생성
|
||||
bin_mask = (mask_uint8 > 0).astype(np.uint8) * 255
|
||||
|
||||
# 2) 연결요소 기반 정리
|
||||
if keep_top_k_components is not None or (min_component_area is not None and min_component_area > 1):
|
||||
num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(bin_mask, connectivity=8)
|
||||
# 0은 배경
|
||||
component_indices = list(range(1, num_labels))
|
||||
if min_component_area is not None:
|
||||
component_indices = [i for i in component_indices if int(stats[i, cv2.CC_STAT_AREA]) >= int(min_component_area)]
|
||||
|
||||
# 가장 큰 K개만 남기기
|
||||
if keep_top_k_components is not None and keep_top_k_components > 0:
|
||||
component_indices = sorted(
|
||||
component_indices,
|
||||
key=lambda i: int(stats[i, cv2.CC_STAT_AREA]),
|
||||
reverse=True,
|
||||
)[: int(keep_top_k_components)]
|
||||
|
||||
filtered = np.zeros_like(bin_mask)
|
||||
for i in component_indices:
|
||||
filtered[labels == i] = 255
|
||||
bin_mask = filtered
|
||||
|
||||
# 3) 모폴로지 연산으로 노이즈 제거/홀 메움
|
||||
k = 3 if morph_kernel is None or morph_kernel <= 0 else int(morph_kernel)
|
||||
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k))
|
||||
|
||||
if morph_open_iters > 0:
|
||||
bin_mask = cv2.morphologyEx(bin_mask, cv2.MORPH_OPEN, kernel, iterations=int(morph_open_iters))
|
||||
|
||||
if morph_close_iters > 0:
|
||||
bin_mask = cv2.morphologyEx(bin_mask, cv2.MORPH_CLOSE, kernel, iterations=int(morph_close_iters))
|
||||
|
||||
if dilate_iters > 0:
|
||||
bin_mask = cv2.dilate(bin_mask, kernel, iterations=int(dilate_iters))
|
||||
if erode_iters > 0:
|
||||
bin_mask = cv2.erode(bin_mask, kernel, iterations=int(erode_iters))
|
||||
|
||||
# 4) 내부 큰 구멍 채우기
|
||||
if fill_holes:
|
||||
flood = bin_mask.copy()
|
||||
h, w = flood.shape[:2]
|
||||
flood_mask = np.zeros((h + 2, w + 2), np.uint8)
|
||||
cv2.floodFill(flood, flood_mask, (0, 0), 255)
|
||||
flood_inv = cv2.bitwise_not(flood)
|
||||
bin_mask = cv2.bitwise_or(bin_mask, flood_inv)
|
||||
|
||||
# 5) 가장자리 페더링(부드럽게)
|
||||
if edge_feather and edge_feather > 0:
|
||||
r = int(edge_feather)
|
||||
r = max(1, min(31, r * 2 + 1)) # 홀수 커널 보장, 과도한 값 방지
|
||||
soft = cv2.GaussianBlur(bin_mask, (r, r), 0)
|
||||
return soft.astype(np.uint8)
|
||||
|
||||
return bin_mask.astype(np.uint8)
|
||||
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.log(f"마스크 후처리 오류: {e}", level=logging.ERROR, exc_info=True)
|
||||
return alpha_mask
|
||||
|
||||
def is_available(self):
|
||||
"""배경제거 모듈 사용 가능 여부 반환"""
|
||||
return self._onnxruntime_available and (self.local_model_path and os.path.exists(self.local_model_path))
|
||||
|
||||
def get_init_error(self):
|
||||
"""초기화 에러 메시지 반환"""
|
||||
return self._init_error
|
||||
|
||||
def get_supported_models(self):
|
||||
"""지원하는 모델 목록 반환"""
|
||||
return self.SUPPORTED_MODELS.copy()
|
||||
|
||||
def get_default_model(self):
|
||||
"""기본 모델명 반환"""
|
||||
return self.default_model
|
||||
|
||||
def set_default_model(self, model_name):
|
||||
"""기본 모델 설정"""
|
||||
if model_name not in self.SUPPORTED_MODELS:
|
||||
raise ValueError(f"지원하지 않는 모델명: {model_name}")
|
||||
self.default_model = model_name
|
||||
if self.logger:
|
||||
self.logger.log(f"BriaAI 기본 모델이 '{model_name}'으로 변경됨", level=logging.INFO)
|
||||
|
||||
def get_model_description(self, model_name):
|
||||
"""모델 설명 반환"""
|
||||
return self.SUPPORTED_MODELS.get(model_name, "모델 설명 없음")
|
||||
|
||||
def to_white_background(self, img: Image.Image) -> Image.Image:
|
||||
"""RGBA 이미지를 흰 배경과 합성 (이미 RGB라면 그대로 반환)"""
|
||||
if img.mode in ("RGBA", "BGRA"):
|
||||
bg = Image.new("RGB", img.size, (255, 255, 255))
|
||||
bg.paste(img, mask=img.split()[-1])
|
||||
return bg
|
||||
else:
|
||||
# 이미 RGB이거나 다른 모드라면 RGB로 변환
|
||||
return img.convert("RGB")
|
||||
|
||||
def remove_background(self, image_path, model_name=None, force_cpu=None, **kwargs):
|
||||
"""
|
||||
이미지에서 배경을 제거하여 PIL Image 반환
|
||||
기존 BackgroundRemovalModule.remove_background와 동일한 인터페이스
|
||||
"""
|
||||
if not self.is_available():
|
||||
if self.logger:
|
||||
self.logger.log(f"BriaAI 모듈 사용 불가: {self._init_error}", level=logging.ERROR)
|
||||
return None
|
||||
|
||||
if not os.path.exists(image_path):
|
||||
if self.logger:
|
||||
self.logger.log(f"입력 이미지가 존재하지 않습니다: {image_path}", level=logging.ERROR)
|
||||
return None
|
||||
|
||||
# force_cpu 매개변수 처리
|
||||
original_providers = self.providers
|
||||
if force_cpu is True:
|
||||
self.providers = ['CPUExecutionProvider']
|
||||
if self.logger:
|
||||
self.logger.log("⚠️ CPU 모드 강제 실행 (BriaAI)", level=logging.WARNING)
|
||||
elif force_cpu is False:
|
||||
# GPU 가속 강제 사용 (DirectML 테스트용)
|
||||
if self.gpu_manager and self.gpu_manager.can_use_cuda:
|
||||
if 'DmlExecutionProvider' in original_providers:
|
||||
self.providers = ['DmlExecutionProvider', 'CPUExecutionProvider']
|
||||
elif 'CUDAExecutionProvider' in original_providers:
|
||||
self.providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
|
||||
else:
|
||||
self.providers = ['CPUExecutionProvider']
|
||||
if self.logger:
|
||||
self.logger.log(f"🔥 GPU 모드 강제 실행 (BriaAI): {self.providers}", level=logging.WARNING)
|
||||
|
||||
# 모델 로드 (지연 로딩)
|
||||
model_loaded = self._load_model()
|
||||
|
||||
# 원래 providers 복원
|
||||
if force_cpu is not None:
|
||||
self.providers = original_providers
|
||||
|
||||
if not model_loaded:
|
||||
return None
|
||||
|
||||
try:
|
||||
# 이미지 로드
|
||||
img = cv2.imread(image_path)
|
||||
if img is None:
|
||||
if self.logger:
|
||||
self.logger.log(f"이미지 로드 실패: {image_path}", level=logging.ERROR)
|
||||
return None
|
||||
|
||||
# 모델명 결정 및 aggressiveness 설정
|
||||
effective_model_name = model_name or self.default_model
|
||||
if effective_model_name not in self.SUPPORTED_MODELS:
|
||||
if self.logger:
|
||||
self.logger.log(f"지원하지 않는 모델명: {effective_model_name}. bria-rmbg-1.4로 대체 사용", level=logging.WARNING)
|
||||
effective_model_name = "bria-rmbg-1.4"
|
||||
|
||||
aggressiveness = self.MODEL_AGGRESSIVENESS.get(effective_model_name, 0.5)
|
||||
custom_aggressiveness = kwargs.get("aggressiveness", aggressiveness)
|
||||
|
||||
if self.logger:
|
||||
self.logger.log(f"BriaAI 배경제거 시작: {effective_model_name} (aggressiveness={custom_aggressiveness})", level=logging.DEBUG)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# 전처리
|
||||
input_tensor, (orig_h, orig_w) = self._preprocess(img)
|
||||
|
||||
# 추론
|
||||
mask_pred = self._infer(input_tensor)
|
||||
|
||||
# 후처리
|
||||
alpha_mask = self._postprocess(mask_pred, (orig_h, orig_w), aggressiveness=custom_aggressiveness)
|
||||
|
||||
# 선택적 마스크 보정
|
||||
if any(k in kwargs for k in (
|
||||
'alpha_threshold', 'keep_top_k_components', 'min_component_area',
|
||||
'morph_kernel', 'morph_open_iters', 'morph_close_iters',
|
||||
'dilate_iters', 'erode_iters', 'fill_holes', 'edge_feather'
|
||||
)):
|
||||
alpha_mask = self._refine_mask(
|
||||
alpha_mask,
|
||||
alpha_threshold=kwargs.get('alpha_threshold'),
|
||||
keep_top_k_components=kwargs.get('keep_top_k_components'),
|
||||
min_component_area=kwargs.get('min_component_area'),
|
||||
morph_kernel=kwargs.get('morph_kernel', 3),
|
||||
morph_open_iters=kwargs.get('morph_open_iters', 0),
|
||||
morph_close_iters=kwargs.get('morph_close_iters', 0),
|
||||
dilate_iters=kwargs.get('dilate_iters', 0),
|
||||
erode_iters=kwargs.get('erode_iters', 0),
|
||||
fill_holes=kwargs.get('fill_holes', False),
|
||||
edge_feather=kwargs.get('edge_feather', 0),
|
||||
)
|
||||
|
||||
# DirectML 알파 채널 이슈 회피: 바로 흰 배경 합성된 BGR로 처리
|
||||
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# 알파 마스크를 이용해 흰 배경 합성 (DirectML 알파 채널 이슈 회피)
|
||||
alpha_normalized = alpha_mask.astype(np.float32) / 255.0
|
||||
alpha_3d = np.stack([alpha_normalized] * 3, axis=-1)
|
||||
|
||||
# 흰 배경과 합성
|
||||
white_bg = np.full_like(img_rgb, 255, dtype=np.uint8)
|
||||
blended_rgb = (
|
||||
img_rgb.astype(np.float32) * alpha_3d +
|
||||
white_bg.astype(np.float32) * (1.0 - alpha_3d)
|
||||
).astype(np.uint8)
|
||||
|
||||
# PIL Image로 변환 (RGB 모드로 - 알파 채널 없음)
|
||||
result = Image.fromarray(blended_rgb, 'RGB')
|
||||
|
||||
end_time = time.time()
|
||||
processing_time = end_time - start_time
|
||||
|
||||
if self.logger:
|
||||
provider_status = "GPU" if any('CUDA' in p or 'Dml' in p for p in self._session.get_providers()) else "CPU"
|
||||
self.logger.log(f"✅ BriaAI 배경제거 성공: {effective_model_name} ({provider_status}, {processing_time:.2f}초)", level=logging.INFO)
|
||||
|
||||
# 마스크 통계 로깅
|
||||
mask_stats = {
|
||||
'min': int(alpha_mask.min()),
|
||||
'max': int(alpha_mask.max()),
|
||||
'mean': float(alpha_mask.mean()),
|
||||
'nonzero_count': int(np.count_nonzero(alpha_mask))
|
||||
}
|
||||
self.logger.log(f"BriaAI 마스크 통계: {mask_stats}", level=logging.DEBUG)
|
||||
|
||||
# GPU 메모리 사용량 로깅
|
||||
if self.gpu_manager and hasattr(self.gpu_manager, 'log_gpu_memory_usage'):
|
||||
self.gpu_manager.log_gpu_memory_usage()
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.log(f"BriaAI 배경제거 처리 중 오류: {e}", level=logging.ERROR, exc_info=True)
|
||||
return None
|
||||
|
||||
def remove_background_array(self, image_bgr: np.ndarray, model_name=None, force_cpu=None, **kwargs) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
배경제거 결과를 numpy array로 직접 반환 (추가 편의 메서드)
|
||||
Returns: (result_bgr, alpha_mask)
|
||||
"""
|
||||
# force_cpu 매개변수 처리
|
||||
original_providers = self.providers
|
||||
if force_cpu is True:
|
||||
self.providers = ['CPUExecutionProvider']
|
||||
if self.logger:
|
||||
self.logger.log("⚠️ CPU 모드 강제 실행 (BriaAI array)", level=logging.WARNING)
|
||||
elif force_cpu is False:
|
||||
# GPU 가속 강제 사용
|
||||
if self.gpu_manager and self.gpu_manager.can_use_cuda:
|
||||
if 'DmlExecutionProvider' in original_providers:
|
||||
self.providers = ['DmlExecutionProvider', 'CPUExecutionProvider']
|
||||
elif 'CUDAExecutionProvider' in original_providers:
|
||||
self.providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
|
||||
if self.logger:
|
||||
self.logger.log(f"🔥 GPU 모드 강제 실행 (BriaAI array): {self.providers}", level=logging.WARNING)
|
||||
|
||||
# 모델 로드
|
||||
model_loaded = self.is_available() and self._load_model()
|
||||
|
||||
# 원래 providers 복원
|
||||
if force_cpu is not None:
|
||||
self.providers = original_providers
|
||||
|
||||
if not model_loaded:
|
||||
return image_bgr, None
|
||||
|
||||
try:
|
||||
effective_model_name = model_name or self.default_model
|
||||
aggressiveness = self.MODEL_AGGRESSIVENESS.get(effective_model_name, 0.5)
|
||||
custom_aggressiveness = kwargs.get("aggressiveness", aggressiveness)
|
||||
|
||||
# 전처리 -> 추론 -> 후처리
|
||||
input_tensor, (orig_h, orig_w) = self._preprocess(image_bgr)
|
||||
mask_pred = self._infer(input_tensor)
|
||||
alpha_mask = self._postprocess(mask_pred, (orig_h, orig_w), aggressiveness=custom_aggressiveness)
|
||||
|
||||
# 선택적 마스크 보정
|
||||
if any(k in kwargs for k in (
|
||||
'alpha_threshold', 'keep_top_k_components', 'min_component_area',
|
||||
'morph_kernel', 'morph_open_iters', 'morph_close_iters',
|
||||
'dilate_iters', 'erode_iters', 'fill_holes', 'edge_feather'
|
||||
)):
|
||||
alpha_mask = self._refine_mask(
|
||||
alpha_mask,
|
||||
alpha_threshold=kwargs.get('alpha_threshold'),
|
||||
keep_top_k_components=kwargs.get('keep_top_k_components'),
|
||||
min_component_area=kwargs.get('min_component_area'),
|
||||
morph_kernel=kwargs.get('morph_kernel', 3),
|
||||
morph_open_iters=kwargs.get('morph_open_iters', 0),
|
||||
morph_close_iters=kwargs.get('morph_close_iters', 0),
|
||||
dilate_iters=kwargs.get('dilate_iters', 0),
|
||||
erode_iters=kwargs.get('erode_iters', 0),
|
||||
fill_holes=kwargs.get('fill_holes', False),
|
||||
edge_feather=kwargs.get('edge_feather', 0),
|
||||
)
|
||||
|
||||
# 흰색 배경 합성
|
||||
mask_3d = np.stack([alpha_mask] * 3, axis=-1)
|
||||
result_bgr = (
|
||||
image_bgr.astype(np.float32) * (mask_3d.astype(np.float32) / 255.0)
|
||||
+ 255.0 * (1.0 - (mask_3d.astype(np.float32) / 255.0))
|
||||
).clip(0, 255).astype(np.uint8)
|
||||
|
||||
return result_bgr, alpha_mask
|
||||
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.log(f"BriaAI 배경제거 array 처리 중 오류: {e}", level=logging.ERROR, exc_info=True)
|
||||
return image_bgr, None
|
||||
|
||||
def _preload_sessions(self):
|
||||
"""BriaAI 모델 미리 로딩"""
|
||||
if self.logger:
|
||||
self.logger.log("🔄 BriaAI 모델 미리 로딩 중...", level=logging.INFO)
|
||||
|
||||
if self._load_model():
|
||||
if self.logger:
|
||||
self.logger.log("✅ BriaAI 모델 미리 로딩 완료", level=logging.INFO)
|
||||
else:
|
||||
if self.logger:
|
||||
self.logger.log("⚠️ BriaAI 모델 로딩 실패", level=logging.WARNING)
|
||||
|
|
@ -0,0 +1,549 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
ImageWorker FastAPI 클라이언트 헬퍼
|
||||
|
||||
- 이미지 URL 다운로드 → 로컬 경로 전달 방식으로 서버에 제출
|
||||
- /v1/process-image, /v1/remove-background 호출 및 대기
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
import shutil
|
||||
import mimetypes
|
||||
import asyncio
|
||||
import random
|
||||
from typing import Dict, Any, Optional, List
|
||||
import logging
|
||||
import cv2 # for cleanup (destroyAllWindows)
|
||||
import requests
|
||||
import psutil
|
||||
from urllib.parse import urlparse
|
||||
import json
|
||||
|
||||
|
||||
def _compat_result_from_job(job: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""ImageProcessor3 결과와 호환되도록 키를 정규화."""
|
||||
rr = (job or {}).get("result") or {}
|
||||
result = {
|
||||
"status": rr.get("status") or job.get("status"),
|
||||
"path": rr.get("path"),
|
||||
"inpaint_method": rr.get("inpaint_method"),
|
||||
"inpaint_device": rr.get("inpaint_device"),
|
||||
"timings": rr.get("timings"),
|
||||
"child_results": rr.get("child_results"),
|
||||
"message": rr.get("message") or rr.get("msg"),
|
||||
"error": rr.get("error"),
|
||||
}
|
||||
# 누락 키 보장
|
||||
for k in ("status", "path", "inpaint_method", "inpaint_device", "timings", "child_results", "message", "error"):
|
||||
result.setdefault(k, None)
|
||||
return result
|
||||
|
||||
def _read_server_info() -> Optional[Dict[str, Any]]:
|
||||
"""ProgramData/ImgWorker/server.json에서 서버 정보를 읽어온다."""
|
||||
try:
|
||||
program_data = os.environ.get("PROGRAMDATA", r"C:\\ProgramData")
|
||||
info_path = os.path.join(program_data, "ImgWorker", "server.json")
|
||||
if os.path.isfile(info_path):
|
||||
with open(info_path, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except Exception:
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
class ImageWorkerClient:
|
||||
def __init__(self, logger, api_base: str = "http://127.0.0.1:8009", work_dir: Optional[str] = None, timeout: int = 30, max_concurrency: int = 8):
|
||||
self.logger = logger
|
||||
|
||||
# API base 우선순위: 명시 인자 > IMGWK_API_BASE 환경변수 > server.json > 기본값
|
||||
api = api_base
|
||||
try:
|
||||
env_api = os.environ.get("IMGWK_API_BASE")
|
||||
if env_api and isinstance(env_api, str) and env_api.strip():
|
||||
api = env_api.strip()
|
||||
else:
|
||||
# 인자가 기본값일 때만 server.json 자동 사용(명시적 인자 우선)
|
||||
if api_base == "http://127.0.0.1:8009":
|
||||
info = _read_server_info()
|
||||
if isinstance(info, dict):
|
||||
base = info.get("base") or (f"http://{info.get('host','127.0.0.1')}:{info.get('port',8009)}")
|
||||
if base:
|
||||
api = base
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
self.api = (api or "http://127.0.0.1:8009").rstrip("/")
|
||||
|
||||
# 기본 작업 디렉토리: C:\ProgramData\ImgWorker\incoming
|
||||
if work_dir is None:
|
||||
program_data = os.environ.get("PROGRAMDATA", r"C:\\ProgramData")
|
||||
work_dir = os.path.join(program_data, "ImgWorker", "incoming")
|
||||
self.logger.log(f"work_dir: {work_dir}", level=logging.DEBUG)
|
||||
os.makedirs(work_dir, exist_ok=True)
|
||||
self.work_dir = work_dir
|
||||
self.TEMP_IMAGE_DIR = work_dir # download_image 메서드 호환성을 위해
|
||||
self.timeout = timeout
|
||||
# 동시요청 제한(세마포어)
|
||||
# try:
|
||||
# n = int(max_concurrency)
|
||||
# except Exception:
|
||||
# n = 8
|
||||
# self._sema = asyncio.Semaphore(max(1, n))
|
||||
|
||||
# def set_max_concurrency(self, n: int):
|
||||
# """동시 요청 제한을 런타임에 조정"""
|
||||
# try:
|
||||
# n = int(n)
|
||||
# except Exception:
|
||||
# n = 1
|
||||
# self._sema = asyncio.Semaphore(max(1, n))
|
||||
|
||||
def is_valid_image_data(self, data):
|
||||
"""이미지 데이터의 유효성을 검사합니다"""
|
||||
if not data or len(data) < 100: # 최소 크기 검사
|
||||
return False
|
||||
|
||||
# JPEG, PNG, GIF, WebP 시그니처 검사
|
||||
if data.startswith(b'\xff\xd8\xff'): # JPEG
|
||||
return True
|
||||
elif data.startswith(b'\x89PNG\r\n\x1a\n'): # PNG
|
||||
return True
|
||||
elif data.startswith(b'GIF87a') or data.startswith(b'GIF89a'): # GIF
|
||||
return True
|
||||
elif data.startswith(b'RIFF') and b'WEBP' in data[:12]: # WebP
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
# ---------------------- 유틸 ----------------------
|
||||
def _guess_ext(self, url: str, content_type: Optional[str]) -> str:
|
||||
# 1) 헤더 content-type
|
||||
if content_type:
|
||||
ext = mimetypes.guess_extension(content_type.split(";")[0].strip())
|
||||
if ext:
|
||||
return ext
|
||||
# 2) URL 경로
|
||||
try:
|
||||
basename = os.path.basename(url.split("?")[0])
|
||||
_, ext = os.path.splitext(basename)
|
||||
if ext:
|
||||
return ext
|
||||
except Exception:
|
||||
pass
|
||||
# 3) 기본값
|
||||
return ".jpg"
|
||||
|
||||
|
||||
|
||||
def download_image(self, image_url, index, file_prefix="", max_retries=3):
|
||||
"""Requests를 사용해 이미지를 다운로드합니다"""
|
||||
|
||||
# 로컬 파일 경로면 바로 반환
|
||||
if os.path.isfile(image_url):
|
||||
self.logger.log(f"로컬 파일 경로 감지, 다운로드 생략: {image_url}", level=logging.DEBUG)
|
||||
return image_url
|
||||
|
||||
# 로컬 파일 경로가 아니면 다운로드 시도
|
||||
try:
|
||||
# "https://assets.alicdn.com"으로 시작하는 URL은 건너뛰기
|
||||
if image_url.startswith("https://assets.alicdn.com") or image_url.startswith("https://gtms01.alicdn.com"):
|
||||
self.logger.log(f"다운로드 제외 URL: {image_url}", level=logging.DEBUG)
|
||||
return None
|
||||
|
||||
# URL에서 파일명 추출 및 접두사 포함
|
||||
parsed_url = urlparse(image_url)
|
||||
base_filename = f"image_{index:03d}_{os.path.basename(parsed_url.path)}"
|
||||
if not base_filename.endswith(('.jpg', '.jpeg', '.png', '.gif', '.webp')):
|
||||
base_filename += '.jpg'
|
||||
|
||||
# 접두사가 있으면 파일명에 포함
|
||||
if file_prefix:
|
||||
filename = f"{file_prefix}_{base_filename}"
|
||||
else:
|
||||
filename = base_filename
|
||||
|
||||
local_path = os.path.join(self.TEMP_IMAGE_DIR, filename)
|
||||
|
||||
# HTTP 헤더 설정
|
||||
headers = {
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/88.0.4324.150 Safari/537.36",
|
||||
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.9",
|
||||
"Accept-Language": "en-US,en;q=0.9",
|
||||
"Accept-Encoding": "gzip, deflate, br",
|
||||
"DNT": "1", # Do Not Track 요청 헤더
|
||||
"Connection": "keep-alive",
|
||||
"Upgrade-Insecure-Requests": "1",
|
||||
"Cache-Control": "max-age=0"
|
||||
}
|
||||
|
||||
retries = 0
|
||||
while retries < max_retries:
|
||||
try:
|
||||
# 메모리 추적: 다운로드 시작 전
|
||||
before_mem = psutil.virtual_memory()
|
||||
before_mb = before_mem.used / 1024 / 1024
|
||||
|
||||
self.logger.log(f"이미지 다운로드 중: {filename}", level=logging.DEBUG)
|
||||
response = requests.get(image_url, headers=headers, stream=True, timeout=30)
|
||||
|
||||
if response.status_code == 200:
|
||||
image_data = response.content
|
||||
|
||||
# 이미지 데이터 유효성 검사
|
||||
if self.is_valid_image_data(image_data):
|
||||
with open(local_path, 'wb') as f:
|
||||
f.write(image_data)
|
||||
|
||||
# 메모리 추적: 다운로드 완료 후
|
||||
after_mem = psutil.virtual_memory()
|
||||
after_mb = after_mem.used / 1024 / 1024
|
||||
change_mb = after_mb - before_mb
|
||||
change_percent = (change_mb / before_mb) * 100 if before_mb > 0 else 0
|
||||
self.logger.log(
|
||||
f"메모리 변화 [다운로드 완료]: {before_mb:.1f}MB -> {after_mb:.1f}MB "
|
||||
f"({change_mb:+.1f}MB, {change_percent:+.1f}%) - {filename}",
|
||||
level=logging.DEBUG if abs(change_mb) < 10 else logging.INFO
|
||||
)
|
||||
|
||||
self.logger.log(f"이미지 다운로드 완료: {filename}", level=logging.DEBUG)
|
||||
return local_path
|
||||
else:
|
||||
self.logger.log(f"유효하지 않은 이미지 데이터: {image_url}", level=logging.WARNING)
|
||||
return None
|
||||
else:
|
||||
self.logger.log(f"이미지 다운로드 실패 (HTTP {response.status_code}): {image_url}. 재시도 {retries + 1}/{max_retries}", level=logging.ERROR)
|
||||
retries += 1
|
||||
if retries < max_retries:
|
||||
time.sleep(random.randint(2, 5)) # 2~5초 대기 후 재시도
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
self.logger.log(f"이미지 다운로드 중 네트워크 오류: {e}. 재시도 {retries + 1}/{max_retries}", level=logging.ERROR)
|
||||
retries += 1
|
||||
if retries < max_retries:
|
||||
time.sleep(random.randint(2, 5)) # 예외 발생 시 대기 후 재시도
|
||||
|
||||
except Exception as e:
|
||||
self.logger.log(f"이미지 다운로드 중 예상치 못한 오류: {e}. 재시도 {retries + 1}/{max_retries}", level=logging.ERROR)
|
||||
retries += 1
|
||||
if retries < max_retries:
|
||||
time.sleep(random.randint(2, 5))
|
||||
|
||||
self.logger.log(f"이미지 다운로드 최대 재시도 횟수를 초과했습니다: {image_url}", level=logging.ERROR)
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
self.logger.log(f"이미지 다운로드 중 오류: {e}", level=logging.ERROR, exc_info=True)
|
||||
return None
|
||||
|
||||
|
||||
|
||||
async def health(self) -> Dict[str, Any]:
|
||||
try:
|
||||
hr = requests.get(f"{self.api}/health", timeout=10)
|
||||
hr.raise_for_status()
|
||||
return hr.json()
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
# ---------------------- 제출/대기 ----------------------
|
||||
async def submit_process_image(self, file_path: str, index: int, file_prefix: str,
|
||||
font_type: str, unwanted_texts: List[str],
|
||||
is_member_valid: bool, authenticated_by_admin: bool,
|
||||
extra_overrides: Optional[Dict[str, Any]] = None,
|
||||
ocr: Optional[bool] = None) -> str:
|
||||
payload: Dict[str, Any] = {
|
||||
"file_path": file_path,
|
||||
"index": int(index),
|
||||
"file_prefix": file_prefix,
|
||||
# per-request 토글 오버라이드
|
||||
"toggle_overrides": {
|
||||
"font_type": font_type,
|
||||
"unwanted_texts": list(unwanted_texts or []),
|
||||
"is_member_valid": bool(is_member_valid),
|
||||
"authenticated_by_admin": bool(authenticated_by_admin),
|
||||
},
|
||||
}
|
||||
if extra_overrides:
|
||||
payload["toggle_overrides"].update(extra_overrides)
|
||||
# ocr 플래그가 명시되면 서버로도 전달
|
||||
if ocr is not None:
|
||||
payload["ocr"] = bool(ocr)
|
||||
|
||||
r = requests.post(f"{self.api}/v1/process-image", json=payload, timeout=self.timeout)
|
||||
r.raise_for_status()
|
||||
return r.json().get("job_id")
|
||||
|
||||
async def submit_remove_background(self, file_path: str, file_prefix: str,
|
||||
extra_overrides: Optional[Dict[str, Any]] = None) -> str:
|
||||
payload: Dict[str, Any] = {
|
||||
"file_path": file_path,
|
||||
"file_prefix": file_prefix,
|
||||
}
|
||||
if extra_overrides:
|
||||
payload["toggle_overrides"] = extra_overrides
|
||||
|
||||
r = requests.post(f"{self.api}/v1/remove-background", json=payload, timeout=self.timeout)
|
||||
r.raise_for_status()
|
||||
return r.json().get("job_id")
|
||||
|
||||
async def wait_job(self, job_id: str, timeout_sec: int = 900) -> Dict[str, Any]:
|
||||
end = time.time() + timeout_sec
|
||||
while time.time() < end:
|
||||
r = requests.get(f"{self.api}/v1/jobs/{job_id}", timeout=15)
|
||||
if r.status_code == 200:
|
||||
data = r.json()
|
||||
if data.get("status") in ("done", "error", "cancelled"):
|
||||
return data
|
||||
await asyncio.sleep(0.2)
|
||||
raise TimeoutError("job wait timeout")
|
||||
|
||||
# ---------------------- 고수준 URL 편의 ----------------------
|
||||
async def process_image_url(self, image_url: str, index: int, file_prefix: str,
|
||||
font_type: str, unwanted_texts: List[str],
|
||||
is_member_valid: bool, authenticated_by_admin: bool,
|
||||
extra_overrides: Optional[Dict[str, Any]] = None,
|
||||
download_first: bool = True,
|
||||
ocr: Optional[bool] = None) -> Optional[Dict[str, Any]]:
|
||||
# 420 에러 방지를 위해 순차 처리 (세마포어 제거)
|
||||
path = image_url
|
||||
if download_first and (image_url.startswith("http://") or image_url.startswith("https://")):
|
||||
# 동기식 download_image를 executor로 실행
|
||||
loop = asyncio.get_event_loop()
|
||||
path = await loop.run_in_executor(None, self.download_image, image_url, index, file_prefix)
|
||||
|
||||
jid = await self.submit_process_image(
|
||||
file_path=path,
|
||||
index=index,
|
||||
file_prefix=file_prefix,
|
||||
font_type=font_type,
|
||||
unwanted_texts=unwanted_texts,
|
||||
is_member_valid=is_member_valid,
|
||||
authenticated_by_admin=authenticated_by_admin,
|
||||
extra_overrides=extra_overrides,
|
||||
ocr=ocr,
|
||||
)
|
||||
job = await self.wait_job(jid)
|
||||
return _compat_result_from_job(job)
|
||||
|
||||
async def remove_background_url(self, image_url: str, file_prefix: str,
|
||||
extra_overrides: Optional[Dict[str, Any]] = None,
|
||||
download_first: bool = True) -> Optional[Dict[str, Any]]:
|
||||
# 420 에러 방지를 위해 순차 처리 (세마포어 제거)
|
||||
path = image_url
|
||||
if download_first and (image_url.startswith("http://") or image_url.startswith("https://")):
|
||||
# 동기식 download_image를 executor로 실행
|
||||
loop = asyncio.get_event_loop()
|
||||
path = await loop.run_in_executor(None, self.download_image, image_url, 0, file_prefix)
|
||||
|
||||
jid = await self.submit_remove_background(
|
||||
file_path=path,
|
||||
file_prefix=file_prefix,
|
||||
extra_overrides=extra_overrides,
|
||||
)
|
||||
job = await self.wait_job(jid)
|
||||
return _compat_result_from_job(job)
|
||||
|
||||
# ---------------------- 간단 제어 API (트레이에서 사용) ----------------------
|
||||
def worker_status(self) -> Dict[str, Any]:
|
||||
try:
|
||||
r = requests.get(f"{self.api}/v1/worker/status", timeout=5)
|
||||
r.raise_for_status()
|
||||
return r.json()
|
||||
except Exception as e:
|
||||
return {"ready": False, "error": str(e)}
|
||||
|
||||
def worker_start(self) -> Dict[str, Any]:
|
||||
try:
|
||||
r = requests.post(f"{self.api}/v1/worker/start", timeout=10)
|
||||
r.raise_for_status()
|
||||
return r.json()
|
||||
except Exception as e:
|
||||
return {"ok": False, "error": str(e)}
|
||||
|
||||
def worker_stop(self) -> Dict[str, Any]:
|
||||
try:
|
||||
r = requests.post(f"{self.api}/v1/worker/stop", timeout=10)
|
||||
r.raise_for_status()
|
||||
return r.json()
|
||||
except Exception as e:
|
||||
return {"ok": False, "error": str(e)}
|
||||
|
||||
def shutdown_server(self) -> Dict[str, Any]:
|
||||
try:
|
||||
r = requests.post(f"{self.api}/v1/server/shutdown", timeout=5)
|
||||
r.raise_for_status()
|
||||
return r.json()
|
||||
except Exception as e:
|
||||
return {"ok": False, "error": str(e)}
|
||||
|
||||
# ---------------------- 호환성 메서드 (기존 ImageProcessor3와 동일한 인터페이스) ----------------------
|
||||
|
||||
async def process_single_image(self, original_image_url, index, delay=1.0, file_prefix="", ocr: Optional[bool] = None):
|
||||
"""
|
||||
기존 ImageProcessor3.process_single_image과 호환되는 메서드
|
||||
|
||||
Args:
|
||||
original_image_url (str): 처리할 이미지 URL
|
||||
index (int): 이미지 인덱스
|
||||
delay (float): 요청 간격 (초) - 호환성을 위해 유지
|
||||
file_prefix (str): 파일명에 추가할 접두사
|
||||
|
||||
Returns:
|
||||
dict: 기존 ImageProcessor3과 동일한 포맷의 결과
|
||||
- status: 'translated', 'original', 'exclude', 'failed' 중 하나
|
||||
- path: 처리된 이미지 파일 경로 또는 원본 이미지 파일 경로
|
||||
- error: 오류 메시지 (status가 'failed'인 경우에만 포함)
|
||||
- inpaint_method: 사용된 인페인팅 방법
|
||||
- inpaint_device: 사용된 인페인팅 장치
|
||||
"""
|
||||
try:
|
||||
# 요청 간격 조절 (호환성을 위해)
|
||||
if delay > 0:
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
# ImageWorkerClient를 통한 처리
|
||||
result = await self.process_image_url(
|
||||
image_url=original_image_url,
|
||||
index=index,
|
||||
file_prefix=file_prefix,
|
||||
font_type="", # 기본값 사용 (필요시 외부에서 설정 가능)
|
||||
unwanted_texts=[], # 기본값 사용 (필요시 외부에서 설정 가능)
|
||||
is_member_valid=False, # 기본값 사용 (필요시 외부에서 설정 가능)
|
||||
authenticated_by_admin=False, # 기본값 사용 (필요시 외부에서 설정 가능)
|
||||
ocr=ocr,
|
||||
)
|
||||
|
||||
if result and isinstance(result, dict):
|
||||
# 서버 결과를 기존 포맷으로 변환
|
||||
status = result.get("status", "failed")
|
||||
path = result.get("path", original_image_url)
|
||||
|
||||
# status 매핑 (서버 결과에 따라 조정)
|
||||
if status == "translated":
|
||||
return {
|
||||
'status': 'translated',
|
||||
'path': path,
|
||||
'inpaint_method': result.get('inpaint_method', 'unknown'),
|
||||
'inpaint_device': result.get('inpaint_device', 'unknown')
|
||||
}
|
||||
elif status == "original":
|
||||
return {
|
||||
'status': 'original',
|
||||
'path': path,
|
||||
'inpaint_method': None,
|
||||
'inpaint_device': None
|
||||
}
|
||||
elif status == "exclude":
|
||||
return {
|
||||
'status': 'exclude',
|
||||
'path': path,
|
||||
'inpaint_method': None,
|
||||
'inpaint_device': None
|
||||
}
|
||||
else:
|
||||
# 기타 상태는 실패로 처리
|
||||
return {
|
||||
'status': 'failed',
|
||||
'path': original_image_url,
|
||||
'error': f'Unknown status: {status}',
|
||||
'inpaint_method': None,
|
||||
'inpaint_device': None
|
||||
}
|
||||
else:
|
||||
return {
|
||||
'status': 'failed',
|
||||
'path': original_image_url,
|
||||
'error': 'No result from server',
|
||||
'inpaint_method': None,
|
||||
'inpaint_device': None
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.logger.log(f"process_single_image 호환성 메서드 오류: {e}", level=logging.ERROR, exc_info=True)
|
||||
return {
|
||||
'status': 'failed',
|
||||
'path': original_image_url,
|
||||
'error': str(e),
|
||||
'inpaint_method': None,
|
||||
'inpaint_device': None
|
||||
}
|
||||
|
||||
async def remove_background(self, original_image_url, file_prefix=""):
|
||||
"""
|
||||
기존 ImageProcessor3.remove_background과 호환되는 메서드
|
||||
|
||||
Args:
|
||||
original_image_url (str): 처리할 이미지 URL
|
||||
file_prefix (str): 파일명에 추가할 접두사
|
||||
|
||||
Returns:
|
||||
dict: 기존 ImageProcessor3과 동일한 포맷의 결과
|
||||
- status: 'success', 'failed' 중 하나
|
||||
- path: 처리된 이미지 파일 경로
|
||||
- error: 오류 메시지 (status가 'failed'인 경우에만 포함)
|
||||
"""
|
||||
try:
|
||||
# ImageWorkerClient를 통한 배경제거
|
||||
result = await self.remove_background_url(
|
||||
image_url=original_image_url,
|
||||
file_prefix=file_prefix
|
||||
)
|
||||
|
||||
if result and isinstance(result, dict):
|
||||
status = result.get("status", "failed")
|
||||
path = result.get("path", original_image_url)
|
||||
|
||||
if status == "success":
|
||||
return {
|
||||
'status': 'success',
|
||||
'path': path
|
||||
}
|
||||
else:
|
||||
return {
|
||||
'status': 'failed',
|
||||
'path': original_image_url,
|
||||
'error': f'Remove background failed: {status}'
|
||||
}
|
||||
else:
|
||||
return {
|
||||
'status': 'failed',
|
||||
'path': original_image_url,
|
||||
'error': 'No result from server'
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.logger.log(f"remove_background 호환성 메서드 오류: {e}", level=logging.ERROR, exc_info=True)
|
||||
return {
|
||||
'status': 'failed',
|
||||
'path': original_image_url,
|
||||
'error': str(e)
|
||||
}
|
||||
|
||||
def __del__(self):
|
||||
"""소멸자에서 리소스 정리"""
|
||||
self.cleanup()
|
||||
self.logger.log("이미지 프로세서 소멸", level=logging.DEBUG)
|
||||
|
||||
def cleanup(self):
|
||||
"""리소스 정리"""
|
||||
try:
|
||||
# Python GC 강제 실행
|
||||
import gc
|
||||
gc.collect()
|
||||
|
||||
# OpenCV 윈도우 정리
|
||||
try:
|
||||
cv2.destroyAllWindows()
|
||||
except:
|
||||
pass
|
||||
|
||||
# 임시 폴더 삭제
|
||||
if hasattr(self, 'TEMP_IMAGE_DIR') and os.path.exists(self.TEMP_IMAGE_DIR):
|
||||
# shutil.rmtree(self.TEMP_IMAGE_DIR)
|
||||
self.logger.log(f"임시 폴더 삭제됨: {self.TEMP_IMAGE_DIR}", level=logging.DEBUG)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.log(f"리소스 정리 중 오류: {e}", level=logging.ERROR, exc_info=True)
|
||||
|
After Width: | Height: | Size: 1.1 MiB |
|
After Width: | Height: | Size: 1.2 MiB |
|
After Width: | Height: | Size: 1.4 MiB |
|
After Width: | Height: | Size: 1.3 MiB |
|
After Width: | Height: | Size: 1.3 MiB |
|
|
@ -0,0 +1,253 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Gemma Translation API Python Client
|
||||
- FastAPI 서버(/batch_translate, /translate_ocr_step1, /translate_ocr_step2 ...) 래퍼
|
||||
- 안전한 재시도, 타임아웃, 배치 슬라이싱 지원
|
||||
- 네 OCR 결과(dict 리스트) -> 번역 문자열 리스트 정렬 유지
|
||||
|
||||
사용 예:
|
||||
from gemma_client import GemmaTranslator
|
||||
gt = GemmaTranslator(base_url="http://<SERVER_IP>", timeout=120)
|
||||
|
||||
# A) OCR 결과를 그대로 번역 (id, text 유지)
|
||||
ko_list = gt.translate_ocr_texts(
|
||||
product_name="휴대용 선풍기",
|
||||
category="가전/계절가전",
|
||||
ocr_results=[{"text":"强力送风"}, {"text":"USB-C 快速充电"}]
|
||||
)
|
||||
|
||||
# B) 순수 텍스트 리스트를 번역
|
||||
ko_list = gt.batch_translate_texts(
|
||||
product_name="휴대용 선풍기",
|
||||
category="가전/계절가전",
|
||||
text_list=["大风力无刷电机","Type-C 充电"]
|
||||
)
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import time
|
||||
import json
|
||||
import random
|
||||
import logging
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
_JSON = Dict[str, Any]
|
||||
|
||||
|
||||
class GemmaTranslatorError(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
class GemmaTranslator:
|
||||
"""
|
||||
vLLM 번역 서버 클라이언트.
|
||||
|
||||
Params
|
||||
------
|
||||
base_url : str
|
||||
예) "http://localhost" (HAProxy 경유 시 포트 생략 / 80)
|
||||
개별 인스턴스 직접 붙으려면 "http://<IP>:8000"
|
||||
timeout : int
|
||||
요청 타임아웃(초)
|
||||
max_retries : int
|
||||
요청 재시도 횟수
|
||||
backoff : float
|
||||
재시도 backoff base (지수)
|
||||
session : requests.Session | None
|
||||
세션 주입 가능
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: Optional[str] = None,
|
||||
timeout: int = 120,
|
||||
max_retries: int = 2,
|
||||
backoff: float = 0.6,
|
||||
session: Optional[requests.Session] = None,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
) -> None:
|
||||
self.base_url = (base_url or os.getenv("GEMMA_API_BASE") or "http://localhost").rstrip("/")
|
||||
self.timeout = timeout
|
||||
self.max_retries = max_retries
|
||||
self.backoff = backoff
|
||||
self.sess = session or requests.Session()
|
||||
self.log = logger or logging.getLogger(__name__)
|
||||
|
||||
# -----------------------------
|
||||
# 내부 HTTP 헬퍼 (retry 포함)
|
||||
# -----------------------------
|
||||
def _post(self, path: str, payload: _JSON) -> _JSON:
|
||||
url = f"{self.base_url}{path}"
|
||||
last_err: Optional[Exception] = None
|
||||
for attempt in range(self.max_retries + 1):
|
||||
try:
|
||||
r = self.sess.post(url, json=payload, timeout=self.timeout)
|
||||
r.raise_for_status()
|
||||
return r.json()
|
||||
except Exception as e:
|
||||
last_err = e
|
||||
# 429/5xx/연결오류 등 재시도
|
||||
if attempt < self.max_retries:
|
||||
sleep_s = (self.backoff ** attempt) + random.uniform(0, 0.2)
|
||||
self.log.warning(f"[GemmaTranslator] POST {url} 실패({e}), 재시도 {attempt+1}/{self.max_retries} 대기 {sleep_s:.2f}s")
|
||||
time.sleep(sleep_s)
|
||||
else:
|
||||
break
|
||||
raise GemmaTranslatorError(f"POST {url} 실패: {last_err}")
|
||||
|
||||
# -----------------------------
|
||||
# 공개 API
|
||||
# -----------------------------
|
||||
def health(self) -> _JSON:
|
||||
url = f"{self.base_url}/health"
|
||||
r = self.sess.get(url, timeout=10)
|
||||
r.raise_for_status()
|
||||
return r.json()
|
||||
|
||||
def metrics(self) -> _JSON:
|
||||
url = f"{self.base_url}/metrics"
|
||||
r = self.sess.get(url, timeout=10)
|
||||
r.raise_for_status()
|
||||
return r.json()
|
||||
|
||||
# ---- A) 순수 텍스트 리스트 번역 (/batch_translate) ----
|
||||
def batch_translate_texts(
|
||||
self,
|
||||
product_name: str,
|
||||
category: str,
|
||||
text_list: List[str],
|
||||
delimiter: str = " / ",
|
||||
batch_size: int = 8,
|
||||
) -> List[str]:
|
||||
"""
|
||||
text_list → 같은 길이의 ko 리스트 반환.
|
||||
서버에서 추가 배치/토큰 슬라이싱을 하므로, 클라에서는 적당한 batch_size만 넘기면 됨.
|
||||
"""
|
||||
if not text_list:
|
||||
return []
|
||||
|
||||
out: List[str] = []
|
||||
for i in range(0, len(text_list), batch_size):
|
||||
chunk = text_list[i : i + batch_size]
|
||||
payload = {
|
||||
"product_name": product_name,
|
||||
"category": category,
|
||||
"text_list": chunk,
|
||||
"delimiter": delimiter,
|
||||
"batch_size": min(batch_size, len(chunk)),
|
||||
}
|
||||
resp = self._post("/batch-translate", payload)
|
||||
# 서버는 translated_texts 길이를 입력과 동일하게 맞춰줌
|
||||
out.extend(resp.get("translated_texts", chunk))
|
||||
return out
|
||||
|
||||
# ---- B) OCR 결과 번역: [{text:str, ...}] -> [ko, ...] ----
|
||||
def translate_ocr_texts(
|
||||
self,
|
||||
product_name: str,
|
||||
category: str,
|
||||
ocr_results: List[Dict[str, Any]],
|
||||
batch_size: int = 16,
|
||||
) -> List[str]:
|
||||
"""
|
||||
입력: OCR 결과 리스트(각 항목에 최소 'text' 키 필요)
|
||||
처리: 서버 /ocr-translate 로 {id, source} 배열을 보내고, 반환 {id, result} 정렬
|
||||
출력: 원래 순서 유지한 ko 문자열 리스트
|
||||
"""
|
||||
if not ocr_results:
|
||||
return []
|
||||
|
||||
# 유효성 검증 (새 스키마 준수)
|
||||
if not product_name or len(product_name.strip()) < 1:
|
||||
raise GemmaTranslatorError("product_name은 1자 이상이어야 합니다.")
|
||||
if not category or len(category.strip()) < 1:
|
||||
raise GemmaTranslatorError("category는 1자 이상이어야 합니다.")
|
||||
|
||||
# id 부여 및 source 필터링 (빈 텍스트 스킵, 최소 1자)
|
||||
items = []
|
||||
source_to_orig_idx = {} # source id to original index mapping
|
||||
for i, d in enumerate(ocr_results):
|
||||
source = (d.get("text") or "").strip()
|
||||
if len(source) >= 1: # 최소 1자 이상
|
||||
item_id = len(items) + 1
|
||||
items.append({"id": item_id, "source": source})
|
||||
source_to_orig_idx[item_id] = i
|
||||
|
||||
if not items:
|
||||
return [""] * len(ocr_results) # 원래 길이만큼 빈 문자열 반환
|
||||
|
||||
# 서버에 배치로 전송
|
||||
out_ko = [""] * len(ocr_results) # 원래 ocr_results 길이 유지
|
||||
for i in range(0, len(items), batch_size):
|
||||
chunk = items[i : i + batch_size]
|
||||
payload = {
|
||||
"product_name": product_name,
|
||||
"category": category,
|
||||
"items": chunk,
|
||||
}
|
||||
resp = self._post("/ocr-translate", payload)
|
||||
|
||||
result_items = resp.get("items", [])
|
||||
# {id, result} 배열 → id 기준으로 매핑
|
||||
for obj in result_items:
|
||||
item_id = int(obj.get("id", 0))
|
||||
orig_idx = source_to_orig_idx.get(item_id)
|
||||
if orig_idx is not None:
|
||||
out_ko[orig_idx] = str(obj.get("result", "")).strip()
|
||||
|
||||
return out_ko
|
||||
|
||||
# ---- C) 옵션 번역: [{id, source:[...]}] → [{id, translations:[...]}] ----
|
||||
def translate_option_groups(
|
||||
self,
|
||||
product_name: str,
|
||||
category: str,
|
||||
option_groups: List[Dict[str, Any]],
|
||||
batch_size: int = 8,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
option_groups 예:
|
||||
[{"id": 1, "source": ["红色","蓝色"]}, {"id": 2, "source": ["小号","大号"]}]
|
||||
반환(서버 결과 그대로):
|
||||
[{"id": 1, "translations": ["핑크","블루"]}, {"id": 2, "translations": ["소형","대형"]}]
|
||||
"""
|
||||
if not option_groups:
|
||||
return []
|
||||
|
||||
out: List[Dict[str, Any]] = []
|
||||
for i in range(0, len(option_groups), batch_size):
|
||||
chunk = option_groups[i : i + batch_size]
|
||||
payload = {
|
||||
"product_name": product_name,
|
||||
"category": category,
|
||||
"items": chunk,
|
||||
}
|
||||
resp = self._post("/option-translate", payload)
|
||||
out.extend(resp.get("result", []))
|
||||
return out
|
||||
|
||||
# ---- D) (선택) 카피 다듬기 ---- (이 메서드는 새 스펙에서 OCR가 단일 단계이므로 제거 또는 주석 처리)
|
||||
# def polish_translations(
|
||||
# self,
|
||||
# product_name: str,
|
||||
# category: str,
|
||||
# id_text_pairs: List[Dict[str, Any]],
|
||||
# batch_size: int = 16,
|
||||
# ) -> List[Dict[str, Any]]:
|
||||
# """
|
||||
# 입력: [{"id":1,"translation":"..."}]
|
||||
# 반환: [{"id":1,"result":"..."}]
|
||||
# """
|
||||
# if not id_text_pairs:
|
||||
# return []
|
||||
# out: List[Dict[str, Any]] = []
|
||||
# for i in range(0, len(id_text_pairs), batch_size):
|
||||
# chunk = id_text_pairs[i : i + batch_size]
|
||||
# payload = {"product_name": product_name, "category": category, "items": chunk}
|
||||
# resp = self._post("/translate_ocr_step2", payload)
|
||||
# out.extend(resp.get("result", []))
|
||||
# return out
|
||||
|
|
@ -0,0 +1,835 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
GPU 상태 확인 및 설치 가이드 모듈
|
||||
|
||||
기능:
|
||||
- NVIDIA GPU 감지
|
||||
- CUDA Toolkit 설치 확인
|
||||
- cuDNN 설치 확인
|
||||
- TensorRT 설치 확인
|
||||
- cuDNN PATH 자동 설정
|
||||
- 설치 가이드 제공
|
||||
"""
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import platform
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from pathlib import Path
|
||||
import glob
|
||||
import json
|
||||
|
||||
|
||||
class GPUStatusChecker:
|
||||
"""GPU 환경 상태를 종합적으로 확인하고 설치 가이드를 제공하는 클래스"""
|
||||
|
||||
def __init__(self, logger: Optional[object] = None):
|
||||
self.logger = logger or self._create_dummy_logger()
|
||||
self.status = {}
|
||||
|
||||
def _create_dummy_logger(self):
|
||||
"""로거가 없을 때 사용할 더미 로거"""
|
||||
class DummyLogger:
|
||||
def log(self, msg, level=logging.INFO, exc_info=False):
|
||||
print(f"[GPU_CHECK] {msg}")
|
||||
return DummyLogger()
|
||||
|
||||
def check_all_gpu_components(self) -> Dict[str, any]:
|
||||
"""모든 GPU 관련 구성요소를 확인하고 상태 반환"""
|
||||
self.status = {
|
||||
'nvidia_gpu': self._check_nvidia_gpu(),
|
||||
'cuda_toolkit': self._check_cuda_toolkit(),
|
||||
'cudnn': self._check_cudnn(),
|
||||
'tensorrt': self._check_tensorrt(),
|
||||
'onnxruntime_gpu': self._check_onnxruntime_gpu(),
|
||||
'path_configured': self._check_path_configuration(),
|
||||
'recommendations': []
|
||||
}
|
||||
|
||||
# 권장사항 생성
|
||||
self._generate_recommendations()
|
||||
|
||||
return self.status
|
||||
|
||||
def _check_nvidia_gpu(self) -> Dict[str, any]:
|
||||
"""NVIDIA GPU 및 드라이버 확인 (상세 정보 포함)"""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["nvidia-smi", "--query-gpu=name,driver_version,memory.total,compute_cap", "--format=csv,noheader"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10,
|
||||
creationflags=subprocess.CREATE_NO_WINDOW if platform.system() == "Windows" else 0
|
||||
)
|
||||
|
||||
if result.returncode == 0:
|
||||
lines = result.stdout.strip().split('\n')
|
||||
gpus = []
|
||||
for i, line in enumerate(lines):
|
||||
if line.strip():
|
||||
parts = [p.strip() for p in line.split(',')]
|
||||
if len(parts) >= 3:
|
||||
gpu_info = {
|
||||
'id': i,
|
||||
'name': parts[0],
|
||||
'driver_version': parts[1],
|
||||
'memory_mb': parts[2].replace(' MiB', ''),
|
||||
'compute_capability': parts[3] if len(parts) > 3 else 'Unknown',
|
||||
'architecture': self._detect_gpu_architecture(parts[0]),
|
||||
'cuda_support': self._check_cuda_compatibility(parts[0])
|
||||
}
|
||||
gpus.append(gpu_info)
|
||||
|
||||
return {
|
||||
'installed': True,
|
||||
'gpus': gpus,
|
||||
'count': len(gpus),
|
||||
'message': f"{len(gpus)}개의 NVIDIA GPU 감지됨",
|
||||
'primary_gpu': gpus[0] if gpus else None
|
||||
}
|
||||
else:
|
||||
return {
|
||||
'installed': False,
|
||||
'gpus': [],
|
||||
'count': 0,
|
||||
'message': "NVIDIA GPU 또는 드라이버가 설치되지 않음"
|
||||
}
|
||||
|
||||
except (subprocess.TimeoutExpired, FileNotFoundError):
|
||||
return {
|
||||
'installed': False,
|
||||
'gpus': [],
|
||||
'count': 0,
|
||||
'message': "nvidia-smi 명령을 찾을 수 없음 (NVIDIA 드라이버 미설치)"
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
'installed': False,
|
||||
'gpus': [],
|
||||
'count': 0,
|
||||
'message': f"GPU 확인 중 오류: {e}"
|
||||
}
|
||||
|
||||
def _detect_gpu_architecture(self, gpu_name: str) -> str:
|
||||
"""GPU 이름을 기반으로 아키텍처 감지"""
|
||||
gpu_name_lower = gpu_name.lower()
|
||||
|
||||
# RTX 40 시리즈 (Ada Lovelace)
|
||||
if any(model in gpu_name_lower for model in ['rtx 40', 'rtx 41', 'rtx 42', 'rtx 43', 'rtx 44']):
|
||||
return 'Ada Lovelace (RTX 40 시리즈)'
|
||||
|
||||
# RTX 30 시리즈 (Ampere)
|
||||
elif any(model in gpu_name_lower for model in ['rtx 30', 'rtx 31', 'rtx 32', 'rtx 33', 'rtx 34']):
|
||||
return 'Ampere (RTX 30 시리즈)'
|
||||
|
||||
# RTX 20 시리즈 (Turing)
|
||||
elif any(model in gpu_name_lower for model in ['rtx 20', 'rtx 21', 'rtx 22', 'rtx 23', 'rtx 24']):
|
||||
return 'Turing (RTX 20 시리즈)'
|
||||
|
||||
# GTX 16 시리즈 (Turing)
|
||||
elif any(model in gpu_name_lower for model in ['gtx 16', 'gtx 165', 'gtx 166']):
|
||||
return 'Turing (GTX 16 시리즈)'
|
||||
|
||||
# GTX 10 시리즈 (Pascal)
|
||||
elif any(model in gpu_name_lower for model in ['gtx 10', 'gtx 105', 'gtx 106', 'gtx 107', 'gtx 108']):
|
||||
return 'Pascal (GTX 10 시리즈)'
|
||||
|
||||
# GTX 900 시리즈 (Maxwell)
|
||||
elif any(model in gpu_name_lower for model in ['gtx 9', 'gtx 90']):
|
||||
return 'Maxwell (GTX 900 시리즈)'
|
||||
|
||||
# Quadro 카드들
|
||||
elif 'quadro' in gpu_name_lower:
|
||||
return 'Quadro 워크스테이션'
|
||||
|
||||
# Tesla 카드들
|
||||
elif 'tesla' in gpu_name_lower:
|
||||
return 'Tesla 데이터센터'
|
||||
|
||||
else:
|
||||
return 'Unknown'
|
||||
|
||||
def _check_cuda_compatibility(self, gpu_name: str) -> Dict[str, any]:
|
||||
"""GPU의 CUDA 12.1 호환성 확인"""
|
||||
gpu_name_lower = gpu_name.lower()
|
||||
|
||||
# CUDA 12.1을 완전히 지원하는 GPU들 (Compute Capability 6.0 이상)
|
||||
fully_supported = [
|
||||
'rtx 40', 'rtx 41', 'rtx 42', 'rtx 43', 'rtx 44', # Ada Lovelace
|
||||
'rtx 30', 'rtx 31', 'rtx 32', 'rtx 33', 'rtx 34', # Ampere
|
||||
'rtx 20', 'rtx 21', 'rtx 22', 'rtx 23', 'rtx 24', # Turing
|
||||
'gtx 16', 'gtx 165', 'gtx 166', # Turing GTX
|
||||
'gtx 10', 'gtx 105', 'gtx 106', 'gtx 107', 'gtx 108' # Pascal
|
||||
]
|
||||
|
||||
# 제한적 지원 (구형 아키텍처)
|
||||
limited_support = [
|
||||
'gtx 9', 'gtx 90' # Maxwell
|
||||
]
|
||||
|
||||
if any(model in gpu_name_lower for model in fully_supported):
|
||||
return {
|
||||
'supported': True,
|
||||
'level': 'full',
|
||||
'message': 'CUDA 12.1 완전 지원'
|
||||
}
|
||||
elif any(model in gpu_name_lower for model in limited_support):
|
||||
return {
|
||||
'supported': True,
|
||||
'level': 'limited',
|
||||
'message': 'CUDA 12.1 제한적 지원 (구형 아키텍처)'
|
||||
}
|
||||
else:
|
||||
return {
|
||||
'supported': False,
|
||||
'level': 'none',
|
||||
'message': 'CUDA 12.1 지원 여부 확인 필요'
|
||||
}
|
||||
|
||||
def _check_cuda_toolkit(self) -> Dict[str, any]:
|
||||
"""CUDA Toolkit 설치 확인"""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["nvcc", "--version"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10,
|
||||
creationflags=subprocess.CREATE_NO_WINDOW if platform.system() == "Windows" else 0
|
||||
)
|
||||
|
||||
if result.returncode == 0:
|
||||
output = result.stdout
|
||||
version_info = ""
|
||||
for line in output.split('\n'):
|
||||
if 'release' in line.lower():
|
||||
version_info = line.strip()
|
||||
break
|
||||
|
||||
return {
|
||||
'installed': True,
|
||||
'version': version_info,
|
||||
'message': f"CUDA Toolkit 설치됨: {version_info}"
|
||||
}
|
||||
else:
|
||||
return {
|
||||
'installed': False,
|
||||
'version': None,
|
||||
'message': "CUDA Toolkit이 설치되지 않음"
|
||||
}
|
||||
|
||||
except (subprocess.TimeoutExpired, FileNotFoundError):
|
||||
return {
|
||||
'installed': False,
|
||||
'version': None,
|
||||
'message': "nvcc 명령을 찾을 수 없음 (CUDA Toolkit 미설치)"
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
'installed': False,
|
||||
'version': None,
|
||||
'message': f"CUDA Toolkit 확인 중 오류: {e}"
|
||||
}
|
||||
|
||||
def _check_cudnn(self) -> Dict[str, any]:
|
||||
"""cuDNN 설치 확인"""
|
||||
cudnn_paths = []
|
||||
cudnn_versions = []
|
||||
|
||||
# Windows에서 일반적인 cuDNN 설치 경로들
|
||||
if platform.system() == "Windows":
|
||||
search_paths = [
|
||||
"C:/Program Files/NVIDIA/CUDNN",
|
||||
"C:/tools/cuda/cudnn",
|
||||
"C:/cudnn"
|
||||
]
|
||||
|
||||
# 환경변수에서도 찾기
|
||||
cuda_path = os.environ.get('CUDA_PATH', '')
|
||||
if cuda_path:
|
||||
search_paths.append(os.path.join(cuda_path, 'cudnn'))
|
||||
|
||||
for base_path in search_paths:
|
||||
if os.path.exists(base_path):
|
||||
# cuDNN 버전 디렉토리 찾기
|
||||
version_dirs = glob.glob(os.path.join(base_path, "v*"))
|
||||
for version_dir in version_dirs:
|
||||
bin_dirs = glob.glob(os.path.join(version_dir, "bin", "*"))
|
||||
for bin_dir in bin_dirs:
|
||||
cudnn_dll = os.path.join(bin_dir, "cudnn64_*.dll")
|
||||
dll_files = glob.glob(cudnn_dll)
|
||||
if dll_files:
|
||||
version_name = os.path.basename(version_dir)
|
||||
cuda_version = os.path.basename(bin_dir)
|
||||
cudnn_paths.append({
|
||||
'path': bin_dir,
|
||||
'version': version_name,
|
||||
'cuda_version': cuda_version,
|
||||
'dll_files': [os.path.basename(f) for f in dll_files]
|
||||
})
|
||||
|
||||
if cudnn_paths:
|
||||
return {
|
||||
'installed': True,
|
||||
'paths': cudnn_paths,
|
||||
'count': len(cudnn_paths),
|
||||
'message': f"cuDNN 설치됨: {len(cudnn_paths)}개 버전 발견"
|
||||
}
|
||||
else:
|
||||
return {
|
||||
'installed': False,
|
||||
'paths': [],
|
||||
'count': 0,
|
||||
'message': "cuDNN이 설치되지 않음"
|
||||
}
|
||||
|
||||
def _check_tensorrt(self) -> Dict[str, any]:
|
||||
"""TensorRT 설치 확인"""
|
||||
try:
|
||||
# ONNXRuntime에서 TensorRT provider 확인
|
||||
import onnxruntime as ort
|
||||
available_providers = ort.get_available_providers()
|
||||
tensorrt_available = 'TensorrtExecutionProvider' in available_providers
|
||||
|
||||
if tensorrt_available:
|
||||
# TensorRT 설치 경로 찾기 시도
|
||||
tensorrt_paths = []
|
||||
if platform.system() == "Windows":
|
||||
search_paths = [
|
||||
"C:/Program Files/NVIDIA/TensorRT",
|
||||
"C:/TensorRT"
|
||||
]
|
||||
for path in search_paths:
|
||||
if os.path.exists(path):
|
||||
tensorrt_paths.append(path)
|
||||
|
||||
return {
|
||||
'installed': True,
|
||||
'available_in_onnx': True,
|
||||
'paths': tensorrt_paths,
|
||||
'message': "TensorRT 설치됨 (ONNXRuntime에서 사용 가능)"
|
||||
}
|
||||
else:
|
||||
return {
|
||||
'installed': False,
|
||||
'available_in_onnx': False,
|
||||
'paths': [],
|
||||
'message': "TensorRT가 설치되지 않음 또는 ONNXRuntime에서 인식되지 않음"
|
||||
}
|
||||
|
||||
except ImportError:
|
||||
return {
|
||||
'installed': False,
|
||||
'available_in_onnx': False,
|
||||
'paths': [],
|
||||
'message': "ONNXRuntime이 설치되지 않아 TensorRT 확인 불가"
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
'installed': False,
|
||||
'available_in_onnx': False,
|
||||
'paths': [],
|
||||
'message': f"TensorRT 확인 중 오류: {e}"
|
||||
}
|
||||
|
||||
def _check_onnxruntime_gpu(self) -> Dict[str, any]:
|
||||
"""ONNXRuntime GPU 지원 확인"""
|
||||
try:
|
||||
import onnxruntime as ort
|
||||
available_providers = ort.get_available_providers()
|
||||
|
||||
gpu_providers = []
|
||||
if 'CUDAExecutionProvider' in available_providers:
|
||||
gpu_providers.append('CUDA')
|
||||
if 'TensorrtExecutionProvider' in available_providers:
|
||||
gpu_providers.append('TensorRT')
|
||||
|
||||
return {
|
||||
'installed': True,
|
||||
'gpu_providers': gpu_providers,
|
||||
'all_providers': available_providers,
|
||||
'message': f"ONNXRuntime GPU 지원: {', '.join(gpu_providers) if gpu_providers else 'CPU만 지원'}"
|
||||
}
|
||||
|
||||
except ImportError:
|
||||
return {
|
||||
'installed': False,
|
||||
'gpu_providers': [],
|
||||
'all_providers': [],
|
||||
'message': "ONNXRuntime이 설치되지 않음"
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
'installed': False,
|
||||
'gpu_providers': [],
|
||||
'all_providers': [],
|
||||
'message': f"ONNXRuntime 확인 중 오류: {e}"
|
||||
}
|
||||
|
||||
def _check_path_configuration(self) -> Dict[str, any]:
|
||||
"""PATH 환경변수에 cuDNN이 설정되어 있는지 확인"""
|
||||
current_path = os.environ.get('PATH', '')
|
||||
cudnn_in_path = []
|
||||
|
||||
# cuDNN 상태를 직접 확인 (self.status가 아직 완전하지 않을 수 있음)
|
||||
cudnn_status = self._check_cudnn()
|
||||
|
||||
if cudnn_status.get('installed', False):
|
||||
cudnn_paths = cudnn_status['paths']
|
||||
path_parts = current_path.split(os.pathsep)
|
||||
|
||||
for cudnn_info in cudnn_paths:
|
||||
cudnn_bin_path = cudnn_info['path']
|
||||
|
||||
# PATH에 있는지 더 정확하게 확인
|
||||
path_exists = any(
|
||||
os.path.normpath(part.strip()) == os.path.normpath(cudnn_bin_path)
|
||||
for part in path_parts if part.strip()
|
||||
)
|
||||
|
||||
if path_exists:
|
||||
cudnn_in_path.append(cudnn_info)
|
||||
|
||||
return {
|
||||
'cudnn_in_path': len(cudnn_in_path) > 0,
|
||||
'configured_paths': cudnn_in_path,
|
||||
'message': f"PATH 설정: {'올바름' if cudnn_in_path else 'cuDNN 경로 누락'}"
|
||||
}
|
||||
|
||||
def _generate_recommendations(self):
|
||||
"""현재 상태에 따른 권장사항 생성"""
|
||||
recommendations = []
|
||||
|
||||
# NVIDIA GPU 확인
|
||||
if not self.status['nvidia_gpu']['installed']:
|
||||
recommendations.append({
|
||||
'priority': 'critical',
|
||||
'component': 'NVIDIA GPU',
|
||||
'action': 'NVIDIA GPU 드라이버 설치',
|
||||
'description': 'NVIDIA 공식 사이트에서 최신 드라이버를 다운로드하여 설치하세요.',
|
||||
'url': 'https://www.nvidia.com/drivers/'
|
||||
})
|
||||
|
||||
# CUDA Toolkit 확인
|
||||
if not self.status['cuda_toolkit']['installed']:
|
||||
recommendations.append({
|
||||
'priority': 'critical',
|
||||
'component': 'CUDA Toolkit',
|
||||
'action': 'CUDA Toolkit 12.1 설치',
|
||||
'description': 'NVIDIA CUDA Toolkit 12.1을 설치하세요.',
|
||||
'url': 'https://developer.nvidia.com/cuda-toolkit-archive'
|
||||
})
|
||||
|
||||
# cuDNN 확인
|
||||
if not self.status['cudnn']['installed']:
|
||||
recommendations.append({
|
||||
'priority': 'critical',
|
||||
'component': 'cuDNN',
|
||||
'action': 'cuDNN v9.x 설치',
|
||||
'description': 'NVIDIA cuDNN v9.x를 다운로드하여 설치하세요.',
|
||||
'url': 'https://developer.nvidia.com/cudnn'
|
||||
})
|
||||
elif not self.status['path_configured']['cudnn_in_path']:
|
||||
recommendations.append({
|
||||
'priority': 'high',
|
||||
'component': 'cuDNN PATH',
|
||||
'action': 'cuDNN PATH 환경변수 설정',
|
||||
'description': 'cuDNN bin 디렉토리를 시스템 PATH에 추가하세요.',
|
||||
'auto_fix': True
|
||||
})
|
||||
|
||||
# TensorRT 확인 (선택사항)
|
||||
if not self.status['tensorrt']['installed']:
|
||||
recommendations.append({
|
||||
'priority': 'medium',
|
||||
'component': 'TensorRT',
|
||||
'action': 'TensorRT 설치 (선택사항)',
|
||||
'description': 'TensorRT를 설치하면 최고 성능의 GPU 가속을 사용할 수 있습니다.',
|
||||
'url': 'https://developer.nvidia.com/tensorrt'
|
||||
})
|
||||
|
||||
# ONNXRuntime GPU 확인
|
||||
if not self.status['onnxruntime_gpu']['installed']:
|
||||
recommendations.append({
|
||||
'priority': 'critical',
|
||||
'component': 'ONNXRuntime GPU',
|
||||
'action': 'onnxruntime-gpu 설치',
|
||||
'description': 'pip install onnxruntime-gpu 명령으로 설치하세요.',
|
||||
'command': 'pip install onnxruntime-gpu'
|
||||
})
|
||||
|
||||
self.status['recommendations'] = recommendations
|
||||
|
||||
def auto_fix_cudnn_path(self) -> bool:
|
||||
"""cuDNN PATH를 자동으로 설정"""
|
||||
try:
|
||||
# 현재 상태 다시 확인
|
||||
cudnn_status = self._check_cudnn()
|
||||
|
||||
if not cudnn_status.get('installed', False):
|
||||
self.logger.log("cuDNN이 설치되지 않아 PATH 설정을 할 수 없습니다", level=logging.ERROR)
|
||||
return False
|
||||
|
||||
cudnn_paths = cudnn_status['paths']
|
||||
if not cudnn_paths:
|
||||
self.logger.log("cuDNN 설치 경로를 찾을 수 없습니다", level=logging.ERROR)
|
||||
return False
|
||||
|
||||
# 가장 최신 버전의 cuDNN 경로 선택
|
||||
latest_cudnn = max(cudnn_paths, key=lambda x: x['version'])
|
||||
cudnn_bin_path = latest_cudnn['path']
|
||||
|
||||
self.logger.log(f"cuDNN 경로 확인됨: {cudnn_bin_path}", level=logging.INFO)
|
||||
|
||||
# 현재 세션의 PATH에 추가
|
||||
current_path = os.environ.get('PATH', '')
|
||||
|
||||
# PATH에 이미 있는지 더 정확하게 확인
|
||||
path_parts = current_path.split(os.pathsep)
|
||||
path_already_exists = any(
|
||||
os.path.normpath(part.strip()) == os.path.normpath(cudnn_bin_path)
|
||||
for part in path_parts if part.strip()
|
||||
)
|
||||
|
||||
if not path_already_exists:
|
||||
# PATH 맨 앞에 추가 (우선순위 높게)
|
||||
new_path = cudnn_bin_path + os.pathsep + current_path
|
||||
os.environ['PATH'] = new_path
|
||||
self.logger.log(f"✅ cuDNN PATH 추가됨: {cudnn_bin_path}", level=logging.INFO)
|
||||
|
||||
# 설정 후 DLL 로딩 테스트
|
||||
test_result = self._test_cudnn_loading(cudnn_bin_path)
|
||||
if test_result:
|
||||
self.logger.log("✅ cuDNN DLL 로딩 테스트 성공", level=logging.INFO)
|
||||
else:
|
||||
self.logger.log("⚠️ cuDNN DLL 로딩 테스트 실패 - 재부팅이 필요할 수 있습니다", level=logging.WARNING)
|
||||
|
||||
return True
|
||||
else:
|
||||
self.logger.log("cuDNN PATH가 이미 설정되어 있습니다", level=logging.INFO)
|
||||
|
||||
# 이미 있어도 DLL 로딩 테스트
|
||||
test_result = self._test_cudnn_loading(cudnn_bin_path)
|
||||
if not test_result:
|
||||
self.logger.log("⚠️ PATH는 설정되어 있지만 cuDNN DLL 로딩에 실패했습니다", level=logging.WARNING)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.logger.log(f"cuDNN PATH 설정 실패: {e}", level=logging.ERROR, exc_info=True)
|
||||
return False
|
||||
|
||||
def _test_cudnn_loading(self, cudnn_bin_path: str) -> bool:
|
||||
"""cuDNN DLL 로딩 테스트"""
|
||||
try:
|
||||
import ctypes
|
||||
import glob
|
||||
|
||||
# cudnn64_*.dll 파일 찾기
|
||||
dll_pattern = os.path.join(cudnn_bin_path, "cudnn64_*.dll")
|
||||
dll_files = glob.glob(dll_pattern)
|
||||
|
||||
if not dll_files:
|
||||
self.logger.log(f"cuDNN DLL 파일을 찾을 수 없음: {dll_pattern}", level=logging.WARNING)
|
||||
return False
|
||||
|
||||
# 첫 번째 DLL 로딩 테스트
|
||||
dll_file = dll_files[0]
|
||||
try:
|
||||
dll = ctypes.CDLL(dll_file)
|
||||
self.logger.log(f"cuDNN DLL 로딩 성공: {os.path.basename(dll_file)}", level=logging.DEBUG)
|
||||
return True
|
||||
except OSError as e:
|
||||
self.logger.log(f"cuDNN DLL 로딩 실패: {e}", level=logging.WARNING)
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
self.logger.log(f"cuDNN DLL 테스트 중 오류: {e}", level=logging.DEBUG)
|
||||
return False
|
||||
|
||||
def get_installation_commands(self) -> List[str]:
|
||||
"""설치가 필요한 구성요소들의 설치 명령어 반환"""
|
||||
commands = []
|
||||
|
||||
for rec in self.status.get('recommendations', []):
|
||||
if rec.get('command'):
|
||||
commands.append(rec['command'])
|
||||
|
||||
return commands
|
||||
|
||||
def generate_status_report(self) -> str:
|
||||
"""상태 보고서 생성"""
|
||||
report = ["=== GPU 환경 상태 보고서 ===\n"]
|
||||
|
||||
# NVIDIA GPU
|
||||
gpu_status = self.status.get('nvidia_gpu', {})
|
||||
report.append(f"🎮 NVIDIA GPU: {'✅' if gpu_status.get('installed') else '❌'}")
|
||||
if gpu_status.get('installed'):
|
||||
for gpu in gpu_status.get('gpus', []):
|
||||
report.append(f" - {gpu['name']} ({gpu['memory_mb']}MB, 드라이버 {gpu['driver_version']})")
|
||||
else:
|
||||
report.append(f" - {gpu_status.get('message', 'Unknown')}")
|
||||
|
||||
# CUDA Toolkit
|
||||
cuda_status = self.status.get('cuda_toolkit', {})
|
||||
report.append(f"\n🛠️ CUDA Toolkit: {'✅' if cuda_status.get('installed') else '❌'}")
|
||||
report.append(f" - {cuda_status.get('message', 'Unknown')}")
|
||||
|
||||
# cuDNN
|
||||
cudnn_status = self.status.get('cudnn', {})
|
||||
report.append(f"\n🧠 cuDNN: {'✅' if cudnn_status.get('installed') else '❌'}")
|
||||
if cudnn_status.get('installed'):
|
||||
for path_info in cudnn_status.get('paths', []):
|
||||
report.append(f" - {path_info['version']} (CUDA {path_info['cuda_version']})")
|
||||
else:
|
||||
report.append(f" - {cudnn_status.get('message', 'Unknown')}")
|
||||
|
||||
# TensorRT
|
||||
tensorrt_status = self.status.get('tensorrt', {})
|
||||
report.append(f"\n🚀 TensorRT: {'✅' if tensorrt_status.get('installed') else '❌'}")
|
||||
report.append(f" - {tensorrt_status.get('message', 'Unknown')}")
|
||||
|
||||
# ONNXRuntime
|
||||
onnx_status = self.status.get('onnxruntime_gpu', {})
|
||||
report.append(f"\n📦 ONNXRuntime: {'✅' if onnx_status.get('installed') else '❌'}")
|
||||
if onnx_status.get('installed'):
|
||||
providers = onnx_status.get('gpu_providers', [])
|
||||
report.append(f" - GPU 지원: {', '.join(providers) if providers else 'CPU만'}")
|
||||
else:
|
||||
report.append(f" - {onnx_status.get('message', 'Unknown')}")
|
||||
|
||||
# PATH 설정
|
||||
path_status = self.status.get('path_configured', {})
|
||||
report.append(f"\n⚙️ PATH 설정: {'✅' if path_status.get('cudnn_in_path') else '❌'}")
|
||||
report.append(f" - {path_status.get('message', 'Unknown')}")
|
||||
|
||||
# 권장사항
|
||||
recommendations = self.status.get('recommendations', [])
|
||||
if recommendations:
|
||||
report.append(f"\n📋 권장사항 ({len(recommendations)}개):")
|
||||
for i, rec in enumerate(recommendations, 1):
|
||||
priority = rec['priority'].upper()
|
||||
report.append(f" {i}. [{priority}] {rec['action']}")
|
||||
report.append(f" {rec['description']}")
|
||||
|
||||
return '\n'.join(report)
|
||||
|
||||
def generate_customized_installation_guide(self) -> str:
|
||||
"""감지된 GPU에 맞는 맞춤형 설치 가이드 생성"""
|
||||
gpu_status = self.status.get('nvidia_gpu', {})
|
||||
primary_gpu = gpu_status.get('primary_gpu')
|
||||
|
||||
if not gpu_status.get('installed', False) or not primary_gpu:
|
||||
return self._generate_generic_installation_guide()
|
||||
|
||||
gpu_name = primary_gpu.get('name', 'Unknown GPU')
|
||||
architecture = primary_gpu.get('architecture', 'Unknown')
|
||||
cuda_support = primary_gpu.get('cuda_support', {})
|
||||
driver_version = primary_gpu.get('driver_version', 'Unknown')
|
||||
memory_mb = primary_gpu.get('memory_mb', 'Unknown')
|
||||
|
||||
guide_parts = []
|
||||
|
||||
# GPU 정보 헤더
|
||||
guide_parts.append(f"""
|
||||
# 🎮 {gpu_name} 맞춤형 GPU 가속 설치 가이드
|
||||
|
||||
## 📊 감지된 GPU 정보
|
||||
- **GPU 모델**: {gpu_name}
|
||||
- **아키텍처**: {architecture}
|
||||
- **메모리**: {memory_mb} MB
|
||||
- **드라이버 버전**: {driver_version}
|
||||
- **CUDA 12.1 지원**: {cuda_support.get('message', 'Unknown')}
|
||||
""")
|
||||
|
||||
# CUDA 호환성에 따른 안내
|
||||
if cuda_support.get('supported', False):
|
||||
if cuda_support.get('level') == 'full':
|
||||
guide_parts.append("""
|
||||
## ✅ 호환성 확인
|
||||
귀하의 GPU는 CUDA 12.1을 완전히 지원합니다! 최고 성능의 GPU 가속을 사용할 수 있습니다.
|
||||
""")
|
||||
elif cuda_support.get('level') == 'limited':
|
||||
guide_parts.append("""
|
||||
## ⚠️ 호환성 확인
|
||||
귀하의 GPU는 CUDA 12.1을 제한적으로 지원합니다. 기본적인 GPU 가속은 가능하지만 최신 기능은 제한될 수 있습니다.
|
||||
""")
|
||||
else:
|
||||
guide_parts.append("""
|
||||
## ❓ 호환성 확인
|
||||
귀하의 GPU의 CUDA 12.1 호환성을 확인할 수 없습니다. 설치 후 정상 작동 여부를 확인해주세요.
|
||||
""")
|
||||
|
||||
# 맞춤형 드라이버 링크
|
||||
driver_link = self._get_driver_download_link(gpu_name)
|
||||
guide_parts.append(f"""
|
||||
## 🎯 맞춤형 설치 링크
|
||||
|
||||
### 1단계: NVIDIA 드라이버 업데이트
|
||||
현재 드라이버: **{driver_version}**
|
||||
|
||||
{driver_link}
|
||||
|
||||
### 2단계: CUDA Toolkit 12.1 설치
|
||||
**권장 버전**: CUDA Toolkit 12.1 (편집알바생 최적화 버전)
|
||||
- 📥 [CUDA 12.1 다운로드](https://developer.nvidia.com/cuda-12-1-0-download-archive)
|
||||
- 설치 시 "Express 설치" 선택 권장
|
||||
- 환경변수는 자동으로 설정됩니다
|
||||
|
||||
### 3단계: cuDNN v9.x 설치
|
||||
**권장 버전**: cuDNN v9.12 for CUDA 12.1
|
||||
- 📥 [cuDNN 다운로드](https://developer.nvidia.com/cudnn) (NVIDIA 계정 필요)
|
||||
- **Windows 설치 파일**: `cudnn-windows-x86_64-9.x.x.x_cuda12.exe`
|
||||
- **설치 방법**: 다운로드한 .exe 파일을 실행하여 자동 설치
|
||||
- **자동 설정**: 설치 후 PATH 환경변수 자동 설정됨
|
||||
|
||||
### 4단계: 프로그램 실행
|
||||
- 프로그램을 실행하면 자동으로 GPU 가속 상태를 확인합니다
|
||||
- 별도의 Python 패키지 설치가 필요하지 않습니다
|
||||
""")
|
||||
|
||||
# GPU별 성능 최적화 팁
|
||||
optimization_tips = self._get_optimization_tips(gpu_name, architecture, memory_mb)
|
||||
guide_parts.append(optimization_tips)
|
||||
|
||||
# 설치 확인 방법
|
||||
guide_parts.append("""
|
||||
## 🧪 설치 확인
|
||||
|
||||
설치 완료 후 다음 버튼들을 클릭하여 확인:
|
||||
- **🎮 GPU 상태 확인**: nvidia-smi 명령 실행
|
||||
- **🔧 CUDA 버전 확인**: nvcc --version 명령 실행
|
||||
- **🐍 ONNXRuntime GPU 확인**: 프로그램 내 GPU 가속 지원 확인
|
||||
|
||||
또는 직접 명령 프롬프트에서 확인:
|
||||
```cmd
|
||||
nvidia-smi # GPU 상태 확인
|
||||
nvcc --version # CUDA 확인
|
||||
```
|
||||
|
||||
**참고**: 이 프로그램은 cx_Freeze로 패키징되어 있어 Python 환경 설정이 필요하지 않습니다.
|
||||
|
||||
## 🆘 문제 해결
|
||||
|
||||
### "cudnn64_9.dll을 찾을 수 없음" 오류
|
||||
1. 이 다이얼로그에서 **"⚙️ cuDNN PATH 자동 설정"** 버튼 클릭
|
||||
2. 또는 수동으로 cuDNN bin 디렉토리를 PATH에 추가
|
||||
|
||||
### 성능이 예상보다 느린 경우
|
||||
- GPU 온도 확인 (과열 시 성능 저하)
|
||||
- 다른 GPU 사용 프로그램 종료
|
||||
- 드라이버를 최신 버전으로 업데이트
|
||||
|
||||
### 메모리 부족 오류
|
||||
- 이미지 크기 줄이기
|
||||
- 다른 프로그램 종료하여 GPU 메모리 확보
|
||||
- 배치 크기 조정 (고급 사용자)
|
||||
""")
|
||||
|
||||
return ''.join(guide_parts).strip()
|
||||
|
||||
def _get_driver_download_link(self, gpu_name: str) -> str:
|
||||
"""GPU에 맞는 드라이버 다운로드 링크 생성"""
|
||||
gpu_name_lower = gpu_name.lower()
|
||||
|
||||
# RTX 40 시리즈
|
||||
if any(model in gpu_name_lower for model in ['rtx 40', 'rtx 41', 'rtx 42', 'rtx 43', 'rtx 44']):
|
||||
return """- 📥 [RTX 40 시리즈 최신 드라이버](https://www.nvidia.com/drivers/results/218194/)
|
||||
- **권장**: Game Ready Driver 또는 Studio Driver 최신 버전"""
|
||||
|
||||
# RTX 30 시리즈
|
||||
elif any(model in gpu_name_lower for model in ['rtx 30', 'rtx 31', 'rtx 32', 'rtx 33', 'rtx 34']):
|
||||
return """- 📥 [RTX 30 시리즈 최신 드라이버](https://www.nvidia.com/drivers/results/218194/)
|
||||
- **권장**: Game Ready Driver 또는 Studio Driver 최신 버전"""
|
||||
|
||||
# RTX 20 시리즈
|
||||
elif any(model in gpu_name_lower for model in ['rtx 20', 'rtx 21', 'rtx 22', 'rtx 23', 'rtx 24']):
|
||||
return """- 📥 [RTX 20 시리즈 최신 드라이버](https://www.nvidia.com/drivers/results/218194/)
|
||||
- **권장**: Game Ready Driver 최신 버전"""
|
||||
|
||||
# GTX 16 시리즈
|
||||
elif any(model in gpu_name_lower for model in ['gtx 16', 'gtx 165', 'gtx 166']):
|
||||
return """- 📥 [GTX 16 시리즈 최신 드라이버](https://www.nvidia.com/drivers/results/218194/)
|
||||
- **권장**: Game Ready Driver 최신 버전"""
|
||||
|
||||
# GTX 10 시리즈
|
||||
elif any(model in gpu_name_lower for model in ['gtx 10', 'gtx 105', 'gtx 106', 'gtx 107', 'gtx 108']):
|
||||
return """- 📥 [GTX 10 시리즈 최신 드라이버](https://www.nvidia.com/drivers/results/218194/)
|
||||
- **참고**: 구형 GPU이므로 최신 기능 일부가 제한될 수 있습니다"""
|
||||
|
||||
# 기타
|
||||
else:
|
||||
return """- 📥 [NVIDIA 드라이버 자동 감지](https://www.nvidia.com/drivers/)
|
||||
- **방법**: 사이트에서 GPU 모델 선택 후 다운로드"""
|
||||
|
||||
def _get_optimization_tips(self, gpu_name: str, architecture: str, memory_mb: str) -> str:
|
||||
"""GPU별 성능 최적화 팁"""
|
||||
tips = ["\n## 🚀 성능 최적화 팁\n"]
|
||||
|
||||
try:
|
||||
memory_gb = int(memory_mb) // 1024 if memory_mb.isdigit() else 0
|
||||
except:
|
||||
memory_gb = 0
|
||||
|
||||
# 메모리 기반 최적화
|
||||
if memory_gb >= 16:
|
||||
tips.append("### 🎯 대용량 메모리 (16GB+) 최적화")
|
||||
tips.append("- **TensorRT 설치 권장**: 최고 성능을 위해 TensorRT 추가 설치")
|
||||
tips.append("- **배치 크기 증가**: 더 큰 이미지나 배치 처리 가능")
|
||||
tips.append("- **동시 처리**: 여러 이미지 동시 처리 가능")
|
||||
elif memory_gb >= 8:
|
||||
tips.append("### ⚡ 중간 메모리 (8-16GB) 최적화")
|
||||
tips.append("- **기본 설정 사용**: 현재 설정이 최적화됨")
|
||||
tips.append("- **메모리 모니터링**: 다른 프로그램과 메모리 공유 주의")
|
||||
elif memory_gb >= 4:
|
||||
tips.append("### 💡 제한적 메모리 (4-8GB) 최적화")
|
||||
tips.append("- **이미지 크기 제한**: 큰 이미지는 성능 저하 가능")
|
||||
tips.append("- **순차 처리**: 동시 처리보다는 순차 처리 권장")
|
||||
tips.append("- **메모리 정리**: 사용 후 자동 메모리 정리 활성화됨")
|
||||
else:
|
||||
tips.append("### ⚠️ 저용량 메모리 (4GB 미만)")
|
||||
tips.append("- **CPU 모드 권장**: GPU 메모리 부족으로 CPU 모드가 더 안정적일 수 있음")
|
||||
tips.append("- **작은 이미지만**: 큰 이미지 처리 시 오류 발생 가능")
|
||||
|
||||
# 아키텍처별 최적화
|
||||
if 'Ada Lovelace' in architecture or 'RTX 40' in architecture:
|
||||
tips.append("\n### 🔥 RTX 40 시리즈 특화 최적화")
|
||||
tips.append("- **TensorRT 필수**: RTX 40 시리즈의 성능을 최대화")
|
||||
tips.append("- **AV1 인코딩**: 영상 처리 시 AV1 인코더 활용 가능")
|
||||
elif 'Ampere' in architecture or 'RTX 30' in architecture:
|
||||
tips.append("\n### ⚡ RTX 30 시리즈 특화 최적화")
|
||||
tips.append("- **TensorRT 권장**: 30 시리즈의 RT Core 활용")
|
||||
tips.append("- **DLSS 기술**: AI 가속 기능 최적화됨")
|
||||
elif 'Pascal' in architecture or 'GTX 10' in architecture:
|
||||
tips.append("\n### 🛠️ GTX 10 시리즈 최적화")
|
||||
tips.append("- **기본 CUDA**: TensorRT보다는 기본 CUDA가 안정적")
|
||||
tips.append("- **드라이버 업데이트**: 정기적인 드라이버 업데이트 중요")
|
||||
|
||||
return '\n'.join(tips)
|
||||
|
||||
def _generate_generic_installation_guide(self) -> str:
|
||||
"""GPU가 감지되지 않은 경우의 일반적인 설치 가이드"""
|
||||
return """
|
||||
# GPU 가속을 위한 일반 설치 가이드
|
||||
|
||||
## ❓ GPU를 감지할 수 없습니다
|
||||
|
||||
NVIDIA GPU가 감지되지 않았습니다. 다음 사항을 확인해주세요:
|
||||
|
||||
### 1. GPU 확인
|
||||
- NVIDIA GPU가 설치되어 있는지 확인
|
||||
- 장치 관리자에서 디스플레이 어댑터 확인
|
||||
|
||||
### 2. 드라이버 설치
|
||||
- NVIDIA 공식 사이트에서 드라이버 다운로드
|
||||
- https://www.nvidia.com/drivers/
|
||||
|
||||
### 3. 일반적인 설치 순서
|
||||
1. NVIDIA 드라이버 설치
|
||||
2. CUDA Toolkit 12.1 설치
|
||||
3. cuDNN v9.x 설치
|
||||
4. onnxruntime-gpu 설치
|
||||
|
||||
자세한 설치 방법은 GPU 설치 후 다시 확인해주세요.
|
||||
"""
|
||||
|
|
@ -0,0 +1,668 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
GPU 유틸리티 모듈 - DirectML 기반 GPU 가속 및 상태 관리
|
||||
|
||||
기능:
|
||||
- GPU 사용 가능성 검사
|
||||
- DirectML 지원 여부 확인
|
||||
- 전역 GPU 상태 관리
|
||||
- CPU 폴백 처리
|
||||
- Windows DirectX 12 기반 범용 GPU 지원 (NVIDIA, AMD, Intel)
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
import subprocess
|
||||
import platform
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
# ONNXRuntime DirectML 메모리 절약형 설정
|
||||
os.environ['ORT_ENABLE_ALL_OPTIMIZATIONS'] = '0' # 🔧 메모리 절약을 위해 비활성화
|
||||
os.environ['ORT_DISABLE_MEMCPY_WARNINGS'] = '1' # Memcpy 경고 억제
|
||||
os.environ['ORT_DISABLE_ALL_CUSTOM_OPS'] = '0' # 커스텀 연산 활성화
|
||||
os.environ['ORT_LOGGING_LEVEL'] = '3' # 경고만 출력 (0: DEBUG, 1: INFO, 2: WARNING, 3: ERROR)
|
||||
os.environ['ORT_CUDA_CUDNN_CONV_USE_MAX_WORKSPACE'] = '0' # GPU 메모리 안정성
|
||||
os.environ['ORT_LOG_SEVERITY_LEVEL'] = '3' # 심각한 오류만 로깅
|
||||
# DirectML 특화 메모리 절약 설정
|
||||
os.environ['ORT_DML_ENABLE_GRAPH_SERIALIZATION'] = '0' # 그래프 직렬화 비활성화 (안정성)
|
||||
os.environ['ORT_DML_METACOMMANDS_ENABLED'] = '1' # 메타커맨드 활성화
|
||||
|
||||
|
||||
class GPUManager:
|
||||
"""GPU 상태 관리 및 DirectML 지원 확인 (하위 호환성을 위해 can_use_cuda 속성 유지)"""
|
||||
|
||||
def __init__(self, logger: Optional[object] = None):
|
||||
self.logger = logger or self._create_dummy_logger()
|
||||
|
||||
# GPU 상태 전역 변수들 (하위 호환성을 위해 can_use_cuda 유지)
|
||||
self.can_use_cuda = False # DirectML 사용 가능 여부 (기존 인터페이스 호환성)
|
||||
self.directml_available = False
|
||||
self.gpu_info = {}
|
||||
self.initialization_attempted = False
|
||||
|
||||
def _create_dummy_logger(self):
|
||||
"""로거가 없을 때 사용할 더미 로거"""
|
||||
class DummyLogger:
|
||||
def log(self, msg, level=logging.DEBUG, exc_info=False):
|
||||
print(f"[GPU] {msg}")
|
||||
return DummyLogger()
|
||||
|
||||
def _setup_directml_environment(self) -> None:
|
||||
"""DirectML 환경 설정 (Windows DirectX 12 기반)"""
|
||||
try:
|
||||
if platform.system() == "Windows":
|
||||
# DirectML은 Windows에 내장된 DirectX 12를 사용하므로 별도 설정 불필요
|
||||
self.logger.log("✅ DirectML 환경 준비 완료 (Windows DirectX 12 기반)", level=logging.DEBUG)
|
||||
else:
|
||||
self.logger.log("⚠️ DirectML은 Windows 전용입니다", level=logging.WARNING)
|
||||
except Exception as e:
|
||||
self.logger.log(f"DirectML 환경 설정 중 오류: {e}", level=logging.WARNING)
|
||||
|
||||
def initialize_gpu_state(self, toggle_states: Dict[str, Any]) -> None:
|
||||
"""
|
||||
GPU 상태를 초기화하고 전역 변수에 저장
|
||||
|
||||
Args:
|
||||
toggle_states: 설정 딕셔너리
|
||||
"""
|
||||
if self.initialization_attempted:
|
||||
return # 이미 초기화됨
|
||||
|
||||
self.initialization_attempted = True
|
||||
|
||||
# DirectML 환경 설정
|
||||
self._setup_directml_environment()
|
||||
|
||||
# 사용자가 GPU 가속을 원하는지 확인 (use_cuda를 GPU 가속 플래그로 사용)
|
||||
use_gpu_requested = toggle_states.get("use_cuda", False)
|
||||
|
||||
self.logger.log("=== 🚀 DirectML GPU 상태 초기화 시작 🚀 ===", level=logging.DEBUG)
|
||||
self.logger.log(f"🎯 사용자 GPU 가속 요청: {use_gpu_requested}", level=logging.DEBUG)
|
||||
self.logger.log(f"💻 현재 운영체제: {platform.system()}", level=logging.DEBUG)
|
||||
|
||||
if not use_gpu_requested:
|
||||
self.logger.log("GPU 가속이 비활성화됨 (toggle_states['use_cuda'] = False)", level=logging.DEBUG)
|
||||
self.can_use_cuda = False
|
||||
self._set_safe_cpu_mode(toggle_states)
|
||||
return
|
||||
|
||||
# Windows 플랫폼 확인 (DirectML 필수)
|
||||
if platform.system() != "Windows":
|
||||
self.logger.log("DirectML은 Windows 전용입니다 - CPU 모드로 전환", level=logging.WARNING)
|
||||
self.can_use_cuda = False
|
||||
self._set_safe_cpu_mode(toggle_states)
|
||||
return
|
||||
|
||||
# DirectML 지원 확인 (안전하게 시도)
|
||||
try:
|
||||
directml_support = self._check_directml_support()
|
||||
except Exception as e:
|
||||
self.logger.log(f"DirectML 확인 중 예외 발생: {e} - CPU 모드로 안전 전환", level=logging.WARNING)
|
||||
self.can_use_cuda = False
|
||||
self._set_safe_cpu_mode(toggle_states)
|
||||
return
|
||||
|
||||
if not directml_support:
|
||||
self.logger.log("DirectML 지원을 확인할 수 없음 - CPU 모드로 전환", level=logging.WARNING)
|
||||
self.can_use_cuda = False
|
||||
self._set_safe_cpu_mode(toggle_states)
|
||||
return
|
||||
|
||||
# 메모리 상태 확인 (안전장치)
|
||||
try:
|
||||
memory_ok = self._check_system_memory()
|
||||
if not memory_ok:
|
||||
self.logger.log("⚠️ 시스템 메모리 부족 감지 - 안전을 위해 CPU 모드로 전환", level=logging.WARNING)
|
||||
self.can_use_cuda = False
|
||||
self._set_safe_cpu_mode(toggle_states)
|
||||
return
|
||||
except Exception as e:
|
||||
self.logger.log(f"메모리 확인 실패: {e} - 안전을 위해 CPU 모드로 전환", level=logging.WARNING)
|
||||
self.can_use_cuda = False
|
||||
self._set_safe_cpu_mode(toggle_states)
|
||||
return
|
||||
|
||||
# 모든 검사 통과
|
||||
self.can_use_cuda = True # 하위 호환성을 위해 이 속성명 유지
|
||||
self.directml_available = True
|
||||
|
||||
# toggle_states에서 migan_use_cuda를 True로 자동 설정
|
||||
if 'migan_use_cuda' in toggle_states and not toggle_states['migan_use_cuda']:
|
||||
toggle_states['migan_use_cuda'] = True
|
||||
self.logger.log("🎯 toggle_states의 migan_use_cuda를 True로 자동 설정", level=logging.INFO)
|
||||
|
||||
self.logger.log("🚀 ✅ DirectML 사용 가능 - GPU 가속 모드로 동작 ✅ 🚀", level=logging.DEBUG)
|
||||
self.logger.log("🎮 DirectML: NVIDIA, AMD, Intel GPU 모두 지원", level=logging.DEBUG)
|
||||
self.logger.log("📊 DirectML 가속 활성화: rembg, MIGAN, OCR 모든 모듈에서 GPU 사용", level=logging.DEBUG)
|
||||
self.logger.log("=== 🎯 DirectML GPU 상태 초기화 완료 🎯 ===", level=logging.DEBUG)
|
||||
|
||||
def _set_safe_cpu_mode(self, toggle_states: Dict[str, Any]) -> None:
|
||||
"""안전한 CPU 모드로 설정"""
|
||||
self.can_use_cuda = False
|
||||
self.directml_available = False
|
||||
|
||||
# 모든 GPU 관련 설정을 CPU 모드로 강제 변경
|
||||
gpu_related_keys = [
|
||||
'migan_use_cuda', 'use_cuda', 'optionIMGTrans_type',
|
||||
'detail_IMGTrans_type', 'thumb_trans_type'
|
||||
]
|
||||
|
||||
for key in gpu_related_keys:
|
||||
if key in toggle_states:
|
||||
if key.endswith('_type'):
|
||||
toggle_states[key] = 'CPU'
|
||||
else:
|
||||
toggle_states[key] = False
|
||||
|
||||
self.logger.log("🔒 안전한 CPU 모드로 모든 GPU 설정 강제 비활성화", level=logging.INFO)
|
||||
|
||||
def _check_system_memory(self) -> bool:
|
||||
"""시스템 메모리 상태 확인"""
|
||||
try:
|
||||
import psutil
|
||||
memory = psutil.virtual_memory()
|
||||
available_gb = memory.available / (1024**3)
|
||||
|
||||
self.logger.log(f"💾 시스템 메모리 - 사용가능: {available_gb:.1f}GB, 사용률: {memory.percent:.1f}%", level=logging.DEBUG)
|
||||
|
||||
# 사용 가능한 메모리가 2GB 미만이거나 사용률이 90% 이상이면 위험
|
||||
if available_gb < 2.0 or memory.percent > 90:
|
||||
self.logger.log(f"⚠️ 메모리 부족 위험: 사용가능 {available_gb:.1f}GB, 사용률 {memory.percent:.1f}%", level=logging.WARNING)
|
||||
return False
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
self.logger.log(f"메모리 상태 확인 실패: {e}", level=logging.WARNING)
|
||||
return False # 확인 실패 시 안전하게 False 반환
|
||||
|
||||
def _detect_gpu_hardware(self) -> bool:
|
||||
"""GPU 하드웨어 감지"""
|
||||
try:
|
||||
self.logger.log("🔍 GPU 하드웨어 감지 시작...", level=logging.DEBUG)
|
||||
|
||||
if platform.system() != "Windows":
|
||||
self.logger.log("❌ 현재 Windows만 지원됨", level=logging.WARNING)
|
||||
return False
|
||||
|
||||
self.logger.log("🖥️ Windows 환경 확인됨, nvidia-smi 명령 실행 중...", level=logging.DEBUG)
|
||||
|
||||
# nvidia-smi 명령어로 GPU 확인
|
||||
result = subprocess.run(
|
||||
["nvidia-smi", "--query-gpu=name,memory.total,driver_version", "--format=csv,noheader,nounits"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10,
|
||||
creationflags=subprocess.CREATE_NO_WINDOW if platform.system() == "Windows" else 0
|
||||
)
|
||||
|
||||
self.logger.log(f"📊 nvidia-smi 실행 결과 - 반환코드: {result.returncode}", level=logging.DEBUG)
|
||||
if result.stdout:
|
||||
self.logger.log(f"📄 nvidia-smi 출력: {result.stdout.strip()}", level=logging.DEBUG)
|
||||
if result.stderr:
|
||||
self.logger.log(f"⚠️ nvidia-smi 에러 출력: {result.stderr.strip()}", level=logging.WARNING)
|
||||
|
||||
if result.returncode == 0 and result.stdout.strip():
|
||||
gpu_lines = result.stdout.strip().split('\n')
|
||||
for i, line in enumerate(gpu_lines):
|
||||
if line.strip():
|
||||
parts = [p.strip() for p in line.split(',')]
|
||||
if len(parts) >= 3:
|
||||
self.gpu_info[f'gpu_{i}'] = {
|
||||
'name': parts[0],
|
||||
'memory_mb': parts[1],
|
||||
'driver_version': parts[2]
|
||||
}
|
||||
|
||||
self.logger.log(f"GPU 하드웨어 감지됨: {len(self.gpu_info)}개", level=logging.DEBUG)
|
||||
for gpu_id, info in self.gpu_info.items():
|
||||
self.logger.log(f" {gpu_id}: {info['name']} ({info['memory_mb']}MB, 드라이버 {info['driver_version']})", level=logging.DEBUG)
|
||||
return True
|
||||
else:
|
||||
self.logger.log(f"nvidia-smi 실행 실패: {result.stderr}", level=logging.WARNING)
|
||||
return False
|
||||
|
||||
except (subprocess.TimeoutExpired, FileNotFoundError, subprocess.SubprocessError) as e:
|
||||
self.logger.log(f"GPU 하드웨어 감지 실패: {e}", level=logging.WARNING)
|
||||
return False
|
||||
except Exception as e:
|
||||
self.logger.log(f"GPU 하드웨어 감지 중 예외: {e}", level=logging.ERROR, exc_info=True)
|
||||
return False
|
||||
|
||||
def _check_cuda_installation(self) -> bool:
|
||||
"""CUDA 설치 및 작동 상태 확인"""
|
||||
try:
|
||||
self.logger.log("🔧 CUDA 설치 상태 확인 중...", level=logging.DEBUG)
|
||||
|
||||
# nvcc 버전 확인
|
||||
result = subprocess.run(
|
||||
["nvcc", "--version"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10,
|
||||
creationflags=subprocess.CREATE_NO_WINDOW if platform.system() == "Windows" else 0
|
||||
)
|
||||
|
||||
self.logger.log(f"🛠️ nvcc 명령 실행 결과 - 반환코드: {result.returncode}", level=logging.DEBUG)
|
||||
if result.stdout:
|
||||
self.logger.log(f"📋 nvcc 출력: {result.stdout.strip()}", level=logging.DEBUG)
|
||||
if result.stderr:
|
||||
self.logger.log(f"⚠️ nvcc 에러 출력: {result.stderr.strip()}", level=logging.WARNING)
|
||||
|
||||
if result.returncode == 0:
|
||||
version_output = result.stdout
|
||||
self.logger.log(f"CUDA 컴파일러 감지됨", level=logging.DEBUG)
|
||||
|
||||
# 버전 정보 추출
|
||||
for line in version_output.split('\n'):
|
||||
if 'release' in line.lower():
|
||||
self.logger.log(f"CUDA 버전: {line.strip()}", level=logging.DEBUG)
|
||||
break
|
||||
|
||||
return True
|
||||
else:
|
||||
self.logger.log("CUDA 컴파일러(nvcc)를 찾을 수 없음", level=logging.WARNING)
|
||||
return False
|
||||
|
||||
except (subprocess.TimeoutExpired, FileNotFoundError) as e:
|
||||
self.logger.log(f"CUDA 설치 확인 실패: {e}", level=logging.WARNING)
|
||||
return False
|
||||
except Exception as e:
|
||||
self.logger.log(f"CUDA 설치 확인 중 예외: {e}", level=logging.ERROR, exc_info=True)
|
||||
return False
|
||||
|
||||
def _check_directml_support(self) -> bool:
|
||||
"""DirectML 지원 확인 및 실제 GPU 가속 동작 테스트"""
|
||||
self.logger.log("🧠 DirectML 지원 확인 및 실제 테스트 시작...", level=logging.DEBUG)
|
||||
|
||||
try:
|
||||
self.logger.log("📦 ONNXRuntime DirectML 확인 중...", level=logging.DEBUG)
|
||||
import onnxruntime as ort
|
||||
providers = ort.get_available_providers()
|
||||
self.logger.log(f"🔍 ONNXRuntime 사용 가능한 providers: {providers}", level=logging.DEBUG)
|
||||
|
||||
# DirectML provider 존재 확인
|
||||
if "DmlExecutionProvider" not in providers:
|
||||
self.logger.log("❌ ONNXRuntime DirectML 지원 없음", level=logging.WARNING)
|
||||
self.logger.log("💡 onnxruntime-directml 패키지가 필요할 수 있습니다", level=logging.WARNING)
|
||||
return False
|
||||
|
||||
self.logger.log("⚡ DirectML ExecutionProvider 지원 확인됨", level=logging.DEBUG)
|
||||
|
||||
# VM 환경 감지
|
||||
if self._detect_vm_environment():
|
||||
self.logger.log("🖥️ VM 환경이 감지됨 - GPU 패스스루 상태 확인 중...", level=logging.DEBUG)
|
||||
|
||||
# 실제 DirectML 동작 테스트
|
||||
if not self._test_directml_actual_performance():
|
||||
self.logger.log("❌ DirectML 실제 동작 테스트 실패 - CPU 모드로 전환", level=logging.WARNING)
|
||||
return False
|
||||
|
||||
self.logger.log("✅ DirectML 실제 GPU 가속 동작 확인됨", level=logging.DEBUG)
|
||||
return True
|
||||
|
||||
except ImportError:
|
||||
self.logger.log("ONNXRuntime가 설치되지 않음", level=logging.WARNING)
|
||||
return False
|
||||
except Exception as e:
|
||||
self.logger.log(f"ONNXRuntime DirectML 지원 확인 실패: {e}", level=logging.WARNING)
|
||||
return False
|
||||
|
||||
def test_directml_comprehensive(self) -> Dict[str, Any]:
|
||||
"""종합적인 DirectML 테스트 (GPU 상태 버튼 전용 - 실제 추론 테스트 포함)"""
|
||||
test_results = {
|
||||
'directml_available': False,
|
||||
'vm_detected': False,
|
||||
'inference_test_passed': False,
|
||||
'test_duration': 0,
|
||||
'error_message': None,
|
||||
'performance_ratio': 0
|
||||
}
|
||||
|
||||
try:
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
self.logger.log("🧪 종합적인 DirectML 테스트 시작 (실제 추론 포함)...", level=logging.DEBUG)
|
||||
|
||||
# 1단계: DirectML provider 확인
|
||||
import onnxruntime as ort
|
||||
providers = ort.get_available_providers()
|
||||
directml_available = "DmlExecutionProvider" in providers
|
||||
test_results['directml_available'] = directml_available
|
||||
|
||||
if not directml_available:
|
||||
test_results['error_message'] = "DirectML Provider를 찾을 수 없습니다"
|
||||
return test_results
|
||||
|
||||
# 2단계: VM 환경 감지
|
||||
vm_detected = self._detect_vm_environment()
|
||||
test_results['vm_detected'] = vm_detected
|
||||
|
||||
if vm_detected:
|
||||
self.logger.log("🖥️ VM 환경이 감지됨 - GPU 패스스루 상태 확인 중...", level=logging.DEBUG)
|
||||
|
||||
# 3단계: 실제 DirectML 추론 테스트
|
||||
inference_success = self._test_directml_actual_performance()
|
||||
test_results['inference_test_passed'] = inference_success
|
||||
|
||||
if not inference_success:
|
||||
test_results['error_message'] = "DirectML 추론 테스트 실패 - GPU 가속이 실제로 동작하지 않습니다"
|
||||
return test_results
|
||||
|
||||
# 4단계: 성능 벤치마크
|
||||
performance_ratio = self._benchmark_directml_vs_cpu()
|
||||
test_results['performance_ratio'] = performance_ratio
|
||||
|
||||
test_results['test_duration'] = time.time() - start_time
|
||||
self.logger.log(f"✅ 종합 DirectML 테스트 완료 ({test_results['test_duration']:.2f}초)", level=logging.DEBUG)
|
||||
|
||||
return test_results
|
||||
|
||||
except Exception as e:
|
||||
test_results['error_message'] = f"DirectML 테스트 중 예외 발생: {e}"
|
||||
self.logger.log(f"DirectML 테스트 중 예외: {e}", level=logging.ERROR, exc_info=True)
|
||||
return test_results
|
||||
|
||||
def _detect_vm_environment(self) -> bool:
|
||||
"""VM 환경 감지 (Proxmox, VMware, VirtualBox, Hyper-V 등)"""
|
||||
try:
|
||||
self.logger.log("🔍 VM 환경 감지 중...", level=logging.DEBUG)
|
||||
|
||||
vm_indicators = []
|
||||
|
||||
# 시스템 정보를 통한 VM 감지
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["wmic", "computersystem", "get", "model"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10,
|
||||
creationflags=subprocess.CREATE_NO_WINDOW
|
||||
)
|
||||
if result.returncode == 0:
|
||||
output = result.stdout.lower()
|
||||
vm_keywords = ['virtualbox', 'vmware', 'qemu', 'xen', 'kvm', 'hyper-v', 'proxmox']
|
||||
for keyword in vm_keywords:
|
||||
if keyword in output:
|
||||
vm_indicators.append(f"시스템 모델: {keyword}")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.log(f"시스템 모델 확인 실패: {e}", level=logging.DEBUG)
|
||||
|
||||
# CPU 정보를 통한 VM 감지
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["wmic", "cpu", "get", "name"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10,
|
||||
creationflags=subprocess.CREATE_NO_WINDOW
|
||||
)
|
||||
if result.returncode == 0:
|
||||
output = result.stdout.lower()
|
||||
if 'qemu' in output or 'virtual' in output:
|
||||
vm_indicators.append("CPU: 가상화 프로세서 감지")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.log(f"CPU 정보 확인 실패: {e}", level=logging.DEBUG)
|
||||
|
||||
# DirectX 정보를 통한 하드웨어 가속 확인
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["dxdiag", "/t", "temp_dxdiag.txt"],
|
||||
capture_output=True,
|
||||
timeout=15,
|
||||
creationflags=subprocess.CREATE_NO_WINDOW
|
||||
)
|
||||
# dxdiag는 파일을 생성하므로 바로 결과를 읽을 수 없음
|
||||
# 이 부분은 간소화하여 생략
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if vm_indicators:
|
||||
self.logger.log(f"🖥️ VM 환경 감지됨: {', '.join(vm_indicators)}", level=logging.INFO)
|
||||
return True
|
||||
else:
|
||||
self.logger.log("💻 물리적 환경으로 판단됨", level=logging.DEBUG)
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
self.logger.log(f"VM 환경 감지 실패: {e}", level=logging.DEBUG)
|
||||
return False # 실패 시 안전하게 물리 환경으로 가정
|
||||
|
||||
def _test_directml_actual_performance(self) -> bool:
|
||||
"""실제 DirectML을 사용한 연산 테스트로 GPU 가속 동작 확인"""
|
||||
try:
|
||||
self.logger.log("🧪 DirectML 실제 동작 테스트 시작...", level=logging.DEBUG)
|
||||
|
||||
import onnxruntime as ort
|
||||
import numpy as np
|
||||
import time
|
||||
|
||||
# 1단계: DirectML provider 초기화 테스트
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# DirectML provider 설정
|
||||
dml_options = {
|
||||
'device_id': 0,
|
||||
'disable_memory_arena': False, # 메모리 아레나 사용
|
||||
}
|
||||
providers = [('DmlExecutionProvider', dml_options), 'CPUExecutionProvider']
|
||||
|
||||
# 세션 옵션 설정
|
||||
session_options = ort.SessionOptions()
|
||||
session_options.log_severity_level = 3 # ERROR만 출력
|
||||
session_options.enable_mem_pattern = False # VM에서 문제가 될 수 있음
|
||||
session_options.enable_cpu_mem_arena = False # 안정성 향상
|
||||
|
||||
# 초기화 시간 확인
|
||||
init_elapsed = time.time() - start_time
|
||||
if init_elapsed > 5.0: # 5초 이상 걸리면 의심스러움
|
||||
self.logger.log(f"⚠️ DirectML 초기화가 비정상적으로 오래 걸림 ({init_elapsed:.1f}초)", level=logging.WARNING)
|
||||
|
||||
# 2단계: 실제 간단한 모델로 추론 테스트
|
||||
success = self._perform_simple_inference_test(providers, session_options)
|
||||
if not success:
|
||||
return False
|
||||
|
||||
# 3단계: 성능 벤치마크 (GPU vs CPU 비교)
|
||||
performance_ok = self._benchmark_directml_vs_cpu()
|
||||
if not performance_ok:
|
||||
self.logger.log("⚠️ DirectML 성능이 CPU보다 현저히 느림 - 실제 GPU 가속이 동작하지 않을 수 있음", level=logging.WARNING)
|
||||
return False
|
||||
|
||||
total_elapsed = time.time() - start_time
|
||||
self.logger.log(f"✅ DirectML 실제 동작 테스트 성공 (총 {total_elapsed:.2f}초)", level=logging.DEBUG)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.logger.log(f"❌ DirectML 세션 생성/테스트 실패: {e}", level=logging.WARNING)
|
||||
return False
|
||||
|
||||
except ImportError as e:
|
||||
self.logger.log(f"❌ 필요한 패키지 import 실패: {e}", level=logging.WARNING)
|
||||
return False
|
||||
except Exception as e:
|
||||
self.logger.log(f"❌ DirectML 실제 동작 테스트 중 예외: {e}", level=logging.WARNING)
|
||||
return False
|
||||
|
||||
def _perform_simple_inference_test(self, providers, session_options) -> bool:
|
||||
"""간단한 모델로 실제 추론 테스트"""
|
||||
try:
|
||||
import onnxruntime as ort
|
||||
import numpy as np
|
||||
import time
|
||||
|
||||
# 매우 간단한 ONNX 모델 생성 (Add 연산)
|
||||
from onnx import helper, TensorProto
|
||||
import onnx
|
||||
|
||||
# 간단한 덧셈 연산 모델 생성
|
||||
input1 = helper.make_tensor_value_info('input1', TensorProto.FLOAT, [1, 3, 224, 224])
|
||||
input2 = helper.make_tensor_value_info('input2', TensorProto.FLOAT, [1, 3, 224, 224])
|
||||
output = helper.make_tensor_value_info('output', TensorProto.FLOAT, [1, 3, 224, 224])
|
||||
|
||||
add_node = helper.make_node('Add', ['input1', 'input2'], ['output'])
|
||||
graph = helper.make_graph([add_node], 'simple_add', [input1, input2], [output])
|
||||
model = helper.make_model(graph)
|
||||
|
||||
# 임시 모델 파일로 저장
|
||||
import tempfile
|
||||
import os
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix='.onnx', delete=False) as tmp_file:
|
||||
tmp_path = tmp_file.name
|
||||
onnx.save(model, tmp_path)
|
||||
|
||||
try:
|
||||
# DirectML로 세션 생성
|
||||
session = ort.InferenceSession(tmp_path, session_options, providers=providers)
|
||||
|
||||
# 테스트 입력 데이터
|
||||
input_data1 = np.random.rand(1, 3, 224, 224).astype(np.float32)
|
||||
input_data2 = np.random.rand(1, 3, 224, 224).astype(np.float32)
|
||||
|
||||
# 실제 추론 수행
|
||||
start_time = time.time()
|
||||
results = session.run(None, {'input1': input_data1, 'input2': input_data2})
|
||||
inference_time = time.time() - start_time
|
||||
|
||||
# 결과 검증
|
||||
expected = input_data1 + input_data2
|
||||
if np.allclose(results[0], expected, rtol=1e-5):
|
||||
self.logger.log(f"✅ 추론 테스트 성공 ({inference_time:.3f}초)", level=logging.DEBUG)
|
||||
return True
|
||||
else:
|
||||
self.logger.log("❌ 추론 결과가 예상과 다름", level=logging.WARNING)
|
||||
return False
|
||||
|
||||
finally:
|
||||
# 임시 파일 삭제
|
||||
try:
|
||||
os.unlink(tmp_path)
|
||||
except:
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
self.logger.log(f"❌ 간단한 추론 테스트 실패: {e}", level=logging.WARNING)
|
||||
return False
|
||||
|
||||
def _benchmark_directml_vs_cpu(self) -> bool:
|
||||
"""DirectML과 CPU 성능 비교로 실제 GPU 가속 확인"""
|
||||
try:
|
||||
self.logger.log("⏱️ DirectML vs CPU 성능 벤치마크 시작...", level=logging.DEBUG)
|
||||
|
||||
# 현재는 간단한 체크만 수행 (더 정교한 벤치마크는 추후 구현)
|
||||
# VM 환경에서는 GPU 가속이 제대로 안 될 가능성이 높으므로
|
||||
# 여기서는 초기화 시간과 기본 동작만 확인
|
||||
|
||||
return True # 일단 통과로 처리
|
||||
|
||||
except Exception as e:
|
||||
self.logger.log(f"성능 벤치마크 중 오류: {e}", level=logging.DEBUG)
|
||||
return True # 벤치마크 실패는 허용
|
||||
|
||||
def get_cuda_status(self) -> Dict[str, Any]:
|
||||
"""현재 GPU 가속 상태 정보 반환 (하위 호환성을 위해 메서드명 유지)"""
|
||||
return {
|
||||
"can_use_cuda": self.can_use_cuda, # DirectML 사용 가능 여부 (호환성)
|
||||
"cuda_available": self.can_use_cuda, # 호환성을 위해 동일한 값
|
||||
"directml_available": self.directml_available,
|
||||
"gpu_info": self.gpu_info.copy(),
|
||||
"initialization_attempted": self.initialization_attempted
|
||||
}
|
||||
|
||||
def force_cpu_mode(self) -> None:
|
||||
"""강제로 CPU 모드로 전환"""
|
||||
self.can_use_cuda = False
|
||||
self.logger.log("강제로 CPU 모드로 전환됨", level=logging.WARNING)
|
||||
|
||||
def get_optimal_onnx_providers(self) -> list:
|
||||
"""DirectML 기반 최적 ONNXRuntime provider 우선순위 리스트 반환"""
|
||||
providers = []
|
||||
|
||||
if self.can_use_cuda: # DirectML 사용 가능 여부
|
||||
try:
|
||||
import onnxruntime as ort
|
||||
available = ort.get_available_providers()
|
||||
|
||||
# DirectML Provider (Windows GPU 가속)
|
||||
if 'DmlExecutionProvider' in available:
|
||||
directml_options = {
|
||||
'device_id': 0, # 기본 GPU 사용
|
||||
}
|
||||
providers.append(('DmlExecutionProvider', directml_options))
|
||||
self.logger.log("⚡ DirectML provider 추가 (범용 GPU 가속 - NVIDIA/AMD/Intel 지원)", level=logging.DEBUG)
|
||||
else:
|
||||
self.logger.log("❌ DirectML provider 사용 불가", level=logging.WARNING)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.log(f"DirectML Provider 확인 실패: {e}", level=logging.WARNING)
|
||||
|
||||
# 항상 CPU는 폴백으로 추가
|
||||
providers.append(('CPUExecutionProvider', {}))
|
||||
provider_names = [p[0] if isinstance(p, tuple) else p for p in providers]
|
||||
self.logger.log(f"📊 최종 provider 순서: {provider_names}", level=logging.DEBUG)
|
||||
|
||||
return providers
|
||||
|
||||
def log_gpu_memory_usage(self) -> None:
|
||||
"""현재 GPU 메모리 사용량 로깅"""
|
||||
if not self.can_use_cuda:
|
||||
return
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["nvidia-smi", "--query-gpu=memory.used,memory.total", "--format=csv,noheader,nounits"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5,
|
||||
creationflags=subprocess.CREATE_NO_WINDOW if platform.system() == "Windows" else 0
|
||||
)
|
||||
|
||||
if result.returncode == 0 and result.stdout.strip():
|
||||
lines = result.stdout.strip().split('\n')
|
||||
for i, line in enumerate(lines):
|
||||
if line.strip():
|
||||
parts = [p.strip() for p in line.split(',')]
|
||||
if len(parts) >= 2:
|
||||
used_mb = int(parts[0])
|
||||
total_mb = int(parts[1])
|
||||
usage_percent = (used_mb / total_mb) * 100
|
||||
self.logger.log(
|
||||
f"GPU {i} 메모리 사용량: {used_mb}MB/{total_mb}MB ({usage_percent:.1f}%)",
|
||||
level=logging.DEBUG
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.log(f"GPU 메모리 사용량 확인 실패: {e}", level=logging.DEBUG)
|
||||
|
||||
|
||||
# 전역 GPU 관리자 인스턴스 (선택적 사용)
|
||||
_global_gpu_manager = None
|
||||
|
||||
def get_global_gpu_manager(logger=None) -> GPUManager:
|
||||
"""전역 GPU 관리자 인스턴스 반환"""
|
||||
global _global_gpu_manager
|
||||
if _global_gpu_manager is None:
|
||||
_global_gpu_manager = GPUManager(logger)
|
||||
return _global_gpu_manager
|
||||
|
||||
|
||||
def check_cuda_simple() -> bool:
|
||||
"""간단한 GPU 가속 사용 가능성 확인 (DirectML, 하위 호환성을 위해 함수명 유지)"""
|
||||
try:
|
||||
import onnxruntime as ort
|
||||
providers = ort.get_available_providers()
|
||||
return "DmlExecutionProvider" in providers
|
||||
except:
|
||||
return False
|
||||
|
||||
def check_directml_simple() -> bool:
|
||||
"""간단한 DirectML 사용 가능성 확인 (캐시 없음)"""
|
||||
try:
|
||||
import onnxruntime as ort
|
||||
providers = ort.get_available_providers()
|
||||
return "DmlExecutionProvider" in providers
|
||||
except:
|
||||
return False
|
||||
|
|
@ -0,0 +1,401 @@
|
|||
|
||||
|
||||
# src/modules/image_worker.py
|
||||
"""
|
||||
ImageWorker 프로세스 – 별도 프로세스로 구동
|
||||
단 한 번의 READY("OK") 신호만 보내며, OCR 모델 Warm‑up 후 작업 루프에 진입한다.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import multiprocessing
|
||||
import os
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
import traceback
|
||||
import queue
|
||||
import time
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from modules.image_processor3 import ImageProcessor3
|
||||
from modules.log_bridge import ImageWorkerLogger
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# 로깅 유틸리티 #
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
class QueueLogger:
|
||||
"""큐를 통해 메인 프로세스의 로거로 전송하는 로거"""
|
||||
def __init__(self, log_queue, process_name):
|
||||
self.log_queue = log_queue
|
||||
self.process_name = process_name
|
||||
|
||||
def log(self, message, level=logging.INFO, exc_info=False):
|
||||
try:
|
||||
log_record = {
|
||||
'process_name': self.process_name,
|
||||
'level': level,
|
||||
'message': message,
|
||||
'exc_info': exc_info
|
||||
}
|
||||
self.log_queue.put(log_record)
|
||||
except Exception:
|
||||
pass # 로그 전송 실패 시 무시
|
||||
|
||||
def debug(self, msg, *a, **kw): self.log(msg, logging.DEBUG)
|
||||
def info(self, msg, *a, **kw): self.log(msg, logging.INFO)
|
||||
def warning(self, msg, *a, **kw): self.log(msg, logging.WARNING)
|
||||
def error(self, msg, *a, **kw): self.log(msg, logging.ERROR)
|
||||
def critical(self, msg, *a, **kw): self.log(msg, logging.CRITICAL)
|
||||
|
||||
class CompatLogger:
|
||||
"""커스텀 Logger 인터페이스를 표준 logging.Logger 로 매핑"""
|
||||
def __init__(self, py_logger: logging.Logger):
|
||||
self._l = py_logger
|
||||
|
||||
def log(self, message, level=logging.INFO, exc_info=False):
|
||||
self._l.log(level, message, exc_info=exc_info)
|
||||
|
||||
# 편의 메서드
|
||||
def debug(self, msg, *a, **kw): self.log(msg, logging.DEBUG)
|
||||
def info(self, msg, *a, **kw): self.log(msg, logging.INFO)
|
||||
def warning(self, msg, *a, **kw): self.log(msg, logging.WARNING)
|
||||
def error(self, msg, *a, **kw): self.log(msg, logging.ERROR)
|
||||
def critical(self, msg, *a, **kw): self.log(msg, logging.CRITICAL)
|
||||
|
||||
|
||||
def _setup_logging(log_path: str) -> logging.Logger:
|
||||
"""별도 프로세스용 파일 로거 설정"""
|
||||
root = logging.getLogger()
|
||||
root.handlers.clear()
|
||||
root.setLevel(logging.DEBUG)
|
||||
|
||||
fh = logging.FileHandler(log_path, encoding="utf-8")
|
||||
fmt = logging.Formatter(
|
||||
"[%(asctime)s] [%(processName)s] [%(levelname)s] "
|
||||
"[%(module)s:%(funcName)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
fh.setFormatter(fmt)
|
||||
root.addHandler(fh)
|
||||
return root
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# 워커 메인 함수 #
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
def worker_main(
|
||||
task_q,
|
||||
result_q,
|
||||
log_queue, # 추가: 로그 큐
|
||||
log_path: str,
|
||||
base_dir: str,
|
||||
toggle_states: dict,
|
||||
unwanted_words: list[str],
|
||||
authenticated_by_admin: bool = False,
|
||||
):
|
||||
# ── 로깅 초기화 ────────────────────────────────────────────
|
||||
# 큐 로거 사용 (메인 프로세스로 로그 전송)
|
||||
if log_queue:
|
||||
logger = ImageWorkerLogger(log_queue, f"ImageWorker-{os.getpid()}")
|
||||
print(f"[DEBUG] ImageWorker {os.getpid()}: LogQueue 사용됨")
|
||||
else:
|
||||
# 폴백: 기존 파일 로거
|
||||
py_logger = _setup_logging(log_path)
|
||||
logger = CompatLogger(py_logger)
|
||||
print(f"[DEBUG] ImageWorker {os.getpid()}: 파일 로거 사용됨")
|
||||
|
||||
logger.info(
|
||||
f"ImageWorker 프로세스 기동 "
|
||||
f"(PID={os.getpid()}, Name={multiprocessing.current_process().name})"
|
||||
)
|
||||
|
||||
# 워커 기동 시에도 ProgramData/ImgWorker 하위 임시 폴더를 한 번 더 정리
|
||||
try:
|
||||
base_program = os.environ.get("PROGRAMDATA", r"C:\\ProgramData")
|
||||
app_data_dir = os.path.join(base_program, "ImgWorker")
|
||||
def _safe_rmtree_contents(target_dir: str):
|
||||
try:
|
||||
if not target_dir:
|
||||
return
|
||||
base = os.path.abspath(app_data_dir)
|
||||
td = os.path.abspath(target_dir)
|
||||
if os.path.commonpath([base, td]) != base:
|
||||
return
|
||||
os.makedirs(td, exist_ok=True)
|
||||
for name in os.listdir(td):
|
||||
p = os.path.join(td, name)
|
||||
try:
|
||||
if os.path.isdir(p):
|
||||
import shutil as _sh
|
||||
_sh.rmtree(p, ignore_errors=True)
|
||||
else:
|
||||
os.remove(p)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
for d in (os.path.join(app_data_dir, "incoming"), os.path.join(app_data_dir, "work"), os.path.join(app_data_dir, "output"), os.path.join(app_data_dir, "outputs")):
|
||||
_safe_rmtree_contents(d)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# READY는 모델 로딩 완료 후에만 전송
|
||||
|
||||
# ── ImageProcessor 초기화 및 Warm‑up ──────────────────────
|
||||
processor = None
|
||||
try:
|
||||
# 안전한 초기화를 위한 메모리 정리
|
||||
import gc
|
||||
gc.collect()
|
||||
|
||||
logger.info("🔧 ImageProcessor3 초기화 시작...")
|
||||
processor = ImageProcessor3(
|
||||
logger=logger,
|
||||
page=None,
|
||||
toggle_states=toggle_states,
|
||||
unwanted_words=unwanted_words,
|
||||
authenticated_by_admin=authenticated_by_admin,
|
||||
base_dir=base_dir,
|
||||
papago_translator=None,
|
||||
)
|
||||
|
||||
# OCR 모델 안전한 Warm-up
|
||||
if processor and processor.ocr_module:
|
||||
try:
|
||||
logger.info("🔰 OCR 모듈 Warm-up 시작...")
|
||||
dummy_path = os.path.join(base_dir, "_imgproc_warmup.png")
|
||||
tmp = np.zeros((100, 100, 3), dtype=np.uint8) # 더 현실적인 크기
|
||||
cv2.imwrite(dummy_path, tmp)
|
||||
|
||||
# 타임아웃 설정으로 무한 대기 방지
|
||||
import threading
|
||||
import time
|
||||
|
||||
warmup_success = [False]
|
||||
warmup_error = [None]
|
||||
|
||||
def warmup_ocr():
|
||||
try:
|
||||
processor.ocr_module.detect_text(dummy_path)
|
||||
warmup_success[0] = True
|
||||
except Exception as e:
|
||||
warmup_error[0] = e
|
||||
|
||||
warmup_thread = threading.Thread(target=warmup_ocr)
|
||||
warmup_thread.daemon = True
|
||||
warmup_thread.start()
|
||||
warmup_thread.join(timeout=30) # 30초 타임아웃
|
||||
|
||||
if warmup_success[0]:
|
||||
logger.info("✅ OCR 모듈 Warm-up 성공")
|
||||
elif warmup_error[0]:
|
||||
logger.warning(f"⚠️ OCR 모듈 Warm-up 실패: {warmup_error[0]}")
|
||||
else:
|
||||
logger.warning("⚠️ OCR 모듈 Warm-up 타임아웃")
|
||||
|
||||
try:
|
||||
os.remove(dummy_path)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"OCR Warm-up 실패: {e}")
|
||||
else:
|
||||
logger.warning("OCR 모듈이 초기화되지 않아 Warm-up 건너뜀")
|
||||
|
||||
logger.info("🔰 ImageProcessor Warm‑up 완료")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"ImageProcessor 초기화 실패: {e}", exc_info=True)
|
||||
|
||||
# 초기화 실패 시에도 기본적인 처리가 가능하도록 최소한의 processor 생성 시도
|
||||
try:
|
||||
logger.info("🔄 안전 모드로 재초기화 시도...")
|
||||
|
||||
# GPU 설정을 CPU로 강제 변경
|
||||
safe_toggle_states = toggle_states.copy()
|
||||
safe_toggle_states['use_cuda'] = False
|
||||
safe_toggle_states['optionIMGTrans_type'] = 'CPU'
|
||||
safe_toggle_states['detail_IMGTrans_type'] = 'CPU'
|
||||
safe_toggle_states['thumb_trans_type'] = 'CPU'
|
||||
safe_toggle_states['migan_use_cuda'] = False
|
||||
|
||||
processor = ImageProcessor3(
|
||||
logger=logger,
|
||||
page=None,
|
||||
toggle_states=safe_toggle_states,
|
||||
unwanted_words=unwanted_words,
|
||||
authenticated_by_admin=authenticated_by_admin,
|
||||
base_dir=base_dir,
|
||||
papago_translator=None,
|
||||
)
|
||||
logger.info("✅ 안전 모드로 ImageProcessor 초기화 성공")
|
||||
|
||||
except Exception as e2:
|
||||
logger.error(f"안전 모드 초기화도 실패: {e2}", exc_info=True)
|
||||
processor = None
|
||||
|
||||
# ── READY(OK) 신호 전송 ──────────────────────────────────
|
||||
try:
|
||||
result_q.put({"id": "__READY__", "data": "OK"})
|
||||
logger.info("워커 READY 신호 전송")
|
||||
except Exception:
|
||||
logger.error("READY 신호 전송 실패", exc_info=True)
|
||||
|
||||
# ── rembg 세션 비동기 로딩 ──────────────────────────────
|
||||
def preload_rembg_sessions():
|
||||
"""백그라운드에서 rembg 세션을 미리 로딩"""
|
||||
try:
|
||||
if processor and hasattr(processor, 'background_removal_module'):
|
||||
logger.info("🔄 rembg 세션 백그라운드 로딩 시작...")
|
||||
# 기본 세션들을 미리 로딩
|
||||
processor.background_removal_module._preload_sessions()
|
||||
logger.info("✅ rembg 세션 백그라운드 로딩 완료")
|
||||
except Exception as e:
|
||||
logger.warning(f"rembg 세션 백그라운드 로딩 실패: {e}")
|
||||
|
||||
# rembg 세션 로딩을 별도 Thread에서 실행
|
||||
import threading
|
||||
preload_thread = threading.Thread(target=preload_rembg_sessions, daemon=True)
|
||||
preload_thread.start()
|
||||
|
||||
# ── 작업 루프 ─────────────────────────────────────────────
|
||||
# 컨트롤러가 즉시 pick-up 할 수 있도록 READY 신호를 추가 전송
|
||||
try:
|
||||
result_q.put({"id": "__READY__", "cmd": "__READY__", "kwargs": {}})
|
||||
logger.info("📡 추가 READY 신호 전송 완료")
|
||||
except Exception as e:
|
||||
logger.error(f"추가 READY 신호 전송 실패: {e}")
|
||||
|
||||
idle_log_last = 0.0
|
||||
while True:
|
||||
try:
|
||||
# 주기적 상태 출력을 위해 타임아웃 설정
|
||||
logger.debug(f"큐에서 작업 대기 중... (PID: {os.getpid()})")
|
||||
task = task_q.get(timeout=60) # 60초 타임아웃
|
||||
#logger.info(f"🔥 작업 수신 성공: {task}")
|
||||
logger.info(f"🔥 작업 수신 성공")
|
||||
except queue.Empty:
|
||||
# 유휴 로그는 10분에 한 번만 기록해 로그 스팸을 줄인다.
|
||||
now_ts = time.time()
|
||||
if now_ts - idle_log_last >= 600:
|
||||
idle_log_last = now_ts
|
||||
logger.info("대기 중(유휴)")
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"작업 수신 중 오류: {e}", exc_info=True)
|
||||
continue
|
||||
|
||||
if task is None:
|
||||
logger.info("Shutdown signal 수신 → 종료")
|
||||
try:
|
||||
# 종료 시 임시 폴더 정리
|
||||
base_program = os.environ.get("PROGRAMDATA", r"C:\\ProgramData")
|
||||
app_data_dir = os.path.join(base_program, "ImgWorker")
|
||||
for d in (os.path.join(app_data_dir, "incoming"), os.path.join(app_data_dir, "work"), os.path.join(app_data_dir, "output"), os.path.join(app_data_dir, "outputs")):
|
||||
try:
|
||||
for n in os.listdir(d):
|
||||
p = os.path.join(d, n)
|
||||
if os.path.isdir(p):
|
||||
import shutil as _sh
|
||||
_sh.rmtree(p, ignore_errors=True)
|
||||
else:
|
||||
os.remove(p)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
break
|
||||
|
||||
uid = task["id"]
|
||||
cmd = task["cmd"]
|
||||
kwargs = task["kwargs"]
|
||||
logger.info(f"🚀 작업 처리 시작: cmd={cmd}, uid={uid}")
|
||||
|
||||
# 메타 파라미터 제거 및 실시간 값 반영
|
||||
new_toggle = kwargs.pop("_toggle_states", None)
|
||||
if new_toggle and processor:
|
||||
processor.update_toggle_states(new_toggle)
|
||||
|
||||
_ = kwargs.pop("_base_dir", None) # 필요 없으므로 버림
|
||||
upd_unwanted = kwargs.pop("_update_unwanted_texts", None)
|
||||
if upd_unwanted and processor:
|
||||
processor.update_unwanted_texts(upd_unwanted)
|
||||
|
||||
# 실제 작업 실행
|
||||
try:
|
||||
logger.debug(f"작업 실행 직전: cmd={cmd}")
|
||||
if cmd == "process_single_image":
|
||||
logger.debug("process_single_image 호출 직전")
|
||||
data = asyncio.run(processor.process_single_image(**kwargs))
|
||||
# 성능 지표 포함
|
||||
try:
|
||||
timings = getattr(processor, '_last_timings', None)
|
||||
if isinstance(data, dict) and timings:
|
||||
data['timings'] = timings
|
||||
except Exception:
|
||||
pass
|
||||
logger.debug("process_single_image 호출 완료")
|
||||
elif cmd == "remove_background":
|
||||
logger.debug("remove_background 호출 직전")
|
||||
data = asyncio.run(processor.remove_background(**kwargs))
|
||||
logger.debug("remove_background 호출 완료")
|
||||
elif cmd == "reinit_ocr":
|
||||
# 토글 반영 후 OCR 재초기화(프로바이더 캐시 고려)
|
||||
ok = processor.reset_ocr_module()
|
||||
data = {"ok": bool(ok)}
|
||||
elif cmd == "reinit_rembg":
|
||||
# REMBG(배경제거) 모듈 재준비: 현재 Bria 모듈 사용시 세션/프로바이더를 재평가하도록 None 초기화
|
||||
try:
|
||||
# toggle_states의 provider override는 상위에서 반영되어 내려옴
|
||||
# BriaBackgroundRemovalModule은 lazy-load 방식 → 재생성 유도
|
||||
if hasattr(processor, 'background_removal_module'):
|
||||
try:
|
||||
del processor.background_removal_module
|
||||
except Exception:
|
||||
processor.background_removal_module = None
|
||||
from modules.bria_background_removal_module import BriaBackgroundRemovalModule
|
||||
# 경로/매개변수는 processor.toggle_states에서 유추(존재 시)
|
||||
model_path = processor.toggle_states.get('local_rembg_model_path')
|
||||
processor.background_removal_module = BriaBackgroundRemovalModule(
|
||||
logger=logger,
|
||||
default_model=processor.toggle_states.get('local_model_name', 'bria-rmbg-1.4'),
|
||||
gpu_manager=getattr(processor, 'gpu_manager', None),
|
||||
local_rembg_model_path=model_path,
|
||||
)
|
||||
data = {"ok": True}
|
||||
except Exception as e:
|
||||
logger.error(f"REMBG 재초기화 실패: {e}")
|
||||
data = {"ok": False, "error": str(e)}
|
||||
elif cmd == "reset_migan":
|
||||
# MIGAN 재구성(토글상 migan_use_cuda 등 변경 반영)
|
||||
try:
|
||||
from modules.migan_module import build_migan_from_toggle
|
||||
enhanced_toggle_states = processor.toggle_states.copy()
|
||||
# 가속 사용 플래그 명칭 정리(호환): migan_use_cuda -> migan_use_accel
|
||||
if 'migan_use_cuda' in enhanced_toggle_states and 'migan_use_accel' not in enhanced_toggle_states:
|
||||
enhanced_toggle_states['migan_use_accel'] = enhanced_toggle_states['migan_use_cuda']
|
||||
# provider override가 들어왔으면 반영(auto|dml|cpu)
|
||||
prov = kwargs.get('provider')
|
||||
if prov:
|
||||
enhanced_toggle_states['migan_provider_override'] = prov
|
||||
# gpu_manager 상태와 무관하게 토글을 그대로 전달(직접 폴백은 모듈 내부)
|
||||
processor.migan = build_migan_from_toggle(enhanced_toggle_states, logger=logger, gpu_manager=getattr(processor, 'gpu_manager', None))
|
||||
data = {"ok": bool(processor.migan is not None)}
|
||||
except Exception as mm_err:
|
||||
logger.error(f"MIGAN 재설정 실패: {mm_err}")
|
||||
data = {"ok": False, "error": str(mm_err)}
|
||||
elif cmd == "__PING__":
|
||||
# 하트비트 응답
|
||||
data = "__PONG__"
|
||||
else:
|
||||
raise ValueError(f"unknown cmd: {cmd}")
|
||||
|
||||
logger.debug(f"작업 결과 반환 중: uid={uid}")
|
||||
result_q.put({"id": uid, "data": data})
|
||||
logger.debug(f"작업 결과 반환 완료: uid={uid}")
|
||||
except Exception:
|
||||
logger.error(f"작업 처리 중 오류: cmd={cmd}, uid={uid}")
|
||||
logger.error("작업 처리 중 오류", exc_info=True)
|
||||
result_q.put({"id": uid, "error": traceback.format_exc()})
|
||||
|
After Width: | Height: | Size: 102 KiB |
|
After Width: | Height: | Size: 380 KiB |
|
After Width: | Height: | Size: 261 KiB |
|
After Width: | Height: | Size: 112 KiB |
|
After Width: | Height: | Size: 238 KiB |
|
After Width: | Height: | Size: 97 KiB |
|
After Width: | Height: | Size: 165 KiB |
|
After Width: | Height: | Size: 118 KiB |
|
After Width: | Height: | Size: 1.7 KiB |
|
After Width: | Height: | Size: 7.2 KiB |
|
After Width: | Height: | Size: 802 KiB |
|
After Width: | Height: | Size: 479 KiB |
|
After Width: | Height: | Size: 1.3 MiB |
|
After Width: | Height: | Size: 1.3 MiB |
|
After Width: | Height: | Size: 569 KiB |
|
After Width: | Height: | Size: 569 KiB |
|
After Width: | Height: | Size: 484 KiB |
|
After Width: | Height: | Size: 1.3 MiB |
|
After Width: | Height: | Size: 537 KiB |
|
After Width: | Height: | Size: 91 KiB |
|
After Width: | Height: | Size: 321 KiB |
|
After Width: | Height: | Size: 611 KiB |
|
After Width: | Height: | Size: 701 KiB |
|
After Width: | Height: | Size: 970 KiB |
|
After Width: | Height: | Size: 384 KiB |