Maison >développement back-end >Tutoriel Python >ROBOFLOW - entraîner et tester avec python
Roboflow est une plate-forme d'annotation d'images à utiliser dans l'IA de détection d'objets.
J'utilise cette plateforme pour le C2SMR c2smr.fr, mon association de vision par ordinateur pour le sauvetage maritime.
Dans cet article, je vous montre comment utiliser cette plateforme et entraîner votre modèle avec python.
Vous pouvez trouver plus d'exemples de code sur mon github : https://github.com/C2SMR/detector
Pour créer votre ensemble de données, accédez à https://app.roboflow.com/ et commencez à annoter votre image comme indiqué dans l'image suivante.
Dans cet exemple, je détourne tous les nageurs pour prédire leur position dans les futures images.
Pour obtenir un bon résultat, recadrez tous les nageurs et placez le cadre de sélection juste après l'objet pour l'entourer correctement.
Vous pouvez déjà utiliser un ensemble de données Roboflow public, pour cette vérification https://universe.roboflow.com/
Pour la phase de formation, vous pouvez utiliser roboflow directement, mais la troisième fois vous devrez payer, c'est pourquoi je vous montre comment le faire avec votre ordinateur portable.
La première étape consiste à importer votre ensemble de données. Pour ce faire, vous pouvez importer la bibliothèque Roboflow.
pip install roboflow
Pour créer un modèle, vous devez utiliser l'algorithme YOLO, que vous pouvez importer avec la bibliothèque ultralytics.
pip install ultralytics
Dans mon script, j'utilise la commande suivante :
py train.py api-key project-workspace project-name project-version nb-epoch size_model
Vous devez obtenir :
Initialement, le script télécharge yolov8-obb.pt, le poids yolo par défaut avec les données pré-entraînement, pour faciliter l'entraînement.
import sys import os import random from roboflow import Roboflow from ultralytics import YOLO import yaml import time class Main: rf: Roboflow project: object dataset: object model: object results: object model_size: str def __init__(self): self.model_size = sys.argv[6] self.import_dataset() self.train() def import_dataset(self): self.rf = Roboflow(api_key=sys.argv[1]) self.project = self.rf.workspace(sys.argv[2]).project(sys.argv[3]) self.dataset = self.project.version(sys.argv[4]).download("yolov8-obb") with open(f'{self.dataset.location}/data.yaml', 'r') as file: data = yaml.safe_load(file) data['path'] = self.dataset.location with open(f'{self.dataset.location}/data.yaml', 'w') as file: yaml.dump(data, file, sort_keys=False) def train(self): list_of_models = ["n", "s", "m", "l", "x"] if self.model_size != "ALL" and self.model_size in list_of_models: self.model = YOLO(f"yolov8{self.model_size}-obb.pt") self.results = self.model.train(data=f"{self.dataset.location}/" f"yolov8-obb.yaml", epochs=int(sys.argv[5]), imgsz=640) elif self.model_size == "ALL": for model_size in list_of_models: self.model = YOLO(f"yolov8{model_size}.pt") self.results = self.model.train(data=f"{self.dataset.location}" f"/yolov8-obb.yaml", epochs=int(sys.argv[5]), imgsz=640) else: print("Invalid model size") if __name__ == '__main__': Main()
Après avoir entraîné le modèle, vous obtenez les fichiers best.py et last.py, qui correspondent au poids.
Avec la bibliothèque ultralytics, vous pouvez également importer YOLO et charger votre poids puis votre vidéo de test.
Dans cet exemple, j'utilise la fonction de suivi pour obtenir un identifiant pour chaque nageur.
import cv2 from ultralytics import YOLO import sys def main(): cap = cv2.VideoCapture(sys.argv[1]) model = YOLO(sys.argv[2]) while True: ret, frame = cap.read() results = model.track(frame, persist=True) res_plotted = results[0].plot() cv2.imshow("frame", res_plotted) if cv2.waitKey(1) == 27: break cap.release() cv2.destroyAllWindows() if __name__ == "__main__": main()
Pour analyser la prédiction, vous pouvez obtenir le modèle json comme suit.
results = model.track(frame, persist=True) results_json = json.loads(results[0].tojson())
Ce qui précède est le contenu détaillé de. pour plus d'informations, suivez d'autres articles connexes sur le site Web de PHP en chinois!