Maison  >  Article  >  développement back-end  >  Explication détaillée de l'implémentation python de l'algorithme kMeans

Explication détaillée de l'implémentation python de l'algorithme kMeans

小云云
小云云original
2017-12-22 09:03:198526parcourir

Le clustering est une sorte d'apprentissage non supervisé. Placer des objets similaires dans le même cluster est un peu comme une classification entièrement automatique. Plus les objets du cluster sont similaires, plus la différence entre les objets des clusters est grande, et mieux c'est. L'effet de regroupement sera bon. Cet article présente principalement en détail l'implémentation de l'algorithme kMeans en python, qui a une certaine valeur de référence. Les amis intéressés peuvent s'y référer et espérer qu'il pourra aider tout le monde.

1. L'algorithme de clustering k-means

Le clustering k-means divise les données en k clusters, et chaque cluster passe son centroïde, c'est-à-dire le centre du cluster Décrivez le centre de tous les points. Tout d’abord, k points initiaux sont déterminés aléatoirement comme centroïdes, puis l’ensemble de données est attribué au cluster le plus proche. Le centroïde de chaque cluster est ensuite mis à jour pour correspondre à la moyenne de tous les ensembles de données. Divisez ensuite l'ensemble de données une seconde fois jusqu'à ce que les résultats du clustering ne changent plus.

Le pseudo-code est

Créez aléatoirement k centroïdes de cluster
Lorsque l'affectation de cluster d'un point change :
Pour chaque point de l'ensemble de données Données points :
Pour chaque centroïde :
Calculer la distance entre l'ensemble de données et le centroïde
Attribuer l'ensemble de données au cluster correspondant au centroïde le plus proche
Pour chaque cluster, calculer la moyenne de tous les points dans le cluster Et utilisez la moyenne comme centroïde

implémentation de Python


import numpy as np
import matplotlib.pyplot as plt

def loadDataSet(fileName): 
 dataMat = [] 
 with open(fileName) as f:
  for line in f.readlines():
   line = line.strip().split('\t')
   dataMat.append(line)
 dataMat = np.array(dataMat).astype(np.float32)
 return dataMat


def distEclud(vecA,vecB):
 return np.sqrt(np.sum(np.power((vecA-vecB),2)))
def randCent(dataSet,k):
 m = np.shape(dataSet)[1]
 center = np.mat(np.ones((k,m)))
 for i in range(m):
  centmin = min(dataSet[:,i])
  centmax = max(dataSet[:,i])
  center[:,i] = centmin + (centmax - centmin) * np.random.rand(k,1)
 return center
def kMeans(dataSet,k,distMeans = distEclud,createCent = randCent):
 m = np.shape(dataSet)[0]
 clusterAssment = np.mat(np.zeros((m,2)))
 centroids = createCent(dataSet,k)
 clusterChanged = True
 while clusterChanged:
  clusterChanged = False
  for i in range(m):
   minDist = np.inf
   minIndex = -1
   for j in range(k):
    distJI = distMeans(dataSet[i,:],centroids[j,:])
    if distJI < minDist:
     minDist = distJI
     minIndex = j
   if clusterAssment[i,0] != minIndex:
    clusterChanged = True
   clusterAssment[i,:] = minIndex,minDist**2
  for cent in range(k):
   ptsInClust = dataSet[np.nonzero(clusterAssment[:,0].A == cent)[0]]
   centroids[cent,:] = np.mean(ptsInClust,axis = 0)
 return centroids,clusterAssment



