首页  >  文章  >  后端开发  >  使用 YOLO 和 CLIP 来改进检索

使用 YOLO 和 CLIP 来改进检索

WBOY
WBOY原创
2024-08-05 21:58:521177浏览

在本文中,我们将了解如何使用 YOLO 等对象检测模型以及 CLIP 等多模态嵌入模型来更好地进行图像检索。

这个想法是:CLIP 图像检索的工作原理如下:我们使用 CLIP 模型嵌入我们拥有的图像并将它们存储在某个地方,例如矢量数据库中。然后,在推理过程中,我们可以使用查询图像或提示,将其嵌入,并从可检索的存储嵌入中找到最接近的图像。问题是当嵌入图像有太多对象或某些对象在背景中时,我们仍然希望我们的系统检索它们。这是因为 CLIP 将图像作为一个整体嵌入。可以将其想象为词嵌入模型与句子嵌入模型的关系。我们希望能够搜索与图像中的对象等效的单词。因此,解决方案是使用对象检测模型将图像分解为不同的对象。然后,嵌入这些分解的图像,但将它们链接到其父图像。这将使我们能够检索作物并获得作物起源的亲本。 让我们看看它是如何工作的。

安装依赖项并导入它们

!pip install -q ultralytics torch matplotlib numpy pillow zipfile36 transformers

from ultralytics import YOLO
import matplotlib.pyplot as plt
from PIL import pillow
import os
from Zipfile import Zipfile, BadZipFile
import torch
from transformers import CLIPProcessor, CLIPModel, CLIPVisionModelWithProjection, CLIPTextModelWithProjection

下载COCO数据集并解压

!wget http://images.cocodataset.org/zips/val2017.zip -O coco_val2017.zip

def extract_zip_file(extract_path):
    try:
        with ZipFile(extract_path+".zip") as zfile:
            zfile.extractall(extract_path)
        # remove zipfile
        zfileTOremove=f"{extract_path}"+".zip"
        if os.path.isfile(zfileTOremove):
            os.remove(zfileTOremove)
        else:
            print("Error: %s file not found" % zfileTOremove)
    except BadZipFile as e:
        print("Error:", e)

extract_val_path = "./coco_val2017"
extract_zip_file(extract_val_path)

然后我们可以拍摄一些图像并创建示例列表。

source = ['coco_val2017/val2017/000000000139.jpg', '/content/coco_val2017/val2017/000000000632.jpg', '/content/coco_val2017/val2017/000000000776.jpg', '/content/coco_val2017/val2017/000000001503.jpg', '/content/coco_val2017/val2017/000000001353.jpg', '/content/coco_val2017/val2017/000000003661.jpg']

初始化YOLO模型和CLIP模型

在此示例中,我们将使用最新的 Ultralytics Yolo10x 模型以及 OpenAI Clip-vit-base-patch32 。

device = "cuda"

 # YOLO Model
model = YOLO('yolov10x.pt')

# Clip model
model_id = "openai/clip-vit-base-patch32"
image_model = CLIPVisionModelWithProjection.from_pretrained(model_id, device_map = device)
text_model = CLIPTextModelWithProjection.from_pretrained(model_id, device_map = device)
processor = CLIPProcessor.from_pretrained(model_id)

运行检测模型

results = model(source=source, device = "cuda")

让我们用此代码片段向我们展示结果

# Visualize the results
fig, ax = plt.subplots(2, 3, figsize=(15, 10))

