搜尋
首頁後端開發Python教學Python实现的Kmeans++算法实例

1、从Kmeans说起

Kmeans是一个非常基础的聚类算法,使用了迭代的思想,关于其原理这里不说了。下面说一下如何在matlab中使用kmeans算法。

创建7个二维的数据点:

复制代码 代码如下:
x=[randn(3,2)*.4;randn(4,2)*.5+ones(4,1)*[4 4]];
使用kmeans函数:
复制代码 代码如下:
class = kmeans(x, 2);
x是数据点,x的每一行代表一个数据;2指定要有2个中心点,也就是聚类结果要有2个簇。 class将是一个具有70个元素的列向量,这些元素依次对应70个数据点,元素值代表着其对应的数据点所处的分类号。某次运行后,class的值是:
复制代码 代码如下:

 2
 2
 2
 1
 1
 1
 1
这说明x的前三个数据点属于簇2,而后四个数据点属于簇1。 kmeans函数也可以像下面这样使用:
复制代码 代码如下:

>> [class, C, sumd, D] = kmeans(x, 2)

class =
     2
     2
     2
     1
     1
     1
     1

C =
    4.0629    4.0845
   -0.1341    0.1201

sumd =
    1.2017
    0.2939

D =
   34.3727    0.0184
   29.5644    0.1858
   36.3511    0.0898
    0.1247   37.4801
    0.7537   24.0659
    0.1979   36.7666
    0.1256   36.2149

class依旧代表着每个数据点的分类;C包含最终的中心点,一行代表一个中心点;sumd代表着每个中心点与所属簇内各个数据点的距离之和;D的每一行也对应一个数据点,行中的数值依次是该数据点与各个中心点之间的距离,Kmeans默认使用的距离是欧几里得距离(参考资料[3])的平方值。kmeans函数使用的距离,也可以是曼哈顿距离(L1-距离),以及其他类型的距离,可以通过添加参数指定。

kmeans有几个缺点(这在很多资料上都有说明):

1、最终簇的类别数目(即中心点或者说种子点的数目)k并不一定能事先知道,所以如何选一个合适的k的值是一个问题。
2、最开始的种子点的选择的好坏会影响到聚类结果。
3、对噪声和离群点敏感。
4、等等。

2、kmeans++算法的基本思路

kmeans++算法的主要工作体现在种子点的选择上,基本原则是使得各个种子点之间的距离尽可能的大,但是又得排除噪声的影响。 以下为基本思路:

1、从输入的数据点集合(要求有k个聚类)中随机选择一个点作为第一个聚类中心
2、对于数据集中的每一个点x,计算它与最近聚类中心(指已选择的聚类中心)的距离D(x)
3、选择一个新的数据点作为新的聚类中心,选择的原则是:D(x)较大的点,被选取作为聚类中心的概率较大
4、重复2和3直到k个聚类中心被选出来
5、利用这k个初始的聚类中心来运行标准的k-means算法

假定数据点集合X有n个数据点,依次用X(1)、X(2)、……、X(n)表示,那么,在第2步中依次计算每个数据点与最近的种子点(聚类中心)的距离,依次得到D(1)、D(2)、……、D(n)构成的集合D。在D中,为了避免噪声,不能直接选取值最大的元素,应该选择值较大的元素,然后将其对应的数据点作为种子点。

如何选择值较大的元素呢,下面是一种思路(暂未找到最初的来源,在资料[2]等地方均有提及,笔者换了一种让自己更好理解的说法): 把集合D中的每个元素D(x)想象为一根线L(x),线的长度就是元素的值。将这些线依次按照L(1)、L(2)、……、L(n)的顺序连接起来,组成长线L。L(1)、L(2)、……、L(n)称为L的子线。根据概率的相关知识,如果我们在L上随机选择一个点,那么这个点所在的子线很有可能是比较长的子线,而这个子线对应的数据点就可以作为种子点。下文中kmeans++的两种实现均是这个原理。

3、python版本的kmeans++

在http://rosettacode.org/wiki/K-means%2B%2B_clustering 中能找到多种编程语言版本的Kmeans++实现。下面的内容是基于python的实现(中文注释是笔者添加的):

