AutoPercenty/ai/compare.py

151 lines
6.7 KiB
Python

# 필요한 라이브러리를 임포트합니다.
import torch
import torchvision
import torchvision.models as models
import torchvision.transforms as transforms
import torchvision.models.detection as detection
from PIL import Image, ImageDraw
import requests
from io import BytesIO
import numpy as np
'''
import torch
import torchvision.models as models
# GPU 사용 가능 여부 확인
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# 모델을 device로 이동
model = torchvision.models.detection.maskrcnn_resnet50_fpn(weights=torchvision.models.detection.MaskRCNN_ResNet50_FPN_Weights.DEFAULT).to(device)
model.eval()
# 데이터도 GPU로 이동
# 예를 들어, 이미지를 모델에 입력하기 전에
image = image.to(device)
'''
# GPU 사용 설정
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 객체 탐지 모델 로드 (Mask R-CNN)
# detection_model = torchvision.models.detection.maskrcnn_resnet50_fpn(weights=torchvision.models.detection.MaskRCNN_ResNet50_FPN_Weights.DEFAULT).to(device)
# fastrcnn_model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=torchvision.models.detection.FasterRCNN_ResNet50_FPN_Weights.DEFAULT).to(device)
# detection_model = torchvision.models.detection.retinanet_resnet50_fpn(weights=torchvision.models.detection.RetinaNet_ResNet50_FPN_Weights.DEFAULT).to(device)
# detection_model.eval()
# 사전 학습된 모델을 불러옵니다.
# feature_model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT).to(device)
# feature_model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT).to(device)
# feature_model = models.resnet101(weights=models.ResNet101_Weights.DEFAULT).to(device)
feature_model = models.vgg16(weights=models.VGG16_Weights.DEFAULT).to(device)
# feature_model = models.inception_v3(weights=models.Inception_V3_Weights.DEFAULT).to(device)
# feature_model = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.DEFAULT).to(device)
# feature_model = models.vit_b_16(weights=models.ViT_B_16_Weights.DEFAULT).to(device)
feature_model.eval()
# print("모델을 성공적으로 불러왔습니다.")
# model.eval() # 모델을 평가 모드로 설정합니다.
# InceptionV3 전처리 예시
# preprocess = transforms.Compose([
# transforms.Resize(299),
# transforms.CenterCrop(299),
# transforms.ToTensor(),
# transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
# ])
# # # 이미지 전처리를 위한 변환을 정의합니다.
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# print("이미지 전처리 설정을 완료했습니다.")
# 이미지 URL로부터 이미지를 다운로드하고 전처리하는 함수를 정의합니다.
# 이미지 다운로드 및 전처리
def download_and_preprocess_image(url):
response = requests.get(url)
img = Image.open(BytesIO(response.content)).convert('RGB')
img_preprocessed = preprocess(img).to(device)
return img_preprocessed
# 객체 탐지 및 배경 제거
# def remove_background(image_tensor):
# with torch.no_grad():
# prediction = detection_model([image_tensor])[0]
# # 가장 확률이 높은 객체의 마스크 추출
# mask = prediction['masks'][0, 0] > 0.5
# image_tensor_masked = image_tensor * mask.float()
# return image_tensor_masked
# 이미지 특징을 추출하는 함수를 정의합니다.
def extract_features(image_tensor):
with torch.no_grad():
# feature_model의 출력을 조정해야 할 수도 있음 (여기서는 단순화를 위해 직접 사용)
features = feature_model(image_tensor.unsqueeze(0))
return features
# 코사인 유사도 계산
def cosine_similarity(feature1, feature2):
cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
similarity = cos(feature1, feature2)
return similarity
# 이미지 비교 및 가장 비슷한 이미지 찾기
def find_most_similar_image(source_url, target_urls):
source_image = download_and_preprocess_image(source_url)
source_features = extract_features(source_image)
# source_image_masked = remove_background(source_image)
# source_features = extract_features(source_image_masked)
similarities = []
for url in target_urls:
target_image = download_and_preprocess_image(url)
# target_image_masked = remove_background(target_image)
# target_features = extract_features(target_image_masked)
target_features = extract_features(target_image)
similarity = cosine_similarity(source_features, target_features)
similarities.append(similarity.item())
print(f" 이미지 유사도 점수: {similarity.item()}")
most_similar_index = np.argmax(similarities)
most_similar_score = similarities[most_similar_index]
return most_similar_index + 1, most_similar_score
# 이미지 비교
def find_most_similar_image_by_one(source_url, target_urls):
source_image = download_and_preprocess_image(source_url)
source_features = extract_features(source_image)
target_image = download_and_preprocess_image(url)
target_features = extract_features(target_image)
similarity = cosine_similarity(source_features, target_features)
return similarity
# # 주어진 이미지 URL 예시와 비교를 시작합니다.
# try:
# source_image_url = "https://file.percenty.co.kr/public/652bed8e865b1f32ea62bf1f/products/65a3cb4a8e425f4b290089a1/344e23cb-669c-45cd-8df1-42b85fbf0c93.jpg"
# target_image_urls = [
# "https://file.percenty.co.kr/public/652bed8e865b1f32ea62bf1f/products/65a3cb4a8e425f4b290089a1/28f0dd02-21c2-4314-b863-7a6076c92612.jpg",
# "https://file.percenty.co.kr/public/652bed8e865b1f32ea62bf1f/products/65a3cb4a8e425f4b290089a1/99033f24-61d4-4a0e-9ad1-5dea457c33c4.jpg",
# "https://file.percenty.co.kr/public/652bed8e865b1f32ea62bf1f/products/65a3cb4a8e425f4b290089a1/ecb80f2f-bc95-4e64-801a-ce8f4cdf13a4.jpg",
# "https://file.percenty.co.kr/public/652bed8e865b1f32ea62bf1f/products/65a3cb4a8e425f4b290089a1/a3b9a3c1-ac04-4e55-be6d-f56eaa4187bb.jpg",
# "https://file.percenty.co.kr/public/652bed8e865b1f32ea62bf1f/products/65a3cb4a8e425f4b290089a1/7f42c677-0bc7-4e71-9a99-6bd14163d597.jpg",
# "https://file.percenty.co.kr/public/652bed8e865b1f32ea62bf1f/products/65a3cb4a8e425f4b290089a1/a39c5c73-b568-42fe-a9de-f364e2d31139.jpg"
# ]
# most_similar_index, similarity_score = find_most_similar_image(source_image_url, target_image_urls)
# print(f"\n가장 비슷한 이미지: 이미지 {most_similar_index}\n유사도 점수: {similarity_score}")
# except Exception as e:
# print(f"에러 발생: {e}")