Maison  >  Article  >  développement back-end  >  ROBOFLOW - entraîner et tester avec python

ROBOFLOW - entraîner et tester avec python

WBOY
WBOYoriginal
2024-08-27 06:01:32710parcourir

Roboflow est une plate-forme d'annotation d'images à utiliser dans l'IA de détection d'objets.

J'utilise cette plateforme pour le C2SMR c2smr.fr, mon association de vision par ordinateur pour le sauvetage maritime.

Dans cet article, je vous montre comment utiliser cette plateforme et entraîner votre modèle avec python.

Vous pouvez trouver plus d'exemples de code sur mon github : https://github.com/C2SMR/detector


I - Ensemble de données

Pour créer votre ensemble de données, accédez à https://app.roboflow.com/ et commencez à annoter votre image comme indiqué dans l'image suivante.

Dans cet exemple, je détourne tous les nageurs pour prédire leur position dans les futures images.
Pour obtenir un bon résultat, recadrez tous les nageurs et placez le cadre de sélection juste après l'objet pour l'entourer correctement.

ROBOFLOW - train & test with python

Vous pouvez déjà utiliser un ensemble de données Roboflow public, pour cette vérification https://universe.roboflow.com/

II - Formation

Pour la phase de formation, vous pouvez utiliser roboflow directement, mais la troisième fois vous devrez payer, c'est pourquoi je vous montre comment le faire avec votre ordinateur portable.

La première étape consiste à importer votre ensemble de données. Pour ce faire, vous pouvez importer la bibliothèque Roboflow.

pip install roboflow

Pour créer un modèle, vous devez utiliser l'algorithme YOLO, que vous pouvez importer avec la bibliothèque ultralytics.

pip install ultralytics

Dans mon script, j'utilise la commande suivante :

py train.py api-key project-workspace project-name project-version nb-epoch size_model

Vous devez obtenir :

  • la clé d'accès
  • espace de travail
  • Nom du projet Roboflow
  • Version de l'ensemble de données du projet
  • nombre d'époques pour entraîner le modèle
  • taille du réseau neuronal

Initialement, le script télécharge yolov8-obb.pt, le poids yolo par défaut avec les données pré-entraînement, pour faciliter l'entraînement.

import sys
import os
import random
from roboflow import Roboflow
from ultralytics import YOLO
import yaml
import time


class Main:
    rf: Roboflow
    project: object
    dataset: object
    model: object
    results: object
    model_size: str

    def __init__(self):
        self.model_size = sys.argv[6]
        self.import_dataset()
        self.train()

    def import_dataset(self):
        self.rf = Roboflow(api_key=sys.argv[1])
        self.project = self.rf.workspace(sys.argv[2]).project(sys.argv[3])
        self.dataset = self.project.version(sys.argv[4]).download("yolov8-obb")

        with open(f'{self.dataset.location}/data.yaml', 'r') as file:
            data = yaml.safe_load(file)

        data['path'] = self.dataset.location

        with open(f'{self.dataset.location}/data.yaml', 'w') as file:
            yaml.dump(data, file, sort_keys=False)

    def train(self):
        list_of_models = ["n", "s", "m", "l", "x"]
        if self.model_size != "ALL" and self.model_size in list_of_models:

            self.model = YOLO(f"yolov8{self.model_size}-obb.pt")

            self.results = self.model.train(data=f"{self.dataset.location}/"
                                                 f"yolov8-obb.yaml",
                                            epochs=int(sys.argv[5]), imgsz=640)



        elif self.model_size == "ALL":
            for model_size in list_of_models:
                self.model = YOLO(f"yolov8{model_size}.pt")

                self.results = self.model.train(data=f"{self.dataset.location}"
                                                     f"/yolov8-obb.yaml",
                                                epochs=int(sys.argv[5]),
                                                imgsz=640)



        else:
            print("Invalid model size")



if __name__ == '__main__':
    Main()

III - Affichage

Après avoir entraîné le modèle, vous obtenez les fichiers best.py et last.py, qui correspondent au poids.

Avec la bibliothèque ultralytics, vous pouvez également importer YOLO et charger votre poids puis votre vidéo de test.
Dans cet exemple, j'utilise la fonction de suivi pour obtenir un identifiant pour chaque nageur.

import cv2
from ultralytics import YOLO
import sys


def main():
    cap = cv2.VideoCapture(sys.argv[1])

    model = YOLO(sys.argv[2])

    while True:
        ret, frame = cap.read()
        results = model.track(frame, persist=True)
        res_plotted = results[0].plot()
        cv2.imshow("frame", res_plotted)

        if cv2.waitKey(1) == 27:
            break

    cap.release()
    cv2.destroyAllWindows()


if __name__ == "__main__":
    main()

Pour analyser la prédiction, vous pouvez obtenir le modèle json comme suit.

 results = model.track(frame, persist=True)
 results_json = json.loads(results[0].tojson())

Ce qui précède est le contenu détaillé de. pour plus d'informations, suivez d'autres articles connexes sur le site Web de PHP en chinois!

Déclaration:
Le contenu de cet article est volontairement contribué par les internautes et les droits d'auteur appartiennent à l'auteur original. Ce site n'assume aucune responsabilité légale correspondante. Si vous trouvez un contenu suspecté de plagiat ou de contrefaçon, veuillez contacter admin@php.cn