大量免費學習推薦,敬請造訪python教學(影片)
引言
在昨天的文章中,我們介紹瞭如何在PyTorch中使用您自己的圖像來訓練圖像分類器,然後使用它來進行圖像識別。本文將展示如何使用預先訓練的分類器來偵測影像中的多個對象,並在影片中追蹤它們。
影像中的目標偵測
目標偵測的演算法很多,YOLO跟SSD是現下最受歡迎的演算法。在本文中,我們將使用YOLOv3。在這裡我們不會詳細討論YOLO,如果想對它有更多了解,可以參考下面的鏈接哦~(https://pjreddie.com/darknet/yolo/)
下面讓我們開始吧,依然從導入模組開始:
from models import * from utils import * import os, sys, time, datetime, random import torch from torch.utils.data import DataLoader from torchvision import datasets, transforms from torch.autograd import Variable import matplotlib.pyplot as plt import matplotlib.patches as patches from PIL import Image
然後載入預訓練的配置和權重,以及一些預先定義的值,包括:影像的尺寸、置信度閾值和非最大抑制閾值。
config_path='config/yolov3.cfg' weights_path='config/yolov3.weights' class_path='config/coco.names' img_size=416 conf_thres=0.8 nms_thres=0.4 # Load model and weights model = Darknet(config_path, img_size=img_size) model.load_weights(weights_path) model.cuda() model.eval() classes = utils.load_classes(class_path) Tensor = torch.cuda.FloatTensor
下面的函數將傳回指定影像的偵測結果。
def detect_image(img): # scale and pad image ratio = min(img_size/img.size[0], img_size/img.size[1]) imw = round(img.size[0] * ratio) imh = round(img.size[1] * ratio) img_transforms=transforms.Compose([transforms.Resize((imh,imw)), transforms.Pad((max(int((imh-imw)/2),0), max(int((imw-imh)/2),0), max(int((imh-imw)/2),0), max(int((imw-imh)/2),0)), (128,128,128)), transforms.ToTensor(), ]) # convert image to Tensor image_tensor = img_transforms(img).float() image_tensor = image_tensor.unsqueeze_(0) input_img = Variable(image_tensor.type(Tensor)) # run inference on the model and get detections with torch.no_grad(): detections = model(input_img) detections = utils.non_max_suppression(detections, 80, conf_thres, nms_thres) return detections[0]
最後,讓我們透過載入一個影像,取得偵測結果,然後用偵測到的物件周圍的包圍框來顯示它。並為不同的類別使用不同的顏色來區分。
# load image and get detections img_path = "images/blueangels.jpg" prev_time = time.time() img = Image.open(img_path) detections = detect_image(img) inference_time = datetime.timedelta(seconds=time.time() - prev_time) print ('Inference Time: %s' % (inference_time)) # Get bounding-box colors cmap = plt.get_cmap('tab20b') colors = [cmap(i) for i in np.linspace(0, 1, 20)] img = np.array(img) plt.figure() fig, ax = plt.subplots(1, figsize=(12,9)) ax.imshow(img) pad_x = max(img.shape[0] - img.shape[1], 0) * (img_size / max(img.shape)) pad_y = max(img.shape[1] - img.shape[0], 0) * (img_size / max(img.shape)) unpad_h = img_size - pad_y unpad_w = img_size - pad_x if detections is not None: unique_labels = detections[:, -1].cpu().unique() n_cls_preds = len(unique_labels) bbox_colors = random.sample(colors, n_cls_preds) # browse detections and draw bounding boxes for x1, y1, x2, y2, conf, cls_conf, cls_pred in detections: box_h = ((y2 - y1) / unpad_h) * img.shape[0] box_w = ((x2 - x1) / unpad_w) * img.shape[1] y1 = ((y1 - pad_y // 2) / unpad_h) * img.shape[0] x1 = ((x1 - pad_x // 2) / unpad_w) * img.shape[1] color = bbox_colors[int(np.where( unique_labels == int(cls_pred))[0])] bbox = patches.Rectangle((x1, y1), box_w, box_h, linewidth=2, edgecolor=color, facecolor='none') ax.add_patch(bbox) plt.text(x1, y1, s=classes[int(cls_pred)], color='white', verticalalignment='top', bbox={'color': color, 'pad': 0}) plt.axis('off') # save image plt.savefig(img_path.replace(".jpg", "-det.jpg"), bbox_inches='tight', pad_inches=0.0) plt.show()
以下是我們的一些偵測結果:
影片中的目標追蹤
現在你知道如何在影像中偵測不同的物件。當你在一個影片中一幀一幀地看時,你會看到那些追蹤框在移動。但是如果這些視訊幀中有多個對象,你如何知道一個幀中的對像是否與前一個幀中的對象相同?這被稱為目標跟踪,它使用多次檢測來識別一個特定的物件。
有多種演算法可以做到這一點,在本文中決定使用SORT(Simple Online and Realtime Tracking),它使用Kalman濾波器預測先前識別的目標的軌跡,並將其與新的檢測結果進行匹配,非常方便且速度很快。
現在開始編寫程式碼,前3個程式碼片段將與單幅圖像偵測中的程式碼片段相同,因為它們處理的是在單幀上獲得 YOLO 檢測。差異在最後一部分出現,對於每個偵測,我們呼叫 Sort 物件的 Update 函數,以獲得對影像中物件的參考。因此,與前面範例中的常規檢測(包括邊界框的座標和類別預測)不同,我們將獲得追蹤的對象,除了上面的參數,還包括一個對象 ID。並且需要使用OpenCV來讀取影片並顯示視訊幀。
videopath = 'video/interp.mp4' %pylab inline import cv2 from IPython.display import clear_output cmap = plt.get_cmap('tab20b') colors = [cmap(i)[:3] for i in np.linspace(0, 1, 20)] # initialize Sort object and video capture from sort import * vid = cv2.VideoCapture(videopath) mot_tracker = Sort() #while(True): for ii in range(40): ret, frame = vid.read() frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) pilimg = Image.fromarray(frame) detections = detect_image(pilimg) img = np.array(pilimg) pad_x = max(img.shape[0] - img.shape[1], 0) * (img_size / max(img.shape)) pad_y = max(img.shape[1] - img.shape[0], 0) * (img_size / max(img.shape)) unpad_h = img_size - pad_y unpad_w = img_size - pad_x if detections is not None: tracked_objects = mot_tracker.update(detections.cpu()) unique_labels = detections[:, -1].cpu().unique() n_cls_preds = len(unique_labels) for x1, y1, x2, y2, obj_id, cls_pred in tracked_objects: box_h = int(((y2 - y1) / unpad_h) * img.shape[0]) box_w = int(((x2 - x1) / unpad_w) * img.shape[1]) y1 = int(((y1 - pad_y // 2) / unpad_h) * img.shape[0]) x1 = int(((x1 - pad_x // 2) / unpad_w) * img.shape[1]) color = colors[int(obj_id) % len(colors)] color = [i * 255 for i in color] cls = classes[int(cls_pred)] cv2.rectangle(frame, (x1, y1), (x1+box_w, y1+box_h), color, 4) cv2.rectangle(frame, (x1, y1-35), (x1+len(cls)*19+60, y1), color, -1) cv2.putText(frame, cls + "-" + str(int(obj_id)), (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 1, (255,255,255), 3) fig=figure(figsize=(12, 8)) title("Video Stream") imshow(frame) show() clear_output(wait=True)
相關免費學習推薦:php程式設計(影片)
以上是詳解使用PyTorch實現目標檢測與跟踪的詳細內容。更多資訊請關注PHP中文網其他相關文章!