复制代码 代码如下:

from math import pi, sin, cos
from collections import namedtuple
from random import random, choice
from copy import copy

try:
    import psyco
    psyco.full()
except ImportError:
    pass

FLOAT_MAX = 1e100

class Point:
    __slots__ = ["x", "y", "group"]
    def __init__(self, x=0.0, y=0.0, group=0):
        self.x, self.y, self.group = x, y, group

def generate_points(npoints, radius):
    points = [Point() for _ in xrange(npoints)]

    # note: this is not a uniform 2-d distribution
    for p in points:
        r = random() * radius
        ang = random() * 2 * pi
        p.x = r * cos(ang)
        p.y = r * sin(ang)

    return points

def nearest_cluster_center(point, cluster_centers):
    """Distance and index of the closest cluster center"""
    def sqr_distance_2D(a, b):
        return (a.x - b.x) ** 2  +  (a.y - b.y) ** 2

    min_index = point.group
    min_dist = FLOAT_MAX

    for i, cc in enumerate(cluster_centers):
        d = sqr_distance_2D(cc, point)
        if min_dist > d:
            min_dist = d
            min_index = i

    return (min_index, min_dist)

'''
points是数据点,nclusters是给定的簇类数目
cluster_centers包含初始化的nclusters个中心点,开始都是对象->(0,0,0)
'''

def kpp(points, cluster_centers):
    cluster_centers[0] = copy(choice(points)) #随机选取第一个中心点
    d = [0.0 for _ in xrange(len(points))]  #列表,长度为len(points),保存每个点离最近的中心点的距离

    for i in xrange(1, len(cluster_centers)):  # i=1...len(c_c)-1
        sum = 0
        for j, p in enumerate(points):
            d[j] = nearest_cluster_center(p, cluster_centers[:i])[1] #第j个数据点p与各个中心点距离的最小值
            sum += d[j]

        sum *= random()

        for j, di in enumerate(d):
            sum -= di
            if sum > 0:
                continue
            cluster_centers[i] = copy(points[j])
            break

    for p in points:
        p.group = nearest_cluster_center(p, cluster_centers)[0]

'''
points是数据点,nclusters是给定的簇类数目
'''
def lloyd(points, nclusters):
    cluster_centers = [Point() for _ in xrange(nclusters)]  #根据指定的中心点个数,初始化中心点,均为(0,0,0)

    # call k++ init
    kpp(points, cluster_centers)   #选择初始种子点

    # 下面是kmeans
    lenpts10 = len(points) >> 10

    changed = 0
    while True:
        # group element for centroids are used as counters
        for cc in cluster_centers:
            cc.x = 0
            cc.y = 0
            cc.group = 0

        for p in points:
            cluster_centers[p.group].group += 1  #与该种子点在同一簇的数据点的个数
            cluster_centers[p.group].x += p.x
            cluster_centers[p.group].y += p.y

        for cc in cluster_centers:    #生成新的中心点
            cc.x /= cc.group
            cc.y /= cc.group

        # find closest centroid of each PointPtr
        changed = 0  #记录所属簇发生变化的数据点的个数
        for p in points:
            min_i = nearest_cluster_center(p, cluster_centers)[0]
            if min_i != p.group:
                changed += 1
                p.group = min_i

        # stop when 99.9% of points are good
        if changed             break

    for i, cc in enumerate(cluster_centers):
        cc.group = i

    return cluster_centers

