首頁  >  文章  >  後端開發  >  使用 YOLO 和 CLIP 來改進檢索

使用 YOLO 和 CLIP 來改進檢索

WBOY
WBOY原創
2024-08-05 21:58:521177瀏覽

在本文中,我們將了解如何使用 YOLO 等物件偵測模型以及 CLIP 等多模態嵌入模型來更好地進行影像檢索。

這個想法是:CLIP 影像檢索的工作原理如下:我們使用 CLIP 模型嵌入我們擁有的影像並將它們儲存在某個地方,例如向量資料庫中。然後,在推理過程中,我們可以使用查詢圖像或提示,將其嵌入,並從可檢索的儲存嵌入中找到最接近的圖像。問題是當嵌入圖像有太多物件或某些物件在背景中時,我們仍然希望我們的系統檢索它們。這是因為 CLIP 將圖像作為一個整體嵌入。可以想像為詞嵌入模型與句子嵌入模型的關係。我們希望能夠搜尋與圖像中的物件等效的單字。因此,解決方案是使用物件偵測模型將影像分解為不同的物件。然後,嵌入這些分解的圖像,但將它們連結到其父圖像。這將使我們能夠檢索作物並獲得作物起源的親本。 讓我們看看它是如何工作的。

安裝相依性並導入它們

!pip install -q ultralytics torch matplotlib numpy pillow zipfile36 transformers

from ultralytics import YOLO
import matplotlib.pyplot as plt
from PIL import pillow
import os
from Zipfile import Zipfile, BadZipFile
import torch
from transformers import CLIPProcessor, CLIPModel, CLIPVisionModelWithProjection, CLIPTextModelWithProjection

下載COCO資料集並解壓縮

!wget http://images.cocodataset.org/zips/val2017.zip -O coco_val2017.zip

def extract_zip_file(extract_path):
    try:
        with ZipFile(extract_path+".zip") as zfile:
            zfile.extractall(extract_path)
        # remove zipfile
        zfileTOremove=f"{extract_path}"+".zip"
        if os.path.isfile(zfileTOremove):
            os.remove(zfileTOremove)
        else:
            print("Error: %s file not found" % zfileTOremove)
    except BadZipFile as e:
        print("Error:", e)

extract_val_path = "./coco_val2017"
extract_zip_file(extract_val_path)

然後我們可以拍攝一些圖像並建立範例清單。

source = ['coco_val2017/val2017/000000000139.jpg', '/content/coco_val2017/val2017/000000000632.jpg', '/content/coco_val2017/val2017/000000000776.jpg', '/content/coco_val2017/val2017/000000001503.jpg', '/content/coco_val2017/val2017/000000001353.jpg', '/content/coco_val2017/val2017/000000003661.jpg']

初始化YOLO模型和CLIP模型

在此範例中,我們將使用最新的 Ultralytics Yolo10x 模型以及 OpenAI Clip-vit-base-patch32 。

device = "cuda"

 # YOLO Model
model = YOLO('yolov10x.pt')

# Clip model
model_id = "openai/clip-vit-base-patch32"
image_model = CLIPVisionModelWithProjection.from_pretrained(model_id, device_map = device)
text_model = CLIPTextModelWithProjection.from_pretrained(model_id, device_map = device)
processor = CLIPProcessor.from_pretrained(model_id)

運行檢測模型

results = model(source=source, device = "cuda")

讓我們用此程式碼片段向我們展示結果

# Visualize the results
fig, ax = plt.subplots(2, 3, figsize=(15, 10))

