Refactor SimpleLamaInpainter: streamline inpainting process by removing unnecessary device checks and enhance postprocessing with NaN/Inf safety handling.
This commit is contained in:
parent
907d28c8bf
commit
249422c227
|
|
@ -156,11 +156,7 @@ class SimpleLamaInpainter:
|
|||
# 성능 최적화: AMP + cuDNN benchmark
|
||||
torch.backends.cudnn.benchmark = True
|
||||
with torch.no_grad():
|
||||
if self._device.type == 'cuda':
|
||||
with torch.cuda.amp.autocast(enabled=True):
|
||||
inpainted_batch = self._model.model(image_batch, mask_batch)
|
||||
else:
|
||||
inpainted_batch = self._model.model(image_batch, mask_batch)
|
||||
inpainted_batch = self._model.model(image_batch, mask_batch)
|
||||
|
||||
# 후처리
|
||||
result_images = []
|
||||
|
|
@ -190,8 +186,10 @@ class SimpleLamaInpainter:
|
|||
def _postprocess(self, tensor: torch.Tensor, original_size: Tuple[int, int], original_image: Image.Image, original_mask: Image.Image) -> Image.Image:
|
||||
"""모델 출력 텐서를 PIL 이미지로 후처리하고 원본에 합성합니다."""
|
||||
# 텐서를 PIL 이미지로 변환
|
||||
result_np = tensor.permute(1, 2, 0).cpu().numpy()
|
||||
result_np = np.clip(result_np * 255, 0, 255).astype(np.uint8)
|
||||
result_np = tensor.permute(1, 2, 0).detach().float().cpu().numpy()
|
||||
# NaN/Inf 안전 처리 후 범위 클램프
|
||||
result_np = np.nan_to_num(result_np, nan=0.0, posinf=1.0, neginf=0.0)
|
||||
result_np = (np.clip(result_np, 0.0, 1.0) * 255.0).astype(np.uint8)
|
||||
inpainted_image_512 = Image.fromarray(result_np)
|
||||
|
||||
# 원본 크기로 리사이즈
|
||||
|
|
|
|||
Loading…
Reference in New Issue