ホームページ  >  記事  >  バックエンド開発  >  ROBOFLOW - Python を使用したトレーニングとテスト

ROBOFLOW - Python を使用したトレーニングとテスト

WBOY
WBOYオリジナル
2024-08-27 06:01:32708ブラウズ

Roboflow は、物体検出 AI で使用する画像に注釈を付けるためのプラットフォームです。

私はこのプラットフォームを C2SMR c2smr.fr (海難救助のためのコンピューター ビジョン協会) に使用しています。

この記事では、このプラットフォームを使用し、Python でモデルをトレーニングする方法を説明します。

私の github でさらにサンプル コードを見つけることができます: https://github.com/C2SMR/detector


I - データセット

データセットを作成するには、https://app.roboflow.com/ にアクセスし、次の画像に示すように画像に注釈を付け始めます。

この例では、将来の画像内での水泳選手の位置を予測するために、すべての水泳選手を迂回させます。
良好な結果を得るには、すべてのスイマーをトリミングし、オブジェクトの直後に境界ボックスを配置してオブジェクトを正しく囲みます。

ROBOFLOW - train & test with python

パブリック roboflow データセットはすでに使用できます。このためには https://universe.roboflow.com/

を確認してください。

II - トレーニング

トレーニング段階では roboflow を直接使用できますが、3 回目以降は料金が発生するため、ラップトップで行う方法を説明します。

最初のステップは、データセットをインポートすることです。これを行うには、Roboflow ライブラリをインポートします。

pip install roboflow

モデルを作成するには、Ultralytics ライブラリでインポートできる YOLO アルゴリズムを使用する必要があります。

pip install ultralytics

私のスクリプトでは、次のコマンドを使用します:

py train.py api-key project-workspace project-name project-version nb-epoch size_model

次のものを入手する必要があります:

  • アクセスキー
  • ワークスペース
  • roboflow プロジェクト名
  • プロジェクト データセットのバージョン
  • モデルをトレーニングするエポック数
  • ニューラルネットワークのサイズ

最初に、スクリプトはトレーニングを容易にするために、ワークアウト前のデータを含むデフォルトの yolo ウェイトである yolov8-obb​​.pt をダウンロードします。

import sys
import os
import random
from roboflow import Roboflow
from ultralytics import YOLO
import yaml
import time


class Main:
    rf: Roboflow
    project: object
    dataset: object
    model: object
    results: object
    model_size: str

    def __init__(self):
        self.model_size = sys.argv[6]
        self.import_dataset()
        self.train()

    def import_dataset(self):
        self.rf = Roboflow(api_key=sys.argv[1])
        self.project = self.rf.workspace(sys.argv[2]).project(sys.argv[3])
        self.dataset = self.project.version(sys.argv[4]).download("yolov8-obb")

        with open(f'{self.dataset.location}/data.yaml', 'r') as file:
            data = yaml.safe_load(file)

        data['path'] = self.dataset.location

        with open(f'{self.dataset.location}/data.yaml', 'w') as file:
            yaml.dump(data, file, sort_keys=False)

    def train(self):
        list_of_models = ["n", "s", "m", "l", "x"]
        if self.model_size != "ALL" and self.model_size in list_of_models:

            self.model = YOLO(f"yolov8{self.model_size}-obb.pt")

            self.results = self.model.train(data=f"{self.dataset.location}/"
                                                 f"yolov8-obb.yaml",
                                            epochs=int(sys.argv[5]), imgsz=640)



        elif self.model_size == "ALL":
            for model_size in list_of_models:
                self.model = YOLO(f"yolov8{model_size}.pt")

                self.results = self.model.train(data=f"{self.dataset.location}"
                                                     f"/yolov8-obb.yaml",
                                                epochs=int(sys.argv[5]),
                                                imgsz=640)



        else:
            print("Invalid model size")



if __name__ == '__main__':
    Main()

III - ディスプレイ

モデルをトレーニングした後、重みに対応するファイル best.py と last.py を取得します。

ultralytics ライブラリを使用すると、YOLO をインポートし、体重をロードしてからテストビデオをロードすることもできます。
この例では、追跡機能を使用して各水泳選手の ID を取得しています。

import cv2
from ultralytics import YOLO
import sys


def main():
    cap = cv2.VideoCapture(sys.argv[1])

    model = YOLO(sys.argv[2])

    while True:
        ret, frame = cap.read()
        results = model.track(frame, persist=True)
        res_plotted = results[0].plot()
        cv2.imshow("frame", res_plotted)

        if cv2.waitKey(1) == 27:
            break

    cap.release()
    cv2.destroyAllWindows()


if __name__ == "__main__":
    main()

予測を分析するには、次のようにモデル json を取得できます。

 results = model.track(frame, persist=True)
 results_json = json.loads(results[0].tojson())

以上がROBOFLOW - Python を使用したトレーニングとテストの詳細内容です。詳細については、PHP 中国語 Web サイトの他の関連記事を参照してください。

声明:
この記事の内容はネチズンが自主的に寄稿したものであり、著作権は原著者に帰属します。このサイトは、それに相当する法的責任を負いません。盗作または侵害の疑いのあるコンテンツを見つけた場合は、admin@php.cn までご連絡ください。