for i, r in enumerate(results):
    # Plot results image
    im_bgr = r.plot()  # BGR-order numpy array
    im_rgb = Image.fromarray(im_bgr[..., ::-1])  # RGB-order PIL image

    ax[i%2, i//2].imshow(im_rgb)
    ax[i%2, i//2].set_title(f"Image {i+1}")

Using YOLO with CLIP to improve Retrieval

所以我们可以看到YOLO模型在检测图像中的物体方面效果很好。它确实会犯一些错误,将显示器标记为电视。但那很好。 YOLO 分配的实际类并不是那么重要,因为我们将使用 CLIP 进行推理。

定义一些辅助类

class CroppedImage:

  def __init__(self, parent, box, cls):

    self.parent = parent
    self.box = box
    self.cls = cls

  def display(self, ax = None):
    im_rgb = Image.open(self.parent)
    cropped_image = im_rgb.crop(self.box)

    if ax is not None:
      ax.imshow(cropped_image)
      ax.set_title(self.cls)
    else:
      plt.figure(figsize=(10, 10))
      plt.imshow(cropped_image)
      plt.title(self.cls)
      plt.show()

  def get_cropped_image(self):
    im_rgb = Image.open(self.parent)
    cropped_image = im_rgb.crop(self.box)
    return cropped_image

  def __str__(self):
    return f"CroppedImage(parent={self.parent}, boxes={self.box}, cls={self.cls})"

  def __repr__(self):
    return self.__str__()

class YOLOImage:
  def __init__(self, image_path, cropped_images):
    self.image_path = str(image_path)
    self.cropped_images = cropped_images

  def get_image(self):
    return Image.open(self.image_path)

  def get_caption(self):
    cls  =[]
    for cropped_image in self.cropped_images:
      cls.append(cropped_image.cls)

    unique_cls = set(cls)
    count_cls = {cls: cls.count(cls) for cls in unique_cls}

    count_string = " ".join(f"{count} {cls}," for cls, count in count_cls.items())
    return "this image contains " + count_string

  def __str__(self):
    return self.__repr__()

  def __repr__(self):
    cls  =[]
    for cropped_image in self.cropped_images:
      cls.append(cropped_image.cls)

    return f"YOLOImage(image={self.image_path}, cropped_images={cls})"

class ImageEmbedding:
  def __init__(self, image_path, embedding, cropped_image = None):
    self.image_path = image_path
    self.cropped_image = cropped_image
    self.embedding = embedding

裁剪图像类

CroppedImage 类表示从较大的父图像中裁剪出的图像的一部分。它使用父图像的路径、定义裁剪区域的边界框和类标签(例如“猫”或“狗”)进行初始化。此类包含显示裁剪图像并将其作为图像对象检索的方法。该显示方法允许在提供的轴上或通过创建新图形来可视化裁剪部分,使其适用于不同的用例。此外,还实现了 __str__ 和 __repr__ 方法,以便轻松且信息丰富地表示对象的字符串。

YOLO图像类

YOLOImage 类旨在处理使用 YOLO 对象检测模型处理的图像。它获取原始图像的路径和代表图像中检测到的对象的 CroppedImage 实例列表。该类提供了打开和显示完整图像以及生成总结图像中检测到的对象的标题的方法。标题方法聚合并计算裁剪图像中的唯一类标签,提供图像内容的简洁描述。此类对于管理和解释对象检测任务的结果特别有用。

图像嵌入类

ImageEmbedding 类具有图像及其关联的嵌入,它是图像特征的数字表示。可以使用图像的路径、嵌入向量以及可选的 CroppedImage 实例(如果嵌入对应于图像的特定裁剪部分)来初始化此类。 ImageEmbedding 类对于涉及图像相似性、分类和检索的任务至关重要,因为它提供了一种结构化方法来存储和访问图像数据及其计算特征。这种集成促进了高效的图像处理和机器学习工作流程。

裁剪每个图像并创建 YOLOImage 对象列表

yolo_images: list[YOLOImage]= []

names= model.names

for i, r in enumerate(results):
  crops:list[CroppedImage] = []
  boxes = r.boxes
  classes = r.boxes.cls
  for j, box in enumerate(r.boxes):
    box = tuple(box.xyxy.flatten().cpu().numpy())
    cropped_image = CroppedImage(parent = r.path, box = box, cls = names[classes[j].int().item()])
    crops.append(cropped_image)
  yolo_images.append(YOLOImage(image_path=r.path, cropped_images=crops))

使用 CLIP 嵌入图像

image_embeddings = []

for image in yolo_images:
  input = processor.image_processor(images= image.get_image(), return_tensors = 'pt')
  input.to(device)
  embeddings = image_model(pixel_values = input.pixel_values).image_embeds
  embeddings = embeddings/embeddings.norm(p=2, dim = -1, keepdim = True) # Normalize the embeddings
  image_embedding = ImageEmbedding(image_path = image.image_path, embedding = embeddings)
  image_embeddings.append(image_embedding)

  for cropped_image in image.cropped_images:
    input = processor.image_processor(images= cropped_image.get_cropped_image(), return_tensors = 'pt')
    input.to(device)
    embeddings = image_model(pixel_values = input.pixel_values).image_embeds
    embeddings = embeddings/embeddings.norm(p=2, dim = -1, keepdim = True) # Normalize the embeddings

    image_embedding = ImageEmbedding(image_path = image.image_path, embedding = embeddings, cropped_image = cropped_image)
    image_embeddings.append(image_embedding)

   **image_embeddings_tensor = torch.stack([image_embedding.embedding for image_embedding in image_embeddings]).squeeze()**

如果愿意,我们现在可以将这些图像嵌入存储在矢量数据库中。但在这个例子中,我们将仅使用内点积技术来检查相似性并检索图像。

检索

query = "image of a flowerpot"

text_embedding = processor.tokenizer(query, return_tensors="pt").to(device)
text_embedding = text_model(**text_embedding).text_embeds

similarities = (torch.matmul(text_embedding, image_embeddings_tensor.T)).flatten().detach().cpu().numpy()

# get the top 5 similar images
k = 5
top_k_indices = similarities.argsort()[-k:]

# Display the top 5 results
fig, ax = plt.subplots(2, 5, figsize=(20, 5))
for i, index in enumerate(top_k_indices):
  if image_embeddings[index].cropped_image is not None:
    image_embeddings[index].cropped_image.display(ax = ax[0][i])
  else:
  ax[0][i].imshow(Image.open(image_embeddings[index].image_path))
  ax[1][i].imshow(Image.open(image_embeddings[index].image_path))
  ax[0][i].axis('off')
  ax[1][i].axis('off')
  ax[1][i].set_title("Original Image")
plt.show()

Using YOLO with CLIP to improve Retrieval

Using YOLO with CLIP to improve Retrieval
Using YOLO with CLIP to improve Retrieval
Using YOLO with CLIP to improve Retrieval

您可以看到,我们甚至能够检索隐藏在背景中的小植物。有时它也会拉出原始图像作为结果,因为我们也嵌入了它。

这是一项非常强大的技术。您还可以微调您自己的图像的检测和嵌入模型,并进一步提高性能。

一个缺点是我们必须对检测到的所有对象运行 CLIP 模型。缓解这种情况的一种方法是限制 YOLO 生产的盒子数量。

您可以通过此链接查看 Colab 上的代码。

Using YOLO with CLIP to improve Retrieval


想要连接吗?

?我的网站

?我的推特

?我的 LinkedIn

以上是使用 YOLO 和 CLIP 来改进检索的详细内容。更多信息请关注PHP中文网其他相关文章!

声明:
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系admin@php.cn