for i, r in enumerate(results):
    # Plot results image
    im_bgr = r.plot()  # BGR-order numpy array
    im_rgb = Image.fromarray(im_bgr[..., ::-1])  # RGB-order PIL image

    ax[i%2, i//2].imshow(im_rgb)
    ax[i%2, i//2].set_title(f"Image {i+1}")

Using YOLO with CLIP to improve Retrieval

所以我們可以看到YOLO模型在偵測影像中的物體方面效果很好。它確實會犯一些錯誤,將顯示器標記為電視。但那很好。 YOLO 分配的實際類別並不是那麼重要,因為我們將使用 CLIP 進行推理。

定義一些輔助類

class CroppedImage:

  def __init__(self, parent, box, cls):

    self.parent = parent
    self.box = box
    self.cls = cls

  def display(self, ax = None):
    im_rgb = Image.open(self.parent)
    cropped_image = im_rgb.crop(self.box)

    if ax is not None:
      ax.imshow(cropped_image)
      ax.set_title(self.cls)
    else:
      plt.figure(figsize=(10, 10))
      plt.imshow(cropped_image)
      plt.title(self.cls)
      plt.show()

  def get_cropped_image(self):
    im_rgb = Image.open(self.parent)
    cropped_image = im_rgb.crop(self.box)
    return cropped_image

  def __str__(self):
    return f"CroppedImage(parent={self.parent}, boxes={self.box}, cls={self.cls})"

  def __repr__(self):
    return self.__str__()

class YOLOImage:
  def __init__(self, image_path, cropped_images):
    self.image_path = str(image_path)
    self.cropped_images = cropped_images

  def get_image(self):
    return Image.open(self.image_path)

  def get_caption(self):
    cls  =[]
    for cropped_image in self.cropped_images:
      cls.append(cropped_image.cls)

    unique_cls = set(cls)
    count_cls = {cls: cls.count(cls) for cls in unique_cls}

    count_string = " ".join(f"{count} {cls}," for cls, count in count_cls.items())
    return "this image contains " + count_string

  def __str__(self):
    return self.__repr__()

  def __repr__(self):
    cls  =[]
    for cropped_image in self.cropped_images:
      cls.append(cropped_image.cls)

    return f"YOLOImage(image={self.image_path}, cropped_images={cls})"

class ImageEmbedding:
  def __init__(self, image_path, embedding, cropped_image = None):
    self.image_path = image_path
    self.cropped_image = cropped_image
    self.embedding = embedding

裁切影像類

CroppedImage 類別表示從較大的父圖像中裁剪出的圖像的一部分。它使用父圖像的路徑、定義裁剪區域的邊界框和類別標籤(例如“貓”或“狗”)進行初始化。此類別包含顯示裁剪影像並將其作為影像物件檢索的方法。此顯示方法允許在提供的軸上或透過建立新圖形來視覺化裁剪部分,使其適用於不同的用例。此外,還實作了 __str__ 和 __repr__ 方法,以便輕鬆且資訊豐富地表示物件的字串。

YOLO圖像類

YOLOImage 類別旨在處理使用 YOLO 物件偵測模型處理的影像。它取得原始影像的路徑和代表影像中偵測到的物件的 CroppedImage 實例清單。該類別提供了打開和顯示完整圖像以及生成總結圖像中檢測到的物件的標題的方法。標題方法聚合並計算裁剪圖像中的唯一類別標籤,提供圖像內容的簡潔描述。此類對於管理和解釋對象檢測任務的結果特別有用。

影像嵌入類

ImageEmbedding 類別具有影像及其關聯的嵌入,它是影像特徵的數位表示。可以使用影像的路徑、嵌入向量以及可選的 CroppedImage 實例(如果嵌入對應於影像的特定裁剪部分)來初始化此類。 ImageEmbedding 類別對於涉及影像相似性、分類和檢索的任務至關重要,因為它提供了一種結構化方法來儲存和存取影像資料及其計算特徵。這種整合促進了高效的影像處理和機器學習工作流程。

裁剪每個圖像並建立 YOLOImage 物件列表

yolo_images: list[YOLOImage]= []

names= model.names

for i, r in enumerate(results):
  crops:list[CroppedImage] = []
  boxes = r.boxes
  classes = r.boxes.cls
  for j, box in enumerate(r.boxes):
    box = tuple(box.xyxy.flatten().cpu().numpy())
    cropped_image = CroppedImage(parent = r.path, box = box, cls = names[classes[j].int().item()])
    crops.append(cropped_image)
  yolo_images.append(YOLOImage(image_path=r.path, cropped_images=crops))

使用 CLIP 嵌入圖像

image_embeddings = []

for image in yolo_images:
  input = processor.image_processor(images= image.get_image(), return_tensors = 'pt')
  input.to(device)
  embeddings = image_model(pixel_values = input.pixel_values).image_embeds
  embeddings = embeddings/embeddings.norm(p=2, dim = -1, keepdim = True) # Normalize the embeddings
  image_embedding = ImageEmbedding(image_path = image.image_path, embedding = embeddings)
  image_embeddings.append(image_embedding)

  for cropped_image in image.cropped_images:
    input = processor.image_processor(images= cropped_image.get_cropped_image(), return_tensors = 'pt')
    input.to(device)
    embeddings = image_model(pixel_values = input.pixel_values).image_embeds
    embeddings = embeddings/embeddings.norm(p=2, dim = -1, keepdim = True) # Normalize the embeddings

    image_embedding = ImageEmbedding(image_path = image.image_path, embedding = embeddings, cropped_image = cropped_image)
    image_embeddings.append(image_embedding)

   **image_embeddings_tensor = torch.stack([image_embedding.embedding for image_embedding in image_embeddings]).squeeze()**

如果願意,我們現在可以將這些圖像嵌入儲存在向量資料庫中。但在這個例子中,我們將僅使用內點積技術來檢查相似性並檢索影像。

檢索

query = "image of a flowerpot"

text_embedding = processor.tokenizer(query, return_tensors="pt").to(device)
text_embedding = text_model(**text_embedding).text_embeds

similarities = (torch.matmul(text_embedding, image_embeddings_tensor.T)).flatten().detach().cpu().numpy()

# get the top 5 similar images
k = 5
top_k_indices = similarities.argsort()[-k:]

# Display the top 5 results
fig, ax = plt.subplots(2, 5, figsize=(20, 5))
for i, index in enumerate(top_k_indices):
  if image_embeddings[index].cropped_image is not None:
    image_embeddings[index].cropped_image.display(ax = ax[0][i])
  else:
  ax[0][i].imshow(Image.open(image_embeddings[index].image_path))
  ax[1][i].imshow(Image.open(image_embeddings[index].image_path))
  ax[0][i].axis('off')
  ax[1][i].axis('off')
  ax[1][i].set_title("Original Image")
plt.show()

Using YOLO with CLIP to improve Retrieval

Using YOLO with CLIP to improve Retrieval
Using YOLO with CLIP to improve Retrieval
Using YOLO with CLIP to improve Retrieval

您可以看到,我們甚至能夠檢索隱藏在背景中的小植物。有時它也會拉出原始圖像作為結果,因為我們也嵌入了它。

這是一項非常強大的技術。您還可以微調您自己的影像的檢測和嵌入模型,並進一步提高效能。

一個缺點是我們必須對所有偵測到的物件執行 CLIP 模型。緩解這種情況的一種方法是限制 YOLO 生產的盒子數量。

您可以透過此連結查看 Colab 上的程式碼。

Using YOLO with CLIP to improve Retrieval


想要連接嗎?

?我的網站

?我的推特

?我的 LinkedIn

以上是使用 YOLO 和 CLIP 來改進檢索的詳細內容。更多資訊請關注PHP中文網其他相關文章!

陳述:
本文內容由網友自願投稿,版權歸原作者所有。本站不承擔相應的法律責任。如發現涉嫌抄襲或侵權的內容,請聯絡admin@php.cn