AutoPercenty3/test/paddle2onnx/make_ppocr_onnx_dynamic.py

88 lines
3.1 KiB
Python

# make_ppocr_onnx_dynamic.py
import argparse
import onnx
from onnx import helper, shape_inference
def _set_dim(dim, value=None, param=None):
# value(정수) 또는 param(문자열 심볼) 한쪽만 설정
dim.ClearField('dim_value')
dim.ClearField('dim_param')
if value is not None:
dim.dim_value = int(value)
elif param is not None:
dim.dim_param = str(param)
def _shape_str(vi):
tp = vi.type.tensor_type
if not tp.HasField("shape"):
return "scalar"
dims = []
for d in tp.shape.dim:
if d.HasField("dim_value"):
dims.append(str(d.dim_value))
elif d.HasField("dim_param"):
dims.append(d.dim_param)
else:
dims.append("?")
return "[" + ",".join(dims) + "]"
def patch_ppocr_input(model_path, out_path, model_type):
"""
model_type:
- det / cls : NCHW에서 N,H,W를 동적으로
- rec : NCHW에서 N, W 동적 + H는 고정(원래 값이 있으면 유지, 없으면 48로)
"""
m = onnx.load(model_path)
if len(m.graph.input) == 0:
raise RuntimeError("No graph input found in ONNX model.")
# 보통 입력은 하나(x). 복수여도 첫 번째만 패치(일반적인 PPOCR는 1개)
vi = m.graph.input[0]
tp = vi.type.tensor_type
if not tp.HasField("shape") or len(tp.shape.dim) != 4:
raise RuntimeError(f"Expect 4D NCHW input, got: {_shape_str(vi)}")
n, c, h, w = tp.shape.dim # N,C,H,W
# C(채널)는 3 고정이 일반적이라 놔두고, 나머지만 처리
# det/cls: [N=?, C=3, H=?, W=?]
if model_type in ("det", "cls"):
_set_dim(n, param="N")
# H가 값으로 박혀 있든 없든 동적 처리
_set_dim(h, param="H")
_set_dim(w, param="W")
# rec: [N=?, C=3, H=고정, W=?] (중국어 계열은 보통 H=48, 영문 모바일계열은 32)
elif model_type == "rec":
_set_dim(n, param="N")
# H 유지(이미 값이 있으면 유지), 없으면 48로 설정
if not h.HasField("dim_value") and not h.HasField("dim_param"):
_set_dim(h, value=48) # 기본값 48
# W는 동적
_set_dim(w, param="W")
else:
raise ValueError("model_type must be one of: det / rec / cls")
# 모델 체크 + (가능하면) Shape Inference 수행
onnx.checker.check_model(m)
try:
m = shape_inference.infer_shapes(m)
except Exception as e:
# 동적 치수가 많으면 실패해도 무방. 패치는 이미 적용됨.
print(f"[warn] onnx shape_inference skipped: {e}")
onnx.save(m, out_path)
print(f"[ok] saved: {out_path}")
print(f" input 0 after patch: {_shape_str(m.graph.input[0])}")
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--in", dest="inp", required=True, help="input onnx path")
ap.add_argument("--out", dest="out", required=True, help="output onnx path")
ap.add_argument("--type", dest="typ", required=True, choices=["det","rec","cls"], help="ppocr model type")
args = ap.parse_args()
patch_ppocr_input(args.inp, args.out, args.typ)
if __name__ == "__main__":
main()