IT_Server/modules/iopaint_models/iop_test.py

32 lines
1.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# iop_tensorrt.py
import torch
from torch2trt import torch2trt, TRTModule
# --- 1. 모델 정의 및 로드 (nn.Module 형태) ---
# MIGAN 예시: 원본 모델 클래스 import 필요
from model_zoo.migan_inference import MIGAN # 실제 경로로 수정하세요
model = MIGAN().eval().cuda()
# 필요 시 weight 로드
model.load_state_dict(torch.load("migan_traced_weights.pth")) # weights 파일 위치
# --- 2. 더미 입력 (4채널: RGB + 마스크) ---
H, W = 512, 512 # 이미지 크기에 맞게 조정
dummy = torch.randn((1, 4, H, W)).cuda().half() # FP16 입력 가능
# --- 3. TensorRT 변환 ---
model_trt = torch2trt(
model, [dummy],
fp16_mode=True,
max_batch_size=1,
default_device_type=None, # DLA 사용 시 trt.DeviceType.DLA 지정 가능
gpu_fallback=True
)
# --- 4. TRTModule wrapping 및 저장 ---
trt_mod = TRTModule()
trt_mod.load_state_dict(model_trt.state_dict())
torch.save(trt_mod.state_dict(), "migan_trt.pth")
print("✅ TensorRT 변환 완료 — migan_trt.pth 저장됨")