diff --git a/re/export.py b/re/export.py new file mode 100644 index 0000000..3b0fdc3 --- /dev/null +++ b/re/export.py @@ -0,0 +1,15 @@ +# RETHINED 모델 export.py +import torch +from rethined import RethinedModel # 논문 제공 코드 임포트 + +model = RethinedModel().eval().to('cuda') +H, W = 512, 512 # 또는 원하는 해상도 +dummy = torch.randn(1, 3, H, W, device='cuda') + +torch.onnx.export( + model, dummy, "rethined.onnx", + opset_version=14, + input_names=['input'], + output_names=['output'], + dynamic_axes={'input':{2:'height',3:'width'}, 'output':{2:'height',3:'width'}} +)