K均值聚类及其java实现

import java.util.ArrayList;
import java.util.Random;

public class Kmeans {

private int k; //参数k
private int m; //迭代次数
private int dataSetLength; //数据集元素个数,即数据集长度
private ArrayList dataSet; //数据集链表
private ArrayList center; //中心链表
private ArrayList> cluster;
private ArrayList Jc; //误差平方和
private Random random;

/*
* 构造方法
*/
public Kmeans(int k){
this.k = k;
}

/*
* 初始化方法
*/
public void init(){
m = 0;
random = new Random();
dataSet = initDataSet(); //初始化数据集
dataSetLength = dataSet.size();
center = initCenters(); //初始化中心
cluster = initCluster(); //初始化簇集,分配内存,但元素为空
Jc = new ArrayList(); //初始化误差平方和
}

/*
* 初始化数据集方法
*/
public ArrayList initDataSet(){
ArrayList dataSet = new ArrayList();
float[][] dataSetArray = new float[][]{
{8, 2}, {3, 4}, {2, 5}, {4, 2}, {7, 3},
{6, 2}, {4, 7}, {6, 3}, {5, 3}, {6, 3},
{6, 9}, {1, 6}, {3, 9}, {4, 1}, {8, 6}
};
for(int i = 0; i < dataSetArray.length; i ++){
dataSet.add(dataSetArray[i]);
}
return dataSet;
}

/*
* 初始化中心方法
*/
public ArrayList initCenters(){
ArrayList center = new ArrayList();
int[] randoms = new int[k];
boolean flag;
int temp = random.nextInt(dataSetLength);
randoms[0] = temp;
for(int i = 1; i < k; i ++){
flag = true;
while(flag){
temp = random.nextInt(dataSetLength);
int j;
for(j = 0; j < i; j ++){
if(temp == randoms[j]){
break;
}
}
if(j == i){
flag = false;
}
}
randoms[i] = temp;
}
/*
//测试随机数组生成情况
for(int i = 0; i < k; i ++){
System.out.println("test1: randoms[" + i + "] = " + randoms[i]);
}
System.out.println();
*/
for (int i = 0; i < k; i ++){
center.add(dataSet.get(randoms[i])); //生成初始化中心链表
}
return center;
}

/*
* 初始化簇集方法
*/
public ArrayList> initCluster(){
ArrayList> cluster = new ArrayList>();
for(int i = 0; i < k; i ++){
cluster.add(new ArrayList());
}
return cluster;
}

/*
* 求距离方法
*/
public float distance(float[] element, float[] center){
float distance = 0.0f;
float x = element[0] - center[0];
float y = element[1] - center[1];
float z = x*x + y*y;
distance = (float) Math.sqrt(z);
return distance;
}

/*
* 求最小距离位置方法
*/
public int minDistance(float distance[]){
float minDistance = distance[0];
int minLocation = 0;
for(int i = 1; i < distance.length; i ++){
if(distance[i] < minDistance){
minDistance = distance[i];

minLocation = i;
}else if(distance[i] == minDistance){ //如果和当前最短距离相等,则随机选取一个
if(random.nextInt(10) < 5){
minLocation = i;
}
}
}
return minLocation;
}

/*
* 生成簇集元素方法
*/
public void clusterSet(){
float[] distance = new float[k];
for(int i = 0; i < dataSetLength; i ++){
for(int j = 0; j < k; j ++){
distance[j] = distance(dataSet.get(i), center.get(j));
//System.out.println("test2: " + "dataSet[" + i + "], centers[" + j + "], distance = " + distance[j]); //测试元素与中心距离
}
int minLocation = minDistance(distance);
//System.out.println("test2: " + "dataSet[" + i + "], minLocation = " + minLocation); //测试最小距离位置
//System.out.println(); //测试用
cluster.get(minLocation).add(dataSet.get(i)); //核心:将当前元素放到最小距离中心相关的簇中
}
}

/*
* 求误差平方的方法
*/
public float errorSquare(float[] element, float[] center){
float x = element[0] - center[0];
float y = element[1] - center[1];
float errorSquare = x*x + y*y;
return errorSquare;
}

/*
* 计算误差平方和准则函数方法
*/
public void countRule(){
float JcF = 0;
for(int i = 0; i < cluster.size(); i ++){
for(int j = 0; j < cluster.get(i).size(); j ++){
JcF += errorSquare(cluster.get(i).get(j), center.get(i));
}
}
Jc.add(JcF);
}

/*
* 计算新的簇中心方法
*/
public void findNewCenter(){
for(int i = 0; i < k; i ++){
int n = cluster.get(i).size();
if(n != 0){
float[] newCenter = {0, 0};
for(int j = 0; j < n; j ++){
newCenter[0] += cluster.get(i).get(j)[0];
newCenter[1] += cluster.get(i).get(j)[1];
}
newCenter[0] = newCenter[0] / n;
newCenter[1] = newCenter[1] / n;
center.set(i, newCenter);
}
}
}

/*
* 打印数据数组
*/
public void printDataArray(ArrayList dataArray, String dataArrayName){
for(int i = 0; i < dataArray.size(); i ++){
System.out.println("print: " + dataArrayName + "[" + i + "] = {" + dataArray.get(i)[0] + ", " + dataArray.get(i)[1] + "}");
}
System.out.println();
}

/*
* Kmeans算法核心过程方法
*/
public void kmeans(){
init(); //初始化
printDataArray(dataSet, "initDataSet"); //输出初始化数据集
printDataArray(center, "initCenter"); //输出初始化中心
while(true){
clusterSet(); //生成簇集元素
for (int i = 0; i < cluster.size(); i++) {
printDataArray(cluster.get(i), "cluster[" + i + "]"); //输出簇集生成结果
}
countRule(); //计算误差平方和
System.out.println("count:" + "Jc[" + m + "] = " + Jc.get(m)); //输出误差平方和
System.out.println();
//判断退出迭代条件
if(m != 0){
if (Jc.get(m) - Jc.get(m - 1) == 0) {
break;
}
}
findNewCenter(); //计算新的中心

printDataArray(center, "newCenter"); //输出新的中心
m++;
cluster.clear(); //簇集清空
cluster = initCluster(); //簇集初始化
}
System.out.println("note: the times of repeat: m = " + m); //输出迭代次数
}

/*
* 主函数
*/
public static void main(String[] args){
long startTime = System.currentTimeMillis(); //获取开始时间
System.out.println("note: program begins.");
Kmeans myKmeans = new Kmeans(3);
myKmeans.kmeans(); //调用Kmeans核心方法
long endTime = System.currentTimeMillis(); //获取结束时间
System.out.println("note: running time = " +(endTime - startTime) + "ms.");
System.out.println("note: program ends.");
}
}

相关主题
相关文档
最新文档