inpaint377/test1.py

80 lines
2.6 KiB
Python

import cv2, os
import numpy as np
import paddlehub as hub
from paddlehub.module.module import moduleinfo, runnable, serving
import paddle.fluid as fluid
from paddle.fluid.core import AnalysisConfig
# 환경 설정을 통해 직접 모델 경로 지정
def create_predictor(model_dir):
config = AnalysisConfig(os.path.join(model_dir, 'model.pdmodel'), os.path.join(model_dir, 'model.pdiparams'))
config.enable_use_gpu(100, 0)
config.switch_use_feed_fetch_ops(False)
predictor = fluid.core.create_paddle_predictor(config)
return predictor
# OCR 모듈 로드
def load_ocr_module():
model_dir = r'C:\modules\chinese_ocr_db_crnn_mobile\inference_model\character_rec'
try:
predictor = create_predictor(model_dir)
print("OCR module loaded successfully.")
return predictor
except Exception as e:
print(f"Failed to load OCR module: {e}")
return None
# 인페인팅 모듈 로드
def load_inpaint_module():
try:
inpaint_model = hub.Module(name='deepfill_v2')
print("Inpaint module loaded successfully.")
return inpaint_model
except Exception as e:
print(f"Failed to load Inpaint module: {e}")
return None
# 이미지에서 텍스트를 감지하고 마스크 생성
def detect_text_and_create_mask(image, predictor):
result = predictor.recognize_text(
images=[cv2.cvtColor(image, cv2.COLOR_BGR2RGB)],
use_gpu=False,
output_dir='ocr_result',
visualization=True
)
mask = np.zeros(image.shape[:2], dtype=np.uint8)
for line in result[0]['data']:
points = np.array(line['text_box_position'], dtype=np.int32)
cv2.fillPoly(mask, [points], 255)
return mask
# 이미지 인페인팅 수행
def inpaint_image(image, mask, inpaint_module):
input_dict = {'image': [image], 'mask': [mask]}
results = inpaint_module.Inpainting(data=input_dict, output_dir='inpainting_output', use_gpu=False)
return cv2.imread(results[0]['data'])
# 이미지 처리 실행
def process_image(image_path):
image = cv2.imread(image_path)
ocr_module = load_ocr_module()
inpaint_module = load_inpaint_module()
if ocr_module and inpaint_module:
mask = detect_text_and_create_mask(image, ocr_module)
inpainted_image = inpaint_image(image, mask, inpaint_module)
cv2.imshow('Original Image', image)
cv2.imshow('Mask', mask)
cv2.imshow('Inpainted Image', inpainted_image)
cv2.waitKey(0)
cv2.destroyAllWindows()
else:
print("Failed to load modules.")
# 이미지 경로
image_path = 'src/img/1.jpg'
process_image(image_path)