22
33Object clustering with k-mean algorithm
44
5-
65author: Atsushi Sakai (@Atsushi_twi)
76
87"""
98
9+ import numpy as np
10+ import math
1011import matplotlib .pyplot as plt
1112import random
1213
1314
14- class Cluster :
15+ class Clusters :
16+
17+ def __init__ (self , x , y , nlabel ):
18+ self .x = x
19+ self .y = y
20+ self .ndata = len (self .x )
21+ self .nlabel = nlabel
22+ self .labels = [random .randint (0 , nlabel - 1 )
23+ for _ in range (self .ndata )]
24+ self .cx = [0.0 for _ in range (nlabel )]
25+ self .cy = [0.0 for _ in range (nlabel )]
26+
27+
28+ def init_clusters (rx , ry , nc ):
29+
30+ clusters = Clusters (rx , ry , nc )
31+
32+ return clusters
33+
34+
35+ def calc_centroid (clusters ):
36+
37+ for ic in range (clusters .nlabel ):
38+ x , y = calc_labeled_points (ic , clusters )
39+ ndata = len (x )
40+ clusters .cx [ic ] = sum (x ) / ndata
41+ clusters .cy [ic ] = sum (y ) / ndata
42+
43+ return clusters
1544
16- def __init__ (self ):
17- self .x = []
18- self .y = []
19- self .cx = None
20- self .cy = None
45+
46+ def update_clusters (clusters ):
47+ cost = 0.0
48+
49+ for ip in range (clusters .ndata ):
50+ px = clusters .x [ip ]
51+ py = clusters .y [ip ]
52+
53+ dx = [icx - px for icx in clusters .cx ]
54+ dy = [icy - py for icy in clusters .cy ]
55+
56+ dlist = [math .sqrt (idx ** 2 + idy ** 2 ) for (idx , idy ) in zip (dx , dy )]
57+ mind = min (dlist )
58+ min_id = dlist .index (mind )
59+ clusters .labels [ip ] = min_id
60+ cost += min_id
61+
62+ return clusters , cost
2163
2264
2365def kmean_clustering (rx , ry , nc ):
2466
25- minx , maxx = min (rx ), max ( rx )
26- miny , maxy = min ( ry ), max ( ry )
67+ clusters = init_clusters (rx , ry , nc )
68+ clusters = calc_centroid ( clusters )
2769
28- clusters = [Cluster () for i in range (nc )]
70+ MAX_LOOP = 10
71+ DCOST_TH = 1.0
72+ pcost = 100.0
73+ for loop in range (MAX_LOOP ):
74+ print ("Loop:" , loop )
75+ clusters , cost = update_clusters (clusters )
76+ clusters = calc_centroid (clusters )
2977
30- for c in clusters :
31- c .cx = random .uniform (minx , maxx )
32- c .cy = random .uniform (miny , maxy )
78+ dcost = abs (cost - pcost )
79+ if dcost < DCOST_TH :
80+ break
81+ pcost = cost
3382
3483 return clusters
3584
@@ -40,17 +89,30 @@ def calc_raw_data():
4089
4190 cx = [0.0 , 5.0 ]
4291 cy = [0.0 , 5.0 ]
43- np = 30
92+ npoints = 30
4493 rand_d = 3.0
4594
4695 for (icx , icy ) in zip (cx , cy ):
47- for _ in range (np ):
96+ for _ in range (npoints ):
4897 rx .append (icx + rand_d * (random .random () - 0.5 ))
4998 ry .append (icy + rand_d * (random .random () - 0.5 ))
5099
51100 return rx , ry
52101
53102
103+ def calc_labeled_points (ic , clusters ):
104+
105+ inds = np .array ([i for i in range (clusters .ndata )
106+ if clusters .labels [i ] == ic ])
107+ tx = np .array (clusters .x )
108+ ty = np .array (clusters .y )
109+
110+ x = tx [inds ]
111+ y = ty [inds ]
112+
113+ return x , y
114+
115+
54116def main ():
55117 print (__file__ + " start!!" )
56118
@@ -59,11 +121,10 @@ def main():
59121 ncluster = 2
60122 clusters = kmean_clustering (rx , ry , ncluster )
61123
62- for c in clusters :
63- print (c .cx , c .cy )
64- plt .plot (c .cx , c .cy , "x" )
65-
66- plt .plot (rx , ry , "." )
124+ for ic in range (clusters .nlabel ):
125+ x , y = calc_labeled_points (ic , clusters )
126+ plt .plot (x , y , "x" )
127+ plt .plot (clusters .cx , clusters .cy , "o" )
67128 plt .show ()
68129
69130
0 commit comments