151 lines
6.7 KiB
Python
151 lines
6.7 KiB
Python
# 필요한 라이브러리를 임포트합니다.
|
|
import torch
|
|
import torchvision
|
|
import torchvision.models as models
|
|
import torchvision.transforms as transforms
|
|
import torchvision.models.detection as detection
|
|
from PIL import Image, ImageDraw
|
|
import requests
|
|
from io import BytesIO
|
|
import numpy as np
|
|
|
|
|
|
'''
|
|
import torch
|
|
import torchvision.models as models
|
|
|
|
# GPU 사용 가능 여부 확인
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
print(f"Using device: {device}")
|
|
|
|
# 모델을 device로 이동
|
|
model = torchvision.models.detection.maskrcnn_resnet50_fpn(weights=torchvision.models.detection.MaskRCNN_ResNet50_FPN_Weights.DEFAULT).to(device)
|
|
model.eval()
|
|
|
|
# 데이터도 GPU로 이동
|
|
# 예를 들어, 이미지를 모델에 입력하기 전에
|
|
image = image.to(device)
|
|
|
|
'''
|
|
|
|
|
|
# GPU 사용 설정
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
# 객체 탐지 모델 로드 (Mask R-CNN)
|
|
# detection_model = torchvision.models.detection.maskrcnn_resnet50_fpn(weights=torchvision.models.detection.MaskRCNN_ResNet50_FPN_Weights.DEFAULT).to(device)
|
|
# fastrcnn_model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=torchvision.models.detection.FasterRCNN_ResNet50_FPN_Weights.DEFAULT).to(device)
|
|
# detection_model = torchvision.models.detection.retinanet_resnet50_fpn(weights=torchvision.models.detection.RetinaNet_ResNet50_FPN_Weights.DEFAULT).to(device)
|
|
|
|
# detection_model.eval()
|
|
|
|
# 사전 학습된 모델을 불러옵니다.
|
|
# feature_model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT).to(device)
|
|
# feature_model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT).to(device)
|
|
# feature_model = models.resnet101(weights=models.ResNet101_Weights.DEFAULT).to(device)
|
|
feature_model = models.vgg16(weights=models.VGG16_Weights.DEFAULT).to(device)
|
|
# feature_model = models.inception_v3(weights=models.Inception_V3_Weights.DEFAULT).to(device)
|
|
# feature_model = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.DEFAULT).to(device)
|
|
# feature_model = models.vit_b_16(weights=models.ViT_B_16_Weights.DEFAULT).to(device)
|
|
feature_model.eval()
|
|
|
|
# print("모델을 성공적으로 불러왔습니다.")
|
|
|
|
# model.eval() # 모델을 평가 모드로 설정합니다.
|
|
|
|
# InceptionV3 전처리 예시
|
|
# preprocess = transforms.Compose([
|
|
# transforms.Resize(299),
|
|
# transforms.CenterCrop(299),
|
|
# transforms.ToTensor(),
|
|
# transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
|
# ])
|
|
|
|
# # # 이미지 전처리를 위한 변환을 정의합니다.
|
|
preprocess = transforms.Compose([
|
|
transforms.Resize(256),
|
|
transforms.CenterCrop(224),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
|
])
|
|
# print("이미지 전처리 설정을 완료했습니다.")
|
|
|
|
# 이미지 URL로부터 이미지를 다운로드하고 전처리하는 함수를 정의합니다.
|
|
# 이미지 다운로드 및 전처리
|
|
def download_and_preprocess_image(url):
|
|
response = requests.get(url)
|
|
img = Image.open(BytesIO(response.content)).convert('RGB')
|
|
img_preprocessed = preprocess(img).to(device)
|
|
return img_preprocessed
|
|
|
|
# 객체 탐지 및 배경 제거
|
|
# def remove_background(image_tensor):
|
|
# with torch.no_grad():
|
|
# prediction = detection_model([image_tensor])[0]
|
|
|
|
# # 가장 확률이 높은 객체의 마스크 추출
|
|
# mask = prediction['masks'][0, 0] > 0.5
|
|
# image_tensor_masked = image_tensor * mask.float()
|
|
# return image_tensor_masked
|
|
|
|
# 이미지 특징을 추출하는 함수를 정의합니다.
|
|
def extract_features(image_tensor):
|
|
with torch.no_grad():
|
|
# feature_model의 출력을 조정해야 할 수도 있음 (여기서는 단순화를 위해 직접 사용)
|
|
features = feature_model(image_tensor.unsqueeze(0))
|
|
return features
|
|
|
|
# 코사인 유사도 계산
|
|
def cosine_similarity(feature1, feature2):
|
|
cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
|
|
similarity = cos(feature1, feature2)
|
|
return similarity
|
|
|
|
# 이미지 비교 및 가장 비슷한 이미지 찾기
|
|
def find_most_similar_image(source_url, target_urls):
|
|
source_image = download_and_preprocess_image(source_url)
|
|
source_features = extract_features(source_image)
|
|
# source_image_masked = remove_background(source_image)
|
|
# source_features = extract_features(source_image_masked)
|
|
|
|
similarities = []
|
|
for url in target_urls:
|
|
target_image = download_and_preprocess_image(url)
|
|
# target_image_masked = remove_background(target_image)
|
|
# target_features = extract_features(target_image_masked)
|
|
target_features = extract_features(target_image)
|
|
similarity = cosine_similarity(source_features, target_features)
|
|
similarities.append(similarity.item())
|
|
print(f" 이미지 유사도 점수: {similarity.item()}")
|
|
most_similar_index = np.argmax(similarities)
|
|
most_similar_score = similarities[most_similar_index]
|
|
|
|
return most_similar_index + 1, most_similar_score
|
|
|
|
# 이미지 비교
|
|
def find_most_similar_image_by_one(source_url, target_urls):
|
|
source_image = download_and_preprocess_image(source_url)
|
|
source_features = extract_features(source_image)
|
|
target_image = download_and_preprocess_image(url)
|
|
target_features = extract_features(target_image)
|
|
similarity = cosine_similarity(source_features, target_features)
|
|
|
|
return similarity
|
|
|
|
# # 주어진 이미지 URL 예시와 비교를 시작합니다.
|
|
# try:
|
|
# source_image_url = "https://file.percenty.co.kr/public/652bed8e865b1f32ea62bf1f/products/65a3cb4a8e425f4b290089a1/344e23cb-669c-45cd-8df1-42b85fbf0c93.jpg"
|
|
# target_image_urls = [
|
|
# "https://file.percenty.co.kr/public/652bed8e865b1f32ea62bf1f/products/65a3cb4a8e425f4b290089a1/28f0dd02-21c2-4314-b863-7a6076c92612.jpg",
|
|
# "https://file.percenty.co.kr/public/652bed8e865b1f32ea62bf1f/products/65a3cb4a8e425f4b290089a1/99033f24-61d4-4a0e-9ad1-5dea457c33c4.jpg",
|
|
# "https://file.percenty.co.kr/public/652bed8e865b1f32ea62bf1f/products/65a3cb4a8e425f4b290089a1/ecb80f2f-bc95-4e64-801a-ce8f4cdf13a4.jpg",
|
|
# "https://file.percenty.co.kr/public/652bed8e865b1f32ea62bf1f/products/65a3cb4a8e425f4b290089a1/a3b9a3c1-ac04-4e55-be6d-f56eaa4187bb.jpg",
|
|
# "https://file.percenty.co.kr/public/652bed8e865b1f32ea62bf1f/products/65a3cb4a8e425f4b290089a1/7f42c677-0bc7-4e71-9a99-6bd14163d597.jpg",
|
|
# "https://file.percenty.co.kr/public/652bed8e865b1f32ea62bf1f/products/65a3cb4a8e425f4b290089a1/a39c5c73-b568-42fe-a9de-f364e2d31139.jpg"
|
|
# ]
|
|
|
|
# most_similar_index, similarity_score = find_most_similar_image(source_image_url, target_image_urls)
|
|
# print(f"\n가장 비슷한 이미지: 이미지 {most_similar_index}\n유사도 점수: {similarity_score}")
|
|
# except Exception as e:
|
|
# print(f"에러 발생: {e}")
|