32 lines
1.0 KiB
Python
32 lines
1.0 KiB
Python
# iop_tensorrt.py
|
||
import torch
|
||
from torch2trt import torch2trt, TRTModule
|
||
|
||
# --- 1. 모델 정의 및 로드 (nn.Module 형태) ---
|
||
# MI‑GAN 예시: 원본 모델 클래스 import 필요
|
||
from model_zoo.migan_inference import MIGAN # 실제 경로로 수정하세요
|
||
|
||
model = MIGAN().eval().cuda()
|
||
# 필요 시 weight 로드
|
||
model.load_state_dict(torch.load("migan_traced_weights.pth")) # weights 파일 위치
|
||
|
||
# --- 2. 더미 입력 (4채널: RGB + 마스크) ---
|
||
H, W = 512, 512 # 이미지 크기에 맞게 조정
|
||
dummy = torch.randn((1, 4, H, W)).cuda().half() # FP16 입력 가능
|
||
|
||
# --- 3. TensorRT 변환 ---
|
||
model_trt = torch2trt(
|
||
model, [dummy],
|
||
fp16_mode=True,
|
||
max_batch_size=1,
|
||
default_device_type=None, # DLA 사용 시 trt.DeviceType.DLA 지정 가능
|
||
gpu_fallback=True
|
||
)
|
||
|
||
# --- 4. TRTModule wrapping 및 저장 ---
|
||
trt_mod = TRTModule()
|
||
trt_mod.load_state_dict(model_trt.state_dict())
|
||
torch.save(trt_mod.state_dict(), "migan_trt.pth")
|
||
|
||
print("✅ TensorRT 변환 완료 — migan_trt.pth 저장됨")
|