This commit is contained in:
K.H.CHOI 2024-03-27 15:47:23 +09:00
commit 4025daeb2b
3 changed files with 23 additions and 3 deletions

View File

@ -8,7 +8,7 @@ from PIL import Image, ImageDraw
import requests
from io import BytesIO
import numpy as np
import os
'''
import torch
@ -28,16 +28,36 @@ image = image.to(device)
'''
# GPU 사용 설정
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 모델 저장 경로 지정
model_path = os.getcwd()
# local 폴더가 없다면 생성
if not os.path.exists('./models'):
os.makedirs('./models')
# 모델 불러오기 또는 다운로드 후 저장
if os.path.isfile(model_path):
# 모델 파일이 이미 존재하면 불러오기
feature_model = torch.load(model_path)
feature_model = feature_model.to(device)
else:
# 모델 파일이 존재하지 않으면 다운로드 후 저장
feature_model = models.vgg16(weights=models.VGG16_Weights.DEFAULT).to(device)
torch.save(feature_model, model_path)
feature_model.eval()
# 객체 탐지 모델 로드 (Mask R-CNN)
# detection_model = torchvision.models.detection.maskrcnn_resnet50_fpn(weights=torchvision.models.detection.MaskRCNN_ResNet50_FPN_Weights.DEFAULT).to(device)
# fastrcnn_model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=torchvision.models.detection.FasterRCNN_ResNet50_FPN_Weights.DEFAULT).to(device)
# detection_model = torchvision.models.detection.retinanet_resnet50_fpn(weights=torchvision.models.detection.RetinaNet_ResNet50_FPN_Weights.DEFAULT).to(device)
# detection_model.eval()
# detection_model.eval()``
# 사전 학습된 모델을 불러옵니다.
# feature_model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT).to(device)

BIN
r.txt

Binary file not shown.

Binary file not shown.