def print_eps(points, cluster_centers, W=400, H=400):
    Color = namedtuple("Color", "r g b");

    colors = []
    for i in xrange(len(cluster_centers)):
        colors.append(Color((3 * (i + 1) % 11) / 11.0,
                            (7 * i % 11) / 11.0,
                            (9 * i % 11) / 11.0))

    max_x = max_y = -FLOAT_MAX
    min_x = min_y = FLOAT_MAX

    for p in points:
        if max_x         if min_x > p.x: min_x = p.x
        if max_y         if min_y > p.y: min_y = p.y

    scale = min(W / (max_x - min_x),
                H / (max_y - min_y))
    cx = (max_x + min_x) / 2
    cy = (max_y + min_y) / 2

    print "%%!PS-Adobe-3.0\n%%%%BoundingBox: -5 -5 %d %d" % (W + 10, H + 10)

    print ("/l {rlineto} def /m {rmoveto} def\n" +
           "/c { .25 sub exch .25 sub exch .5 0 360 arc fill } def\n" +
           "/s { moveto -2 0 m 2 2 l 2 -2 l -2 -2 l closepath " +
           "   gsave 1 setgray fill grestore gsave 3 setlinewidth" +
           " 1 setgray stroke grestore 0 setgray stroke }def")

    for i, cc in enumerate(cluster_centers):
        print ("%g %g %g setrgbcolor" %
               (colors[i].r, colors[i].g, colors[i].b))

        for p in points:
            if p.group != i:
                continue
            print ("%.3f %.3f c" % ((p.x - cx) * scale + W / 2,
                                    (p.y - cy) * scale + H / 2))

        print ("\n0 setgray %g %g s" % ((cc.x - cx) * scale + W / 2,
                                        (cc.y - cy) * scale + H / 2))

    print "\n%%%%EOF"

def main():
    npoints = 30000
    k = 7 # # clusters

    points = generate_points(npoints, 10)
    cluster_centers = lloyd(points, k)
    print_eps(points, cluster_centers)

main()

上述代码实现的算法是针对二维数据的,所以Point对象有三个属性,分别是在x轴上的值、在y轴上的值、以及所属的簇的标识。函数lloyd是kmeans++算法的整体实现,其先是通过kpp函数选取合适的种子点,然后对数据集实行kmeans算法进行聚类。kpp函数的实现完全符合上述kmeans++的基本思路的2、3、4步。


4、matlab版本的kmeans++

复制代码 代码如下:

function [L,C] = kmeanspp(X,k)
%KMEANS Cluster multivariate data using the k-means++ algorithm.
%   [L,C] = kmeans_pp(X,k) produces a 1-by-size(X,2) vector L with one class
%   label per column in X and a size(X,1)-by-k matrix C containing the
%   centers corresponding to each class.

%   Version: 2013-02-08
%   Authors: Laurent Sorber (Laurent.Sorber@cs.kuleuven.be)

L = [];
L1 = 0;

