検索
ホームページバックエンド開発Python チュートリアルPythonによるK最近傍アルゴリズムの原理と実装(ソースコード添付)

PythonによるK最近傍アルゴリズムの原理と実装(ソースコード添付)

Oct 27, 2018 pm 02:21 PM
pythonscikit-learn機械学習アルゴリズム

この記事で紹介する内容は、Python での K 最近傍アルゴリズムの原理と実装に関するものです (ソース コードが添付されています)。一定の参考価値があります。必要な友人が参考にしていただければ幸いです。あなたに役立ちます。

k 最近傍アルゴリズムは、異なる特徴値間の距離を測定することによって分類を実行します。

k-最近傍アルゴリズムの原理

ラベル付きのトレーニング サンプル セットの場合、ラベルなしの新しいデータを入力した後、新しいデータの各特徴がサンプルに集中します。データに対応する特徴が比較され、アルゴリズムに従ってサンプル データセット内の最も類似した上位 k 個のデータが選択され、k 個の最も類似したデータの中で最も多く出現した分類が新しいデータの分類として選択されます。 。

k-最近傍アルゴリズムの実装

ここでは 1 つの新しいデータの予測のみを示し、同時に複数の新しいデータを予測する場合は後ほど説明します。

トレーニング サンプル セット X_train (X_train.shape=(10, 2))、対応するマーク y_train (y_train.shape=(10,)、0、1 を含む) があると仮定し、matplotlib を使用します。描画する pyplot 次のように表されます (緑の点はマーク 0 を表し、赤の点はマーク 1 を表します):

PythonによるK最近傍アルゴリズムの原理と実装(ソースコード添付)

新しいデータがあります: x (x = np.array([3.18557125, 6.03119673]))、描画表現は次のとおりです (青い点):

PythonによるK最近傍アルゴリズムの原理と実装(ソースコード添付)

まず、次を使用します。オイラー距離 この式は、x から各サンプルまでの距離を計算します。元のデータの順序に影響します:

import math

distances = [math.sqrt(np.sum((x_train - x) ** 2)) for x_train in X_train]

3 番目に、k 個の最も近いサンプルに対応するマークを取得します: <pre class="brush:php;toolbar:false">import numpy as np nearest = np.argsort(distances)</pre>最後に、k 個の最も近いサンプルに対応するマークを取得します。統計を作成し、マークの割合が最も大きい予測分類 (x です) を見つけます。この例の予測分類は 0 です。

topK_y = [y_train[i] for i in nearest[:k]]

上記のコードを次のようにカプセル化します。メソッド:

from collections import Counter

votes = Counter(topK_y)
votes.most_common(1)[0][0]

Scikit Learn の K 近傍アルゴリズム

典型的な機械学習アルゴリズムのプロセスでは、トレーニング データ セットを使用してモデルをトレーニング (適合) します。機械学習アルゴリズムを通じて、このモデルを使用して入力サンプルの結果を予測します。

k 最近傍アルゴリズムは、モデルを持たない特殊なアルゴリズムですが、その学習データセットをモデルとみなします。これは、Scikit Learn でどのように処理されるかです。 PythonによるK最近傍アルゴリズムの原理と実装(ソースコード添付)

Scikit Learn の k 近傍アルゴリズムは、

Scikit Learn の k 近傍アルゴリズムは、neighbors モジュール内にあります。初期化時のパラメータ n_neighbors は、 6、これは上記です。 k:

import numpy as np
import math

from collections import Counter


def kNN(k, X_train, y_train, x):
    distances = [math.sqrt(np.sum((x_train - x) ** 2)) for x_train in X_train]
    nearest = np.argsort(distances)

    topK_y = [y_train[i] for i in nearest[:k]]
    votes = Counter(topK_y)
    return votes.most_common(1)[0][0]

fit()

メソッドは、トレーニング データ セットに基づいて分類子を「トレーニング」し、分類子自体を返します。

from sklearn.neighbors import KNeighborsClassifier

kNN_classifier = KNeighborsClassifier(n_neighbors=6)

predict( ) メソッドは入力の結果を予測します。このメソッドでは渡されるパラメーターの型が行列である必要があります。したがって、

