AutoPercenty3/test/paddle2onnx/opt_fp16_slim_ppocr.py

168 lines
5.4 KiB
Python

import argparse, os, sys
import onnx
from onnx import shape_inference, checker
import onnxoptimizer
# onnx-simplifier
try:
from onnxsim import simplify as onnx_simplify
except Exception:
onnx_simplify = None
# FP16 변환
try:
from onnxconverter_common import float16
except Exception as e:
print("[WARN] onnxconverter-common 미설치 혹은 구버전: FP16 변환은 생략됩니다.", e)
float16 = None
# (선택) ORT 동적 양자화
try:
from onnxruntime.quantization import quantize_dynamic, QuantType
except Exception:
quantize_dynamic, QuantType = None, None
def load_model(path):
m = onnx.load(path)
checker.check_model(m)
return m
def save_model(m, path):
checker.check_model(m)
onnx.save(m, path)
print(f"[OK] saved -> {path}")
def run_shape_infer(m):
try:
return shape_inference.infer_shapes(m)
except Exception as e:
print("[WARN] onnx.shape_inference 실패 — 동적 shape가 많아도 추론엔 문제없음:", e)
return m
def run_simplify(m, dynamic_input_shape=True):
if onnx_simplify is None:
print("[SKIP] onnx-simplifier(onnxsim) 미사용")
return m
try:
sm, ok = onnx_simplify(m)
if not ok:
print("[WARN] onnxsim 검증 실패 — 원본 그래프 유지")
return m
return sm
except Exception as e:
print("[WARN] onnxsim 에러 — 원본 그래프 유지:", e)
return m
def run_optimize(m):
# 보수적인 기본 패스 세트
passes = [
"eliminate_deadend",
"eliminate_identity",
"eliminate_nop_transpose",
"eliminate_nop_pad",
"eliminate_unused_initializer",
"fuse_consecutive_transposes",
"fuse_transpose_into_gemm",
"fuse_add_bias_into_conv",
"fuse_bn_into_conv",
"fuse_pad_into_conv",
"fuse_matmul_add_bias_into_gemm",
]
avail = set(onnxoptimizer.get_available_passes())
use = [p for p in passes if p in avail]
try:
return onnxoptimizer.optimize(m, use)
except Exception as e:
print("[WARN] onnxoptimizer 실패 — 원본 그래프 유지:", e)
return m
def run_fp16(m, keep_io_fp32=True):
if float16 is None:
print("[SKIP] FP16 변환 모듈 없음")
return m
try:
# 일부 op은 FP16 미지원일 수 있어 자동으로 통과됨
return float16.convert_float_to_float16(
m,
keep_io_types=keep_io_fp32, # 입출력은 FP32 유지(파이프라인 호환)
# op_block_list=set([]), # 문제가 생기면 여기에 차단 op 추가
)
except Exception as e:
print("[WARN] FP16 변환 실패 — 원본 그래프 유지:", e)
return m
def run_dyn_quant(in_path, out_path):
if quantize_dynamic is None:
print("[SKIP] ORT 동적 양자화 모듈 없음")
return
try:
# MatMul/Gemm 위주 — Conv는 동적양자화 지원 제한
quantize_dynamic(
model_input=in_path,
model_output=out_path,
per_channel=False,
weight_type=QuantType.QInt8,
optimize_model=False,
op_types_to_quantize=["MatMul","Gemm"]
)
print(f"[OK] 동적 양자화 saved -> {out_path}")
except Exception as e:
print("[WARN] 동적 양자화 실패:", e)
def pipeline_one(in_path, out_dir, do_fp16=True, do_quant_dynamic=False, tag=""):
os.makedirs(out_dir, exist_ok=True)
base = os.path.splitext(os.path.basename(in_path))[0]
print(f"\n=== [{tag}] {base} ===")
m = load_model(in_path)
# 1) Shape inference
m = run_shape_infer(m)
save_model(m, os.path.join(out_dir, f"{base}.shaped.onnx"))
# 2) Simplify (동적 입력 허용)
m = run_simplify(m, dynamic_input_shape=True)
save_model(m, os.path.join(out_dir, f"{base}.simp.onnx"))
# 3) Graph optimize
m = run_optimize(m)
save_model(m, os.path.join(out_dir, f"{base}.opt.onnx"))
# 4) FP16 (선택) — CPU에선 속도이득은 제한적, 메모리/모델크기 이득
if do_fp16:
mf16 = run_fp16(m, keep_io_fp32=True)
save_model(mf16, os.path.join(out_dir, f"{base}.fp16.onnx"))
# (선택) FP16 + 동적양자화를 같이 쓰진 않는 게 일반적
# 5) (선택) ORT 동적 양자화 (MatMul/Gemm 위주)
if do_quant_dynamic:
run_dyn_quant(
os.path.join(out_dir, f"{base}.opt.onnx"),
os.path.join(out_dir, f"{base}.dq.onnx"),
)
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--det", help="det onnx path")
ap.add_argument("--rec", help="rec onnx path")
ap.add_argument("--cls", help="cls onnx path")
ap.add_argument("--outdir", required=True, help="output dir")
ap.add_argument("--fp16", action="store_true", help="export fp16 model")
ap.add_argument("--dq", action="store_true", help="onnxruntime dynamic quantization")
args = ap.parse_args()
if not any([args.det, args.rec, args.cls]):
print("하나 이상 지정: --det / --rec / --cls")
sys.exit(1)
if args.det:
pipeline_one(args.det, args.outdir, do_fp16=args.fp16, do_quant_dynamic=args.dq, tag="det")
if args.rec:
pipeline_one(args.rec, args.outdir, do_fp16=args.fp16, do_quant_dynamic=args.dq, tag="rec")
if args.cls:
pipeline_one(args.cls, args.outdir, do_fp16=args.fp16, do_quant_dynamic=args.dq, tag="cls")
if __name__ == "__main__":
main()