Skip to content

Commit 6168494

Browse files
author
linyiqun
committed
聚类k均值算法实现
聚类k均值算法实现
1 parent bba743d commit 6168494

4 files changed

Lines changed: 243 additions & 0 deletions

File tree

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
package DataMining_KMeans;
2+
3+
/**
4+
* K-means(K均值)算法调用类
5+
* @author lyq
6+
*
7+
*/
8+
public class Client {
9+
public static void main(String[] args){
10+
String filePath = "C:\\Users\\lyq\\Desktop\\icon\\input.txt";
11+
//聚类中心数量设定
12+
int classNum = 3;
13+
14+
KMeansTool tool = new KMeansTool(filePath, classNum);
15+
tool.kMeansClustering();
16+
}
17+
}
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
package DataMining_KMeans;
2+
3+
import java.io.BufferedReader;
4+
import java.io.File;
5+
import java.io.FileReader;
6+
import java.io.IOException;
7+
import java.text.MessageFormat;
8+
import java.util.ArrayList;
9+
import java.util.Collections;
10+
11+
/**
12+
* k均值算法工具类
13+
*
14+
* @author lyq
15+
*
16+
*/
17+
public class KMeansTool {
18+
// 输入数据文件地址
19+
private String filePath;
20+
// 分类类别个数
21+
private int classNum;
22+
// 类名称
23+
private ArrayList<String> classNames;
24+
// 聚类坐标点
25+
private ArrayList<Point> classPoints;
26+
// 所有的数据左边点
27+
private ArrayList<Point> totalPoints;
28+
29+
public KMeansTool(String filePath, int classNum) {
30+
this.filePath = filePath;
31+
this.classNum = classNum;
32+
readDataFile();
33+
}
34+
35+
/**
36+
* 从文件中读取数据
37+
*/
38+
private void readDataFile() {
39+
File file = new File(filePath);
40+
ArrayList<String[]> dataArray = new ArrayList<String[]>();
41+
42+
try {
43+
BufferedReader in = new BufferedReader(new FileReader(file));
44+
String str;
45+
String[] tempArray;
46+
while ((str = in.readLine()) != null) {
47+
tempArray = str.split(" ");
48+
dataArray.add(tempArray);
49+
}
50+
in.close();
51+
} catch (IOException e) {
52+
e.getStackTrace();
53+
}
54+
55+
classPoints = new ArrayList<>();
56+
totalPoints = new ArrayList<>();
57+
classNames = new ArrayList<>();
58+
for (int i = 0, j = 1; i < dataArray.size(); i++) {
59+
if (j <= classNum) {
60+
classPoints.add(new Point(dataArray.get(i)[0],
61+
dataArray.get(i)[1], j + ""));
62+
classNames.add(i + "");
63+
j++;
64+
}
65+
totalPoints
66+
.add(new Point(dataArray.get(i)[0], dataArray.get(i)[1]));
67+
}
68+
}
69+
70+
/**
71+
* K均值聚类算法实现
72+
*/
73+
public void kMeansClustering() {
74+
double tempX = 0;
75+
double tempY = 0;
76+
int count = 0;
77+
double error = Integer.MAX_VALUE;
78+
Point temp;
79+
80+
while (error > 0.01 * classNum) {
81+
for (Point p1 : totalPoints) {
82+
// 将所有的测试坐标点就近分类
83+
for (Point p2 : classPoints) {
84+
p2.computerDistance(p1);
85+
}
86+
Collections.sort(classPoints);
87+
88+
// 取出p1离类坐标点最近的那个点
89+
p1.setClassName(classPoints.get(0).getClassName());
90+
}
91+
92+
error = 0;
93+
// 按照均值重新划分聚类中心点
94+
for (Point p1 : classPoints) {
95+
count = 0;
96+
tempX = 0;
97+
tempY = 0;
98+
for (Point p : totalPoints) {
99+
if (p.getClassName().equals(p1.getClassName())) {
100+
count++;
101+
tempX += p.getX();
102+
tempY += p.getY();
103+
}
104+
}
105+
tempX /= count;
106+
tempY /= count;
107+
108+
error += Math.abs((tempX - p1.getX()));
109+
error += Math.abs((tempY - p1.getY()));
110+
// 计算均值
111+
p1.setX(tempX);
112+
p1.setY(tempY);
113+
114+
}
115+
116+
for (int i = 0; i < classPoints.size(); i++) {
117+
temp = classPoints.get(i);
118+
System.out.println(MessageFormat.format("聚类中心点{0},x={1},y={2}",
119+
(i + 1), temp.getX(), temp.getY()));
120+
}
121+
System.out.println("----------");
122+
}
123+
124+
System.out.println("结果值收敛");
125+
for (int i = 0; i < classPoints.size(); i++) {
126+
temp = classPoints.get(i);
127+
System.out.println(MessageFormat.format("聚类中心点{0},x={1},y={2}",
128+
(i + 1), temp.getX(), temp.getY()));
129+
}
130+
131+
}
132+
133+
}
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
package DataMining_KMeans;
2+
3+
/**
4+
* 坐标点类
5+
*
6+
* @author lyq
7+
*
8+
*/
9+
public class Point implements Comparable<Point>{
10+
// 坐标点横坐标
11+
private double x;
12+
// 坐标点纵坐标
13+
private double y;
14+
//以此点作为聚类中心的类的类名称
15+
private String className;
16+
// 坐标点之间的欧式距离
17+
private Double distance;
18+
19+
public Point(double x, double y) {
20+
this.x = x;
21+
this.y = y;
22+
}
23+
24+
public Point(String x, String y) {
25+
this.x = Double.parseDouble(x);
26+
this.y = Double.parseDouble(y);
27+
}
28+
29+
public Point(String x, String y, String className) {
30+
this.x = Double.parseDouble(x);
31+
this.y = Double.parseDouble(y);
32+
this.className = className;
33+
}
34+
35+
/**
36+
* 距离目标点p的欧几里得距离
37+
*
38+
* @param p
39+
*/
40+
public void computerDistance(Point p) {
41+
if (p == null) {
42+
return;
43+
}
44+
45+
this.distance = (this.x - p.x) * (this.x - p.x) + (this.y - p.y)
46+
* (this.y - p.y);
47+
}
48+
49+
public double getX() {
50+
return x;
51+
}
52+
53+
public void setX(double x) {
54+
this.x = x;
55+
}
56+
57+
public double getY() {
58+
return y;
59+
}
60+
61+
public void setY(double y) {
62+
this.y = y;
63+
}
64+
65+
public String getClassName() {
66+
return className;
67+
}
68+
69+
public void setClassName(String className) {
70+
this.className = className;
71+
}
72+
73+
public double getDistance() {
74+
return distance;
75+
}
76+
77+
public void setDistance(double distance) {
78+
this.distance = distance;
79+
}
80+
81+
@Override
82+
public int compareTo(Point o) {
83+
// TODO Auto-generated method stub
84+
return this.distance.compareTo(o.distance);
85+
}
86+
87+
}
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
3 3
2+
4 10
3+
9 6
4+
14 8
5+
18 11
6+
21 7

0 commit comments

Comments
 (0)