reshape

操作は最初に x に対して実行されます。<pre class="brush:php;toolbar:false">kNN_classifier.fit(X_train, y_train)</pre>y_predict 値は 0 であり、以前に実装された kNN メソッドの結果と一致します。

Scikit Learn で KNeighborsClassifier 分類器を実装する

KNNClassifier クラスを定義し、そのコンストラクター メソッドは、予測中に選択された最も類似したデータを表すパラメーター k を渡します。番号:

X_predict = x.reshape(1, -1)
y_predict = kNN_classifier.predict(X_predict)

fit()

このメソッドは分類子をトレーニングし、分類子自体を返します:

class KNNClassifier:
    def __init__(self, k):
        self.k = k
        self._X_train = None
        self._y_train = None

predict() メソッドはデータですテスト対象に設定 予測を行うために、パラメーター X_predict のタイプは行列です。このメソッドは、リスト分析を使用して X_predict を走査し、テスト対象のデータごとに

_predict()

メソッドを 1 回呼び出します。 <pre class="brush:php;toolbar:false">def fit(self, X_train, y_train):     self._X_train = X_train     self._y_train = y_train     return self</pre>アルゴリズムの精度

モデルの問題

モデルはトレーニング サンプル セットを通じてトレーニングされましたが、このモデルがどれほど優れているかはわかりませんが、問題が 2 つあります。

モデルが悪い場合、予測結果は期待どおりではありません。同時に、実際の状況では、実際のラベルを取得してモデルをテストすることは困難です。

  1. モデルをトレーニングする場合、トレーニング サンプルにはすべてのマーカーが含まれているわけではありません。

  2. 最初の質問では、通常、サンプル セット内のデータの一定の割合 (20% など) がテスト データとして使用され、残りのデータがトレーニング データとして使用されます。

  3. Scikit Learn で提供される虹彩データを例に挙げます。これには 150 個のサンプルが含まれています。
def predict(self, X_predict):
    y_predict = [self._predict(x) for x in X_predict]
    return np.array(y_predict)

def _predict(self, x):
    distances = [math.sqrt(np.sum((x_train - x) ** 2))
                 for x_train in self._X_train]
    nearest = np.argsort(distances)

    topK_y = [self._y_train[i] for i in nearest[:self.k]]
    votes = Counter(topK_y)

    return votes.most_common(1)[0][0]

次に、サンプルを 20% のサンプル テスト データと 80% の比率トレーニング データに分割します。

import numpy as np
from sklearn import datasets

iris = datasets.load_iris()
X = iris.data
y = iris.target

X_train と y_train をモデルをトレーニングするためのトレーニング データとして使用し、X_test と y_test をモデルのテスト データとして使用します。モデルの精度の検証。

2 番目の質問では、Scikit Learn で提供されている虹彩データを例として、y マークの内容は次のとおりです。

array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])

发现0、1、2是以顺序存储的,在将样本划分为训练数据和测试数据过程中,如果训练数据中才对标记只包含0、1,这样的训练数据对于模型的训练将是致命的。以此,应将样本数据先进行随机处理。

np.random.permutation() 方法传入一个整数 n,会返回一个区间在 [0, n) 且随机排序的一维数组。将 X 的长度作为参数传入,返回 X 索引的随机数组:

shuffle_indexes = np.random.permutation(len(X))

将随机化的索引数组分为训练数据的索引与测试数据的索引两部分:

test_ratio = 0.2
test_size = int(len(X) * test_ratio)

test_indexes = shuffle_indexes[:test_size]
train_indexes = shuffle_indexes[test_size:]

再通过两部分的索引将样本数据分为训练数据和测试数据:

X_train = X[train_indexes]
y_train = y[train_indexes]

X_test = X[test_indexes]
y_test = y[test_indexes]

可以将两个问题的解决方案封装到一个方法中,seed 表示随机数种子,作用在 np.random 中:

import numpy as np