while length(unique(L)) ~= k

    % The k-means++ initialization.
    C = X(:,1+round(rand*(size(X,2)-1))); %size(X,2)是数据集合X的数据点的数目,C是中心点的集合
    L = ones(1,size(X,2));
    for i = 2:k
        D = X-C(:,L); %-1
        D = cumsum(sqrt(dot(D,D,1))); %将每个数据点与中心点的距离,依次累加
        if D(end) == 0, C(:,i:k) = X(:,ones(1,k-i+1)); return; end
        C(:,i) = X(:,find(rand         [~,L] = max(bsxfun(@minus,2*real(C'*X),dot(C,C,1).')); %碉堡了,这句,将每个数据点进行分类。
    end

    % The k-means algorithm.
    while any(L ~= L1)
        L1 = L;
        for i = 1:k, l = L==i; C(:,i) = sum(X(:,l),2)/sum(l); end
        [~,L] = max(bsxfun(@minus,2*real(C'*X),dot(C,C,1).'),[],1);
    end

end

这个函数的实现有些特殊,参数X是数据集,但是是将每一列看做一个数据点,参数k是指定的聚类数。返回值L标记了每个数据点的所属分类,返回值C保存了最终形成的中心点(一列代表一个中心点)。测试一下:

复制代码 代码如下:

>> x=[randn(3,2)*.4;randn(4,2)*.5+ones(4,1)*[4 4]]
x =
   -0.0497    0.5669
    0.5959    0.2686
    0.5636   -0.4830
    4.3586    4.3634
    4.8151    3.8483
    4.2444    4.1469
    4.5173    3.6064

>> [L, C] = kmeanspp(x',2)
L =
     2     2     2     1     1     1     1
C =
    4.4839    0.3699
    3.9913    0.1175

好了,现在开始一点点理解这个实现,顺便巩固一下matlab知识。

unique函数用来获取一个矩阵中的不同的值,示例:

复制代码 代码如下:

>> unique([1 3 3 4 4 5])
ans =
     1     3     4     5
>> unique([1 3 3 ; 4 4 5])
ans =
     1
     3
     4
     5
所以循环 while length(unique(L)) ~= k 以得到了k个聚类为结束条件,不过一般情况下,这个循环一次就结束了,因为重点在这个循环中。

rand是返回在(0,1)这个区间的一个随机数。在注释%-1所在行,C被扩充了,被扩充的方法类似于下面:

复制代码 代码如下:

>> C =[];
>> C(1,1) = 1
C =
     1
>> C(2,1) = 2
C =
     1
     2
>> C(:,[1 1 1 1])
ans =
     1     1     1     1
     2     2     2     2
>> C(:,[1 1 1 1 2])
Index exceeds matrix dimensions.


C中第二个参数的元素1,其实是代表C的第一列数据,之所以在值2时候出现Index exceeds matrix dimensions.的错误,是因为C本身没有第二列。如果C有第二列了:

复制代码 代码如下:

>> C(2,2) = 3;
>> C(2,2) = 4;
>> C(:,[1 1 1 1 2])
ans =
     1     1     1     1     3
     2     2     2     2     4
dot函数是将两个矩阵点乘,然后把结果在某一维度相加:
复制代码 代码如下:

>> TT = [1 2 3 ; 4 5 6];
>> dot(TT,TT)
ans =
    17    29    45
>> dot(TT,TT,1 )
ans =
    17    29    45
cumsum是累加函数:
复制代码 代码如下:

>> cumsum([1 2 3])
ans =
     1     3     6
>> cumsum([1 2 3; 4 5 6])
ans =
     1     2     3
     5     7     9
max函数可以返回两个值,第二个代表的是max数的索引位置:
复制代码 代码如下:

>> [~, L] = max([1 2 3])
L =
     3
>> [~,L] = max([1 2 3;2 3 4])
L =
     2     2     2
其中~是占位符。

关于bsxfun函数,官方文档指出:

复制代码 代码如下:

C = bsxfun(fun,A,B) applies the element-by-element binary operation specified by the function handle fun to arrays A and B, with singleton expansion enabled
其中参数fun是函数句柄,关于函数句柄见资料[9]。下面是bsxfun的一个示例:
复制代码 代码如下:
>> A= [1 2 3;2 3 4]
A =
     1     2     3
     2     3     4
>> B=[6;7]
B =
     6
     7
>> bsxfun(@minus,A,B)
ans =
    -5    -4    -3
    -5    -4    -3
对于:
复制代码 代码如下:

[~,L] = max(bsxfun(@minus,2*real(C'*X),dot(C,C,1).'));
max的参数是这样一个矩阵,矩阵有n列,n也是数据点的个数,每一列代表着对应的数据点与各个中心点之间的距离的相反数。不过这个距离有些与众不同,算是欧几里得距离的变形。

假定数据点是2维的,某个数据点为(x1,y1),某个中心点为(c1,d1),那么通过bsxfun(@minus,2real(C'X),dot(C,C,1).')的计算,数据点与中心点的距离为2c1x1 + 2d1y1 -c1.^2 - c2.^2,可以变换为x1.^2 + y1.^2 - (c1-x1).^2 - (d1-y1).^2。对于每一列而言,由于是数据点与各个中心点之间的计算,所以可以忽略x1.^2 + y1.^2,最终计算结果是欧几里得距离的平方的相反数。这也说明了使用max的合理性,因为一个数据点的所属簇取决于与其距离最近的中心点,若将距离取相反数,则应该是值最大的那个点。

陳述
本文內容由網友自願投稿,版權歸原作者所有。本站不承擔相應的法律責任。如發現涉嫌抄襲或侵權的內容,請聯絡admin@php.cn
Python和時間:充分利用您的學習時間Python和時間:充分利用您的學習時間Apr 14, 2025 am 12:02 AM

要在有限的時間內最大化學習Python的效率,可以使用Python的datetime、time和schedule模塊。 1.datetime模塊用於記錄和規劃學習時間。 2.time模塊幫助設置學習和休息時間。 3.schedule模塊自動化安排每週學習任務。

Python:遊戲,Guis等Python:遊戲,Guis等Apr 13, 2025 am 12:14 AM

Python在遊戲和GUI開發中表現出色。 1)遊戲開發使用Pygame,提供繪圖、音頻等功能,適合創建2D遊戲。 2)GUI開發可選擇Tkinter或PyQt,Tkinter簡單易用,PyQt功能豐富,適合專業開發。

Python vs.C:申請和用例Python vs.C:申請和用例Apr 12, 2025 am 12:01 AM

Python适合数据科学、Web开发和自动化任务,而C 适用于系统编程、游戏开发和嵌入式系统。Python以简洁和强大的生态系统著称,C 则以高性能和底层控制能力闻名。

2小時的Python計劃:一種現實的方法2小時的Python計劃:一種現實的方法Apr 11, 2025 am 12:04 AM

2小時內可以學會Python的基本編程概念和技能。 1.學習變量和數據類型,2.掌握控制流(條件語句和循環),3.理解函數的定義和使用,4.通過簡單示例和代碼片段快速上手Python編程。

Python:探索其主要應用程序Python:探索其主要應用程序Apr 10, 2025 am 09:41 AM

Python在web開發、數據科學、機器學習、自動化和腳本編寫等領域有廣泛應用。 1)在web開發中,Django和Flask框架簡化了開發過程。 2)數據科學和機器學習領域,NumPy、Pandas、Scikit-learn和TensorFlow庫提供了強大支持。 3)自動化和腳本編寫方面,Python適用於自動化測試和系統管理等任務。

您可以在2小時內學到多少python?您可以在2小時內學到多少python?Apr 09, 2025 pm 04:33 PM

兩小時內可以學到Python的基礎知識。 1.學習變量和數據類型,2.掌握控制結構如if語句和循環,3.了解函數的定義和使用。這些將幫助你開始編寫簡單的Python程序。

如何在10小時內通過項目和問題驅動的方式教計算機小白編程基礎?如何在10小時內通過項目和問題驅動的方式教計算機小白編程基礎?Apr 02, 2025 am 07:18 AM

如何在10小時內教計算機小白編程基礎?如果你只有10個小時來教計算機小白一些編程知識,你會選擇教些什麼�...

如何在使用 Fiddler Everywhere 進行中間人讀取時避免被瀏覽器檢測到?如何在使用 Fiddler Everywhere 進行中間人讀取時避免被瀏覽器檢測到?Apr 02, 2025 am 07:15 AM

使用FiddlerEverywhere進行中間人讀取時如何避免被檢測到當你使用FiddlerEverywhere...

See all articles

熱AI工具

Undresser.AI Undress

Undresser.AI Undress

人工智慧驅動的應用程序,用於創建逼真的裸體照片

AI Clothes Remover

AI Clothes Remover

用於從照片中去除衣服的線上人工智慧工具。

Undress AI Tool

Undress AI Tool

免費脫衣圖片

Clothoff.io

Clothoff.io

AI脫衣器

AI Hentai Generator

AI Hentai Generator

免費產生 AI 無盡。

熱門文章

R.E.P.O.能量晶體解釋及其做什麼(黃色晶體)
4 週前By尊渡假赌尊渡假赌尊渡假赌
R.E.P.O.最佳圖形設置
3 週前By尊渡假赌尊渡假赌尊渡假赌
R.E.P.O.如果您聽不到任何人,如何修復音頻
4 週前By尊渡假赌尊渡假赌尊渡假赌
WWE 2K25:如何解鎖Myrise中的所有內容
1 個月前By尊渡假赌尊渡假赌尊渡假赌

熱工具

SublimeText3 英文版

SublimeText3 英文版

推薦:為Win版本,支援程式碼提示!

Safe Exam Browser

Safe Exam Browser

Safe Exam Browser是一個安全的瀏覽器環境,安全地進行線上考試。該軟體將任何電腦變成一個安全的工作站。它控制對任何實用工具的訪問,並防止學生使用未經授權的資源。

ZendStudio 13.5.1 Mac

ZendStudio 13.5.1 Mac

強大的PHP整合開發環境

Dreamweaver Mac版

Dreamweaver Mac版

視覺化網頁開發工具

Dreamweaver CS6

Dreamweaver CS6

視覺化網頁開發工具