AutoPercenty3/tools/convert_paddle_to_onnx.py

169 lines
6.2 KiB
Python

#!/usr/bin/env python3
"""
PaddleOCR 모델을 ONNX 형식으로 변환하는 도구
ARM 환경에서 ONNX Runtime으로 실행하기 위해 사용
"""
import os
import sys
import argparse
import logging
from pathlib import Path
def setup_logging():
"""로깅 설정"""
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
return logging.getLogger(__name__)
def convert_paddle_model_to_onnx(model_dir, output_dir, model_type='det'):
"""
PaddleOCR 모델을 ONNX로 변환
Args:
model_dir (str): PaddleOCR 모델 디렉토리 (.pdmodel, .pdiparams 파일 위치)
output_dir (str): 변환된 ONNX 모델 저장 위치
model_type (str): 모델 타입 ('det', 'rec', 'cls')
"""
logger = setup_logging()
try:
import paddle2onnx
logger.info(f"paddle2onnx 버전: {paddle2onnx.__version__}")
except ImportError:
logger.error("paddle2onnx가 설치되지 않았습니다: pip install paddle2onnx")
return False
# 입력 파일 확인
model_file = os.path.join(model_dir, "inference.pdmodel")
params_file = os.path.join(model_dir, "inference.pdiparams")
if not os.path.exists(model_file):
logger.error(f"모델 파일을 찾을 수 없습니다: {model_file}")
return False
if not os.path.exists(params_file):
logger.error(f"파라미터 파일을 찾을 수 없습니다: {params_file}")
return False
# 출력 디렉토리 생성
os.makedirs(output_dir, exist_ok=True)
output_file = os.path.join(output_dir, "model.onnx")
try:
logger.info(f"변환 시작: {model_type} 모델")
logger.info(f"입력: {model_file}")
logger.info(f"출력: {output_file}")
# 모델 타입별 변환 설정
if model_type == 'det':
# Detection 모델 설정
onnx_model = paddle2onnx.command.c_paddle_to_onnx(
model_file=model_file,
params_file=params_file,
opset_version=11,
enable_onnx_checker=True,
auto_upgrade_opset=True,
save_file=output_file,
input_shape_dict={"x": [1, 3, 960, 960]} # Detection 입력 크기
)
elif model_type == 'rec':
# Recognition 모델 설정
onnx_model = paddle2onnx.command.c_paddle_to_onnx(
model_file=model_file,
params_file=params_file,
opset_version=11,
enable_onnx_checker=True,
auto_upgrade_opset=True,
save_file=output_file,
input_shape_dict={"x": [1, 3, 32, 100]} # Recognition 입력 크기
)
elif model_type == 'cls':
# Classification 모델 설정
onnx_model = paddle2onnx.command.c_paddle_to_onnx(
model_file=model_file,
params_file=params_file,
opset_version=11,
enable_onnx_checker=True,
auto_upgrade_opset=True,
save_file=output_file,
input_shape_dict={"x": [1, 3, 48, 192]} # Classification 입력 크기
)
else:
logger.error(f"지원하지 않는 모델 타입: {model_type}")
return False
if os.path.exists(output_file):
file_size = os.path.getsize(output_file) / (1024 * 1024) # MB
logger.info(f"✅ 변환 성공: {output_file} ({file_size:.2f} MB)")
return True
else:
logger.error("❌ 변환 실패: 출력 파일이 생성되지 않았습니다")
return False
except Exception as e:
logger.error(f"❌ 변환 중 오류 발생: {e}")
return False
def convert_all_models(base_model_dir, output_base_dir):
"""모든 PaddleOCR 모델을 ONNX로 변환"""
logger = setup_logging()
models = [
('det', 'PP_Models/det'),
('rec', 'PP_Models/rec'),
('cls', 'PP_Models/cls')
]
success_count = 0
total_count = len(models)
for model_type, model_subdir in models:
model_dir = os.path.join(base_model_dir, model_subdir)
output_dir = os.path.join(output_base_dir, model_subdir)
if os.path.exists(model_dir):
logger.info(f"\n🔄 {model_type.upper()} 모델 변환 중...")
if convert_paddle_model_to_onnx(model_dir, output_dir, model_type):
success_count += 1
else:
logger.warning(f"⚠️ {model_type} 모델 변환 실패")
else:
logger.warning(f"⚠️ {model_type} 모델 디렉토리를 찾을 수 없습니다: {model_dir}")
logger.info(f"\n📊 변환 완료: {success_count}/{total_count} 성공")
return success_count == total_count
def main():
parser = argparse.ArgumentParser(description="PaddleOCR 모델을 ONNX로 변환")
parser.add_argument("--model_dir", required=True, help="PaddleOCR 모델 디렉토리")
parser.add_argument("--output_dir", required=True, help="ONNX 모델 출력 디렉토리")
parser.add_argument("--model_type", choices=['det', 'rec', 'cls', 'all'],
default='all', help="변환할 모델 타입")
args = parser.parse_args()
logger = setup_logging()
logger.info("🚀 PaddleOCR → ONNX 변환 시작")
if args.model_type == 'all':
success = convert_all_models(args.model_dir, args.output_dir)
else:
success = convert_paddle_model_to_onnx(args.model_dir, args.output_dir, args.model_type)
if success:
logger.info("✅ 모든 변환 작업 완료!")
logger.info(f"📁 변환된 모델 위치: {args.output_dir}")
logger.info("\n🔧 사용법:")
logger.info("1. ARM 환경에서 requirements_arm.txt 설치")
logger.info("2. 변환된 ONNX 모델을 src/modules/PP_Models/ 디렉토리에 복사")
logger.info("3. 프로그램 실행시 자동으로 ONNX Runtime 백엔드 사용")
else:
logger.error("❌ 변환 작업 실패")
sys.exit(1)
if __name__ == "__main__":
main()