data = loadDataSet(&#39;testSet.txt&#39;)
muCentroids, clusterAssing = kMeans(data,4)
fig = plt.figure(0)
ax = fig.add_subplot(111)
ax.scatter(data[:,0],data[:,1],c = clusterAssing[:,0].A)
plt.show()

print(clusterAssing)

2. Algorithme de bisection k moyenne

L'algorithme K-means peut converger vers un minimum local plutôt que vers un minimum global. Une mesure utilisée pour mesurer l’efficacité du clustering est la somme des erreurs quadratiques (SSE). Parce que le carré est pris, l’accent est davantage mis sur le point situé au centre du principe. Afin de surmonter le problème de la convergence de l'algorithme des k-moyennes vers un minimum local, quelqu'un a proposé l'algorithme de bisection des k-means.
Traitez d'abord tous les points comme un cluster, puis divisez le cluster en deux, puis sélectionnez le cluster parmi tous les clusters qui peuvent minimiser la valeur SSE jusqu'à ce que le nombre spécifié de clusters soit atteint.

Pseudocode

Considérer tous les points comme un cluster
Calculer SSE
tandis que Lorsque le nombre de clusters est inférieur à k :
pour chaque cluster :
Calculer l'erreur totale
Effectuer un clustering k-means (k=2) sur un cluster donné
Calculer l'erreur totale de division du cluster en deux
Sélectionner le cluster qui minimise l'erreur Effectuer le partitionnement opération

Implémentation Python


import numpy as np
import matplotlib.pyplot as plt

def loadDataSet(fileName): 
 dataMat = [] 
 with open(fileName) as f:
  for line in f.readlines():
   line = line.strip().split(&#39;\t&#39;)
   dataMat.append(line)
 dataMat = np.array(dataMat).astype(np.float32)
 return dataMat


def distEclud(vecA,vecB):
 return np.sqrt(np.sum(np.power((vecA-vecB),2)))
def randCent(dataSet,k):
 m = np.shape(dataSet)[1]
 center = np.mat(np.ones((k,m)))
 for i in range(m):
  centmin = min(dataSet[:,i])
  centmax = max(dataSet[:,i])
  center[:,i] = centmin + (centmax - centmin) * np.random.rand(k,1)
 return center
def kMeans(dataSet,k,distMeans = distEclud,createCent = randCent):
 m = np.shape(dataSet)[0]
 clusterAssment = np.mat(np.zeros((m,2)))
 centroids = createCent(dataSet,k)
 clusterChanged = True
 while clusterChanged:
  clusterChanged = False
  for i in range(m):
   minDist = np.inf
   minIndex = -1
   for j in range(k):
    distJI = distMeans(dataSet[i,:],centroids[j,:])
    if distJI < minDist:
     minDist = distJI
     minIndex = j
   if clusterAssment[i,0] != minIndex:
    clusterChanged = True
   clusterAssment[i,:] = minIndex,minDist**2
  for cent in range(k):
   ptsInClust = dataSet[np.nonzero(clusterAssment[:,0].A == cent)[0]]
   centroids[cent,:] = np.mean(ptsInClust,axis = 0)
 return centroids,clusterAssment

def biKmeans(dataSet,k,distMeans = distEclud):
 m = np.shape(dataSet)[0]
 clusterAssment = np.mat(np.zeros((m,2)))
 centroid0 = np.mean(dataSet,axis=0).tolist()
 centList = [centroid0]
 for j in range(m):
  clusterAssment[j,1] = distMeans(dataSet[j,:],np.mat(centroid0))**2
 while (len(centList)<k):
  lowestSSE = np.inf
  for i in range(len(centList)):
   ptsInCurrCluster = dataSet[np.nonzero(clusterAssment[:,0].A == i)[0],:]
   centroidMat,splitClustAss = kMeans(ptsInCurrCluster,2,distMeans)
   sseSplit = np.sum(splitClustAss[:,1])
   sseNotSplit = np.sum(clusterAssment[np.nonzero(clusterAssment[:,0].A != i)[0],1])
   if (sseSplit + sseNotSplit) < lowestSSE:
    bestCentToSplit = i
    bestNewCents = centroidMat.copy()
    bestClustAss = splitClustAss.copy()
    lowestSSE = sseSplit + sseNotSplit
  print(&#39;the best cent to split is &#39;,bestCentToSplit)
#  print(&#39;the len of the bestClust&#39;)
  bestClustAss[np.nonzero(bestClustAss[:,0].A == 1)[0],0] = len(centList)
  bestClustAss[np.nonzero(bestClustAss[:,0].A == 0)[0],0] = bestCentToSplit

  clusterAssment[np.nonzero(clusterAssment[:,0].A == bestCentToSplit)[0],:] = bestClustAss.copy()
  centList[bestCentToSplit] = bestNewCents[0,:].tolist()[0]
  centList.append(bestNewCents[1,:].tolist()[0])
 return np.mat(centList),clusterAssment

data = loadDataSet(&#39;testSet2.txt&#39;)
muCentroids, clusterAssing = biKmeans(data,3)
fig = plt.figure(0)
ax = fig.add_subplot(111)
ax.scatter(data[:,0],data[:,1],c = clusterAssing[:,0].A,cmap=plt.cm.Paired)
ax.scatter(muCentroids[:,0],muCentroids[:,1])
plt.show()

print(clusterAssing)
print(muCentroids)

Téléchargement du code et de l'ensemble de données : K-means

Recommandations associées :

Laissez Mahout KMeans analyse de clustering s'exécuter sur Hadoop

cvKMeans2 analyse de clustering moyen + analyse de code + image couleur en niveaux de gris clustering

Exemple détaillé de simple capture d'image de page Web à l'aide de Python

Ce qui précède est le contenu détaillé de. pour plus d'informations, suivez d'autres articles connexes sur le site Web de PHP en chinois!

Déclaration:
Le contenu de cet article est volontairement contribué par les internautes et les droits d'auteur appartiennent à l'auteur original. Ce site n'assume aucune responsabilité légale correspondante. Si vous trouvez un contenu suspecté de plagiat ou de contrefaçon, veuillez contacter admin@php.cn