def train_test_split(X, y, test_ratio=0.2, seed=None):
    if seed:
        np.random.seed(seed)
    shuffle_indexes = np.random.permutation(len(X))

    test_size = int(len(X) * test_ratio)
    test_indexes = shuffle_indexes[:test_size]
    train_indexes = shuffle_indexes[test_size:]

    X_train = X[train_indexes]
    y_train = y[train_indexes]

    X_test = X[test_indexes]
    y_test = y[test_indexes]

    return X_train, X_test, y_train, y_test

Scikit Learn 中封装了 train_test_split() 方法,放在了 model_selection 模块中:

from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

算法正确率

通过 train_test_split() 方法对样本数据进行了预处理后,开始训练模型,并且对测试数据进行验证:

from sklearn.neighbors import KNeighborsClassifier

kNN_classifier = KNeighborsClassifier(n_neighbors=6)
kNN_classifier.fit(X_train, y_train)
y_predict = kNN_classifier.predict(X_test)

y_predict 是对测试数据 X_test 的预测结果,其中与 y_test 相等的个数除以 y_test 的个数就是该模型的正确率,将其和 y_test 进行比较可以算出模型的正确率:

def accuracy_score(y_true, y_predict):
    return sum(y_predict == y_true) / len(y_true)

调用该方法,返回一个小于等于1的浮点数:

accuracy_score(y_test, y_predict)

同样在 Scikit Learn 的 metrics 模块中封装了 accuracy_score() 方法:

from sklearn.metrics import accuracy_score

accuracy_score(y_test, y_predict)

Scikit Learn 中的 KNeighborsClassifier 类的父类 ClassifierMixin 中有一个 score() 方法,里面就调用了 accuracy_score() 方法,将测试数据 X_test 和 y_test 作为参数传入该方法中,可以直接计算出算法正确率。

class ClassifierMixin(object):
    def score(self, X, y, sample_weight=None):
        from .metrics import accuracy_score
        return accuracy_score(y, self.predict(X), sample_weight=sample_weight)

超参数

前文中提到的 k 是一种超参数,超参数是在算法运行前需要决定的参数。 Scikit Learn 中 k-近邻算法包含了许多超参数,在初始化构造函数中都有指定:

def __init__(self, n_neighbors=5,
             weights='uniform', algorithm='auto', leaf_size=30,
             p=2, metric='minkowski', metric_params=None, n_jobs=None,
             **kwargs):
    # code here

这些超参数的含义在源代码和官方文档[scikit-learn.org]中都有说明。

算法优缺点

k-近邻算法是一个比较简单的算法,有其优点但也有缺点。

优点是思想简单,但效果强大, 天然的适合多分类问题。

缺点是效率低下,比如一个训练集有 m 个样本,n 个特征,则预测一个新的数据的算法复杂度为 O(m*n);同时该算法可能产生维数灾难,当维数很大时,两个点之间的距离可能也很大,如 (0,0,0,...,0) 和 (1,1,1,...,1)(10000维)之间的距离为100。

源码地址

Github | ML-Algorithms-Action


以上がPythonによるK最近傍アルゴリズムの原理と実装(ソースコード添付)の詳細内容です。詳細については、PHP 中国語 Web サイトの他の関連記事を参照してください。

声明
この記事はsegmentfaultで複製されています。侵害がある場合は、admin@php.cn までご連絡ください。
どのデータ型をPythonアレイに保存できますか?どのデータ型をPythonアレイに保存できますか?Apr 27, 2025 am 12:11 AM

Pythonlistscanstoreanydatatype,arraymodulearraysstoreonetype,andNumPyarraysarefornumericalcomputations.1)Listsareversatilebutlessmemory-efficient.2)Arraymodulearraysarememory-efficientforhomogeneousdata.3)NumPyarraysareoptimizedforperformanceinscient

Pythonアレイに間違ったデータ型の値を保存しようとするとどうなりますか?Pythonアレイに間違ったデータ型の値を保存しようとするとどうなりますか?Apr 27, 2025 am 12:10 AM

heouttemptemptostoreavure ofthewrongdatatypeinapythonarray、yure counteractypeerror.thisduetothearraymodule'sstricttypeeencultionyを使用します

Python Standard Libraryの一部はどれですか:リストまたは配列はどれですか?Python Standard Libraryの一部はどれですか:リストまたは配列はどれですか?Apr 27, 2025 am 12:03 AM

