Heim  >  Artikel  >  Backend-Entwicklung  >  Detaillierte Erläuterung der Verwendung von PyTorch zur Implementierung der Zielerkennung und -verfolgung

Detaillierte Erläuterung der Verwendung von PyTorch zur Implementierung der Zielerkennung und -verfolgung

coldplay.xixi
coldplay.xixinach vorne
2020-12-11 17:18:458866Durchsuche

Python-TutorialSpalte stellt die Verwendung von PyTorch zur Zielerkennung und -verfolgung vor

Einführung Detaillierte Erläuterung der Verwendung von PyTorch zur Implementierung der Zielerkennung und -verfolgung

Im gestrigen Artikel haben wir erläutert, wie Sie Ihre eigenen Bilder in PyTorch verwenden, um einen Bildklassifikator zu trainieren und ihn dann für die Bilderkennung zu verwenden. In diesem Artikel wird gezeigt, wie Sie mit einem vorab trainierten Klassifikator mehrere Objekte in Bildern erkennen und in Videos verfolgen können. Zielerkennung in BildernEs gibt viele Algorithmen zur Zielerkennung, YOLO und SSD sind derzeit die beliebtesten Algorithmen. In diesem Artikel verwenden wir YOLOv3. Wir werden YOLO hier nicht im Detail besprechen. Wenn Sie mehr darüber erfahren möchten, können Sie auf den Link unten verweisen ~ (https://pjreddie.com/darknet/yolo/)Lass uns beginnen und das Modul noch importieren startet:

from models import *
from utils import *
import os, sys, time, datetime, random
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.autograd import Variable
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image
Lädt dann die vorab trainierte Konfiguration und die Gewichte sowie einige vordefinierte Werte, darunter: die Abmessungen des Bildes, den Konfidenzschwellenwert und den nicht maximalen Unterdrückungsschwellenwert.

config_path='config/yolov3.cfg'
weights_path='config/yolov3.weights'
class_path='config/coco.names'
img_size=416
conf_thres=0.8
nms_thres=0.4
# Load model and weights
model = Darknet(config_path, img_size=img_size)
model.load_weights(weights_path)
model.cuda()
model.eval()
classes = utils.load_classes(class_path)
Tensor = torch.cuda.FloatTensor
Die folgende Funktion gibt die Erkennungsergebnisse für das angegebene Bild zurück.
def detect_image(img):
    # scale and pad image
    ratio = min(img_size/img.size[0], img_size/img.size[1])
    imw = round(img.size[0] * ratio)
    imh = round(img.size[1] * ratio)
    img_transforms=transforms.Compose([transforms.Resize((imh,imw)),
         transforms.Pad((max(int((imh-imw)/2),0), 
              max(int((imw-imh)/2),0), max(int((imh-imw)/2),0),
              max(int((imw-imh)/2),0)), (128,128,128)),
         transforms.ToTensor(),
         ])
    # convert image to Tensor
    image_tensor = img_transforms(img).float()
    image_tensor = image_tensor.unsqueeze_(0)
    input_img = Variable(image_tensor.type(Tensor))
    # run inference on the model and get detections
    with torch.no_grad():
        detections = model(input_img)
        detections = utils.non_max_suppression(detections, 80, 
                        conf_thres, nms_thres)
    return detections[0]

Abschließend erhalten wir die Erkennungsergebnisse, indem wir ein Bild laden und es dann mit einem Begrenzungsrahmen um das erkannte Objekt anzeigen. Und verwenden Sie zur Unterscheidung unterschiedliche Farben für verschiedene Klassen.

# load image and get detections
img_path = "images/blueangels.jpg"
prev_time = time.time()
img = Image.open(img_path)
detections = detect_image(img)
inference_time = datetime.timedelta(seconds=time.time() - prev_time)
print ('Inference Time: %s' % (inference_time))
# Get bounding-box colors
cmap = plt.get_cmap('tab20b')
colors = [cmap(i) for i in np.linspace(0, 1, 20)]
img = np.array(img)
plt.figure()
fig, ax = plt.subplots(1, figsize=(12,9))
ax.imshow(img)
pad_x = max(img.shape[0] - img.shape[1], 0) * (img_size / max(img.shape))
pad_y = max(img.shape[1] - img.shape[0], 0) * (img_size / max(img.shape))
unpad_h = img_size - pad_y
unpad_w = img_size - pad_x
if detections is not None:
    unique_labels = detections[:, -1].cpu().unique()
    n_cls_preds = len(unique_labels)
    bbox_colors = random.sample(colors, n_cls_preds)
    # browse detections and draw bounding boxes
    for x1, y1, x2, y2, conf, cls_conf, cls_pred in detections:
        box_h = ((y2 - y1) / unpad_h) * img.shape[0]
        box_w = ((x2 - x1) / unpad_w) * img.shape[1]
        y1 = ((y1 - pad_y // 2) / unpad_h) * img.shape[0]
        x1 = ((x1 - pad_x // 2) / unpad_w) * img.shape[1]
        color = bbox_colors[int(np.where(
             unique_labels == int(cls_pred))[0])]
        bbox = patches.Rectangle((x1, y1), box_w, box_h,
             linewidth=2, edgecolor=color, facecolor='none')
        ax.add_patch(bbox)
        plt.text(x1, y1, s=classes[int(cls_pred)], 
                color='white', verticalalignment='top',
                bbox={'color': color, 'pad': 0})
plt.axis('off')
# save image
plt.savefig(img_path.replace(".jpg", "-det.jpg"),        
                  bbox_inches='tight', pad_inches=0.0)
plt.show()

Hier sind einige unserer Erkennungsergebnisse:

Objektverfolgung im Video

Jetzt wissen Sie, wie Sie verschiedene Objekte in Bildern erkennen. Wenn Sie es Bild für Bild in einem Video betrachten, sehen Sie, wie sich diese Tracking-Boxen bewegen. Wenn diese Videobilder jedoch mehrere Objekte enthalten, woher wissen Sie dann, ob das Objekt in einem Bild mit dem Objekt im vorherigen Bild übereinstimmt? Dies wird als Objektverfolgung bezeichnet und verwendet mehrere Erkennungen, um ein bestimmtes Objekt zu identifizieren.

Hierfür gibt es mehrere Algorithmen. In diesem Artikel habe ich mich für die Verwendung von SORT (Simple Online and Realtime Tracking) entschieden, das einen Kalman-Filter verwendet, um die Flugbahn eines zuvor identifizierten Ziels vorherzusagen und sie mit neuen Erkennungsergebnissen abzugleichen. Sehr praktisch und schnell.

Detaillierte Erläuterung der Verwendung von PyTorch zur Implementierung der Zielerkennung und -verfolgungJetzt beginnen wir mit dem Schreiben des Codes. Die ersten drei Codefragmente sind die gleichen wie bei der Einzelbilderkennung, da sie sich mit der YOLO-Erkennung in einem einzelnen Bild befassen. Der Unterschied liegt im letzten Teil: Für jede Erkennung rufen wir die Update-Funktion des Sortierobjekts auf, um einen Verweis auf das Objekt im Bild zu erhalten. Im Gegensatz zur regulären Erkennung im vorherigen Beispiel (einschließlich der Koordinaten des Begrenzungsrahmens und der Klassenvorhersage) erhalten wir also das verfolgte Objekt, einschließlich zusätzlich zu den oben genannten Parametern, einer Objekt-ID. Und Sie müssen OpenCV verwenden, um das Video zu lesen und die Videobilder anzuzeigen.

videopath = 'video/interp.mp4'
%pylab inline 
import cv2
from IPython.display import clear_output
cmap = plt.get_cmap('tab20b')
colors = [cmap(i)[:3] for i in np.linspace(0, 1, 20)]
# initialize Sort object and video capture
from sort import *
vid = cv2.VideoCapture(videopath)
mot_tracker = Sort()
#while(True):
for ii in range(40):
    ret, frame = vid.read()
    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    pilimg = Image.fromarray(frame)
    detections = detect_image(pilimg)
    img = np.array(pilimg)
    pad_x = max(img.shape[0] - img.shape[1], 0) * 
            (img_size / max(img.shape))
    pad_y = max(img.shape[1] - img.shape[0], 0) * 
            (img_size / max(img.shape))
    unpad_h = img_size - pad_y
    unpad_w = img_size - pad_x
    if detections is not None:
        tracked_objects = mot_tracker.update(detections.cpu())
        unique_labels = detections[:, -1].cpu().unique()
        n_cls_preds = len(unique_labels)
        for x1, y1, x2, y2, obj_id, cls_pred in tracked_objects:
            box_h = int(((y2 - y1) / unpad_h) * img.shape[0])
            box_w = int(((x2 - x1) / unpad_w) * img.shape[1])
            y1 = int(((y1 - pad_y // 2) / unpad_h) * img.shape[0])
            x1 = int(((x1 - pad_x // 2) / unpad_w) * img.shape[1])
            color = colors[int(obj_id) % len(colors)]
            color = [i * 255 for i in color]
            cls = classes[int(cls_pred)]
            cv2.rectangle(frame, (x1, y1), (x1+box_w, y1+box_h),
                         color, 4)
            cv2.rectangle(frame, (x1, y1-35), (x1+len(cls)*19+60,
                         y1), color, -1)
            cv2.putText(frame, cls + "-" + str(int(obj_id)), 
                        (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 
                        1, (255,255,255), 3)
    fig=figure(figsize=(12, 8))
    title("Video Stream")
    imshow(frame)
    show()
    clear_output(wait=True)

Detaillierte Erläuterung der Verwendung von PyTorch zur Implementierung der Zielerkennung und -verfolgungVerwandte kostenlose Lernempfehlungen:

php-Programmierung

(Video)Detaillierte Erläuterung der Verwendung von PyTorch zur Implementierung der Zielerkennung und -verfolgung

Das obige ist der detaillierte Inhalt vonDetaillierte Erläuterung der Verwendung von PyTorch zur Implementierung der Zielerkennung und -verfolgung. Für weitere Informationen folgen Sie bitte anderen verwandten Artikeln auf der PHP chinesischen Website!

Stellungnahme:
Dieser Artikel ist reproduziert unter:csdn.net. Bei Verstößen wenden Sie sich bitte an admin@php.cn löschen