首頁  >  文章  >  後端開發  >  詳解使用PyTorch實現目標檢測與跟踪

詳解使用PyTorch實現目標檢測與跟踪

coldplay.xixi
coldplay.xixi轉載
2020-12-11 17:18:458858瀏覽

python教學欄位介紹使用PyTorch實現目標偵測與追蹤

詳解使用PyTorch實現目標檢測與跟踪

大量免費學習推薦,敬請造訪python教學(影片)

引言

在昨天的文章中,我們介紹瞭如何在PyTorch中使用您自己的圖像來訓練圖像分類器,然後使用它來進行圖像識別。本文將展示如何使用預先訓練的分類器來偵測影像中的多個對象,並在影片中追蹤它們。

影像中的目標偵測

目標偵測的演算法很多,YOLO跟SSD是現下最受歡迎的演算法。在本文中,我們將使用YOLOv3。在這裡我們不會詳細討論YOLO,如果想對它有更多了解,可以參考下面的鏈接哦~(https://pjreddie.com/darknet/yolo/)

下面讓我們開始吧,依然從導入模組開始:

from models import *
from utils import *
import os, sys, time, datetime, random
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.autograd import Variable
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image

然後載入預訓練的配置和權重,以及一些預先定義的值,包括:影像的尺寸、置信度閾值和非最大抑制閾值。

config_path='config/yolov3.cfg'
weights_path='config/yolov3.weights'
class_path='config/coco.names'
img_size=416
conf_thres=0.8
nms_thres=0.4
# Load model and weights
model = Darknet(config_path, img_size=img_size)
model.load_weights(weights_path)
model.cuda()
model.eval()
classes = utils.load_classes(class_path)
Tensor = torch.cuda.FloatTensor

下面的函數將傳回指定影像的偵測結果。

def detect_image(img):
    # scale and pad image
    ratio = min(img_size/img.size[0], img_size/img.size[1])
    imw = round(img.size[0] * ratio)
    imh = round(img.size[1] * ratio)
    img_transforms=transforms.Compose([transforms.Resize((imh,imw)),
         transforms.Pad((max(int((imh-imw)/2),0), 
              max(int((imw-imh)/2),0), max(int((imh-imw)/2),0),
              max(int((imw-imh)/2),0)), (128,128,128)),
         transforms.ToTensor(),
         ])
    # convert image to Tensor
    image_tensor = img_transforms(img).float()
    image_tensor = image_tensor.unsqueeze_(0)
    input_img = Variable(image_tensor.type(Tensor))
    # run inference on the model and get detections
    with torch.no_grad():
        detections = model(input_img)
        detections = utils.non_max_suppression(detections, 80, 
                        conf_thres, nms_thres)
    return detections[0]

最後,讓我們透過載入一個影像,取得偵測結果,然後用偵測到的物件周圍的包圍框來顯示它。並為不同的類別使用不同的顏色來區分。

# load image and get detections
img_path = "images/blueangels.jpg"
prev_time = time.time()
img = Image.open(img_path)
detections = detect_image(img)
inference_time = datetime.timedelta(seconds=time.time() - prev_time)
print ('Inference Time: %s' % (inference_time))
# Get bounding-box colors
cmap = plt.get_cmap('tab20b')
colors = [cmap(i) for i in np.linspace(0, 1, 20)]
img = np.array(img)
plt.figure()
fig, ax = plt.subplots(1, figsize=(12,9))
ax.imshow(img)
pad_x = max(img.shape[0] - img.shape[1], 0) * (img_size / max(img.shape))
pad_y = max(img.shape[1] - img.shape[0], 0) * (img_size / max(img.shape))
unpad_h = img_size - pad_y
unpad_w = img_size - pad_x
if detections is not None:
    unique_labels = detections[:, -1].cpu().unique()
    n_cls_preds = len(unique_labels)
    bbox_colors = random.sample(colors, n_cls_preds)
    # browse detections and draw bounding boxes
    for x1, y1, x2, y2, conf, cls_conf, cls_pred in detections:
        box_h = ((y2 - y1) / unpad_h) * img.shape[0]
        box_w = ((x2 - x1) / unpad_w) * img.shape[1]
        y1 = ((y1 - pad_y // 2) / unpad_h) * img.shape[0]
        x1 = ((x1 - pad_x // 2) / unpad_w) * img.shape[1]
        color = bbox_colors[int(np.where(
             unique_labels == int(cls_pred))[0])]
        bbox = patches.Rectangle((x1, y1), box_w, box_h,
             linewidth=2, edgecolor=color, facecolor='none')
        ax.add_patch(bbox)
        plt.text(x1, y1, s=classes[int(cls_pred)], 
                color='white', verticalalignment='top',
                bbox={'color': color, 'pad': 0})
plt.axis('off')
# save image
plt.savefig(img_path.replace(".jpg", "-det.jpg"),        
                  bbox_inches='tight', pad_inches=0.0)
plt.show()

以下是我們的一些偵測結果:

詳解使用PyTorch實現目標檢測與跟踪

詳解使用PyTorch實現目標檢測與跟踪

詳解使用PyTorch實現目標檢測與跟踪

影片中的目標追蹤

現在你知道如何在影像中偵測不同的物件。當你在一個影片中一幀一幀地看時,你會看到那些追蹤框在移動。但是如果這些視訊幀中有多個對象,你如何知道一個幀中的對像是否與前一個幀中的對象相同?這被稱為目標跟踪,它使用多次檢測來識別一個特定的物件。

有多種演算法可以做到這一點,在本文中決定使用SORT(Simple Online and Realtime Tracking),它使用Kalman濾波器預測先前識別的目標的軌跡,並將其與新的檢測結果進行匹配,非常方便且速度很快。

現在開始編寫程式碼,前3個程式碼片段將與單幅圖像偵測中的程式碼片段相同,因為它們處理的是在單幀上獲得 YOLO 檢測。差異在最後一部分出現,對於每個偵測,我們呼叫 Sort 物件的 Update 函數,以獲得對影像中物件的參考。因此,與前面範例中的常規檢測(包括邊界框的座標和類別預測)不同,我們將獲得追蹤的對象,除了上面的參數,還包括一個對象 ID。並且需要使用OpenCV來讀取影片並顯示視訊幀。

videopath = 'video/interp.mp4'
%pylab inline 
import cv2
from IPython.display import clear_output
cmap = plt.get_cmap('tab20b')
colors = [cmap(i)[:3] for i in np.linspace(0, 1, 20)]
# initialize Sort object and video capture
from sort import *
vid = cv2.VideoCapture(videopath)
mot_tracker = Sort()
#while(True):
for ii in range(40):
    ret, frame = vid.read()
    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    pilimg = Image.fromarray(frame)
    detections = detect_image(pilimg)
    img = np.array(pilimg)
    pad_x = max(img.shape[0] - img.shape[1], 0) * 
            (img_size / max(img.shape))
    pad_y = max(img.shape[1] - img.shape[0], 0) * 
            (img_size / max(img.shape))
    unpad_h = img_size - pad_y
    unpad_w = img_size - pad_x
    if detections is not None:
        tracked_objects = mot_tracker.update(detections.cpu())
        unique_labels = detections[:, -1].cpu().unique()
        n_cls_preds = len(unique_labels)
        for x1, y1, x2, y2, obj_id, cls_pred in tracked_objects:
            box_h = int(((y2 - y1) / unpad_h) * img.shape[0])
            box_w = int(((x2 - x1) / unpad_w) * img.shape[1])
            y1 = int(((y1 - pad_y // 2) / unpad_h) * img.shape[0])
            x1 = int(((x1 - pad_x // 2) / unpad_w) * img.shape[1])
            color = colors[int(obj_id) % len(colors)]
            color = [i * 255 for i in color]
            cls = classes[int(cls_pred)]
            cv2.rectangle(frame, (x1, y1), (x1+box_w, y1+box_h),
                         color, 4)
            cv2.rectangle(frame, (x1, y1-35), (x1+len(cls)*19+60,
                         y1), color, -1)
            cv2.putText(frame, cls + "-" + str(int(obj_id)), 
                        (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 
                        1, (255,255,255), 3)
    fig=figure(figsize=(12, 8))
    title("Video Stream")
    imshow(frame)
    show()
    clear_output(wait=True)

相關免費學習推薦:php程式設計(影片)

以上是詳解使用PyTorch實現目標檢測與跟踪的詳細內容。更多資訊請關注PHP中文網其他相關文章!

陳述:
本文轉載於:csdn.net。如有侵權,請聯絡admin@php.cn刪除