PythonListSarePartOfThestAndardarenot.liestareBuilting-in、versatile、forStoringCollectionsのpythonlistarepart。

スクリプトが間違ったPythonバージョンで実行されるかどうかを確認する必要がありますか?スクリプトが間違ったPythonバージョンで実行されるかどうかを確認する必要がありますか?Apr 27, 2025 am 12:01 AM

theScriptisrunningwithwrongthonversionduetorectRectDefaultEntertersettings.tofixthis:1)CheckthedededefaultHaulthonsionsingpython - versionorpython3-- version.2)usevirtualenvironmentsbycreatingonewiththon3.9-mvenvmyenv、andverixe

Pythonアレイで実行できる一般的な操作は何ですか?Pythonアレイで実行できる一般的な操作は何ですか?Apr 26, 2025 am 12:22 AM

PythonArraysSupportVariousoperations:1)SlicingExtractsSubsets、2)Appending/ExtendingAdddesements、3)inSertingSelementSatspecificpositions、4)remvingingDeletesements、5)sorting/verversingsorder、and6)listenionsionsionsionsionscreatenewlistsebasedexistin

一般的に使用されているnumpy配列はどのようなアプリケーションにありますか?一般的に使用されているnumpy配列はどのようなアプリケーションにありますか?Apr 26, 2025 am 12:13 AM

numpyarraysAressertialentionsionceivationsefirication-efficientnumericalcomputations andDatamanipulation.theyarecrucialindatascience、mashineelearning、物理学、エンジニアリング、および促進可能性への適用性、scaledatiencyを効率的に、forexample、infinancialanalyyy

Pythonのリスト上の配列を使用するのはいつですか?Pythonのリスト上の配列を使用するのはいつですか?Apr 26, 2025 am 12:12 AM

UseanArray.ArrayOverAlistinPythonは、Performance-criticalCode.1)homogeneousdata:araysavememorywithpedelements.2)Performance-criticalcode:Araysofterbetterbetterfornumerumerumericaleperations.3)interf

すべてのリスト操作は配列でサポートされていますか?なぜまたはなぜですか?すべてのリスト操作は配列でサポートされていますか?なぜまたはなぜですか?Apr 26, 2025 am 12:05 AM

いいえ、notallistoperationSaresuptedbyarrays、andviceversa.1)arraysdonotsupportdynamicoperationslikeappendorintorintorinsertizizing、whosimpactsporformance.2)リスト

See all articles

ホットAIツール

Undresser.AI Undress

Undresser.AI Undress

リアルなヌード写真を作成する AI 搭載アプリ

AI Clothes Remover

AI Clothes Remover

写真から衣服を削除するオンライン AI ツール。

Undress AI Tool

Undress AI Tool

脱衣画像を無料で

Clothoff.io

Clothoff.io

AI衣類リムーバー

Video Face Swap

Video Face Swap

完全無料の AI 顔交換ツールを使用して、あらゆるビデオの顔を簡単に交換できます。

ホットツール

VSCode Windows 64 ビットのダウンロード

VSCode Windows 64 ビットのダウンロード

Microsoft によって発売された無料で強力な IDE エディター

SublimeText3 Linux 新バージョン

SublimeText3 Linux 新バージョン

SublimeText3 Linux 最新バージョン

メモ帳++7.3.1

メモ帳++7.3.1

使いやすく無料のコードエディター

SublimeText3 中国語版

SublimeText3 中国語版

中国語版、とても使いやすい

mPDF

mPDF

mPDF は、UTF-8 でエンコードされた HTML から PDF ファイルを生成できる PHP ライブラリです。オリジナルの作者である Ian Back は、Web サイトから「オンザフライ」で PDF ファイルを出力し、さまざまな言語を処理するために mPDF を作成しました。 HTML2FPDF などのオリジナルのスクリプトよりも遅く、Unicode フォントを使用すると生成されるファイルが大きくなりますが、CSS スタイルなどをサポートし、多くの機能強化が施されています。 RTL (アラビア語とヘブライ語) や CJK (中国語、日本語、韓国語) を含むほぼすべての言語をサポートします。ネストされたブロックレベル要素 (P、DIV など) をサポートします。