1
0
Fork 0
AutoPercenty2/edit/compare.py

70 lines
2.7 KiB
Python

# 필요한 라이브러리를 임포트합니다.
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
import requests
from io import BytesIO
import numpy as np
# 사전 학습된 모델(예: ResNet)을 불러옵니다.
model = models.resnet18(pretrained=True)
model.eval() # 모델을 평가 모드로 설정합니다.
# 이미지 전처리를 위한 변환을 정의합니다.
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# 이미지 URL로부터 이미지를 다운로드하고 전처리하는 함수를 정의합니다.
def download_and_preprocess_image(url):
response = requests.get(url)
img = Image.open(BytesIO(response.content)).convert('RGB')
img_preprocessed = preprocess(img)
return img_preprocessed
# 이미지 특징을 추출하는 함수를 정의합니다.
def extract_features(image):
with torch.no_grad():
features = model(image.unsqueeze(0))
return features
# 두 이미지 특징 간의 코사인 유사도를 계산하는 함수를 정의합니다.
def cosine_similarity(feature1, feature2):
cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
similarity = cos(feature1, feature2)
return similarity
# 주어진 '원본이미지' URL과 여러 '대상이미지' URL을 비교하여 가장 비슷한 이미지를 찾는 함수를 정의합니다.
def find_most_similar_image(source_url, target_urls):
source_image = download_and_preprocess_image(source_url)
source_features = extract_features(source_image)
similarities = []
for url in target_urls:
target_image = download_and_preprocess_image(url)
target_features = extract_features(target_image)
similarity = cosine_similarity(source_features, target_features)
similarities.append(similarity.item())
# 가장 높은 유사도를 가진 이미지의 인덱스를 찾습니다.
most_similar_index = np.argmax(similarities)
most_similar_score = similarities[most_similar_index]
return target_urls[most_similar_index], most_similar_score
# 주어진 이미지 URL 예시 (실제 실행 시 주석 처리할 부분)
#source_image_url = "https://example.com/source.jpg"
#target_image_urls = [
# "https://example.com/target1.jpg",
# "https://example.com/target2.jpg",
# "https://example.com/target3.jpg",
#]
# 가장 비슷한 이미지를 찾습니다. (실제 실행 시 주석 처리할 부분)
#most_similar_url, similarity_score = find_most_similar_image(source_image_url, target_image_urls)
#print(f"Most similar image URL: {most_similar_url}, Similarity score: {similarity_score}")