Java k-means聚类算法的实现

声明:本页面是StackOverFlow热门问题的中英对照翻译,遵循CC BY-SA 4.0协议,如果您需要使用它,必须同样遵循CC BY-SA许可,注明原文地址和作者信息,同时你必须将它归于原作者(不是我):StackOverFlow 原文地址: http://stackoverflow.com/questions/21111070/
Warning: these are provided under cc-by-sa 4.0 license. You are free to use/share it, But you must attribute it to the original authors (not me): StackOverFlow

提示:将鼠标放在中文语句上可以显示对应的英文。显示中英文
时间:2020-08-13 06:39:16  来源:igfitidea点击:

Implementation of k-means clustering algorithm

javaalgorithmdata-miningcluster-analysisk-means

提问by chinu

In my program, i'm taking k=2 for k-mean algorithm i.e i want only 2 clusters. I have implemented in a very simple and straightforward way, still i'm unable to understand why my program is getting into infinite loop. can anyone please guide me where i'm making a mistake..?

在我的程序中,我将 k=2 用于 k-mean 算法,即我只想要 2 个集群。我已经以一种非常简单直接的方式实现了,但我仍然无法理解为什么我的程序会进入无限循环。任何人都可以请指导我在哪里我犯了错误..?

for simplicity, i hav taken the input in the program code itself. here is my code :

为简单起见,我在程序代码本身中进行了输入。这是我的代码:

import java.io.*;
import java.lang.*;
class Kmean
{
public static void main(String args[])
{
int N=9;
int arr[]={2,4,10,12,3,20,30,11,25};    // initial data
int i,m1,m2,a,b,n=0;
boolean flag=true;
float sum1=0,sum2=0;
a=arr[0];b=arr[1];
m1=a; m2=b;
int cluster1[]=new int[9],cluster2[]=new int[9];
for(i=0;i<9;i++)
    System.out.print(arr[i]+ "\t");
System.out.println();

do
{
 n++;
 int k=0,j=0;
 for(i=0;i<9;i++)
 {
    if(Math.abs(arr[i]-m1)<=Math.abs(arr[i]-m2))
    {   cluster1[k]=arr[i];
        k++;
    }
    else
    {   cluster2[j]=arr[i];
        j++;
    }
 }
    System.out.println();
    for(i=0;i<9;i++)
        sum1=sum1+cluster1[i];
    for(i=0;i<9;i++)
        sum2=sum1+cluster2[i];
    a=m1;
    b=m2;
    m1=Math.round(sum1/k);
    m2=Math.round(sum2/j);
    if(m1==a && m2==b)
        flag=false;
    else
        flag=true;

    System.out.println("After iteration "+ n +" , cluster 1 :\n");    //printing the clusters of each iteration
    for(i=0;i<9;i++)
        System.out.print(cluster1[i]+ "\t");

    System.out.println("\n");
    System.out.println("After iteration "+ n +" , cluster 2 :\n");
    for(i=0;i<9;i++)
        System.out.print(cluster2[i]+ "\t");

}while(flag);

    System.out.println("Final cluster 1 :\n");            // final clusters
    for(i=0;i<9;i++)
        System.out.print(cluster1[i]+ "\t");

    System.out.println();
    System.out.println("Final cluster 2 :\n");
    for(i=0;i<9;i++)
        System.out.print(cluster2[i]+ "\t");
 }
}

采纳答案by Vincent van der Weele

You have a bunch of errors:

你有一堆错误:

  1. At the start of your doloop you should reset sum1and sum2to 0.

  2. You should loop until kand jrespectively when calculating sum1and sum2(or clear cluster1and cluster2at the start of your doloop.

  3. In the calculation of sum2you accidentally use sum1.

  1. do循环开始时,您应该将sum1和重置sum2为 0。

  2. 您应该在计算和(或在循环开始时清除和)时分别循环直到k和。jsum1sum2cluster1cluster2do

  3. 在计算中sum2你不小心使用了sum1.

When I make those fixes the code runs fine, yielding the output:

当我进行这些修复时,代码运行良好,产生输出:

Final cluster 1 :   
2   4   10   12  3   11  0   0   0

Final cluster 2 :
20  30  25   0   0   0   0   0   0

My general advise: learn how to use a debugger. Stackoverflow is not meant for questions like this: it is expected that you can find your own bugs and only come here when everything else fails...

我的一般建议是:学习如何使用调试器。Stackoverflow 不适合这样的问题:预计您可以找到自己的错误,并且只有在其他一切都失败时才来到这里......

回答by Tim B

The only possible infinite loop is the do-while.

唯一可能的无限循环是 do-while。

if(m1==a && m2==b)
    flag=false;
else
    flag=true;

You only exit the loop if flagis true. Breakpoint the if statement here and have a look to see why it is never getting set to false. Maybe add some debug print statements as well.

只有在flag为真时才退出循环。在这里断点 if 语句,看看为什么它永远不会被设置为 false。也许还添加一些调试打印语句。

回答by Manindar

public class KMeansClustering {

public static void main(String args[]) {
    int arr[] = {2, 4, 10, 12, 3, 20, 30, 11, 25};    // initial data
    int i, m1, m2, a, b, n = 0;
    boolean flag;
    float sum1, sum2;
    a = arr[0];
    b = arr[1];
    m1 = a;
    m2 = b;
    int cluster1[] = new int[arr.length], cluster2[] = new int[arr.length];
    do {
        sum1 = 0;
        sum2 = 0;
        cluster1 = new int[arr.length];
        cluster2 = new int[arr.length];
        n++;
        int k = 0, j = 0;
        for (i = 0; i < arr.length; i++) {
            if (Math.abs(arr[i] - m1) <= Math.abs(arr[i] - m2)) {
                cluster1[k] = arr[i];
                k++;
            } else {
                cluster2[j] = arr[i];
                j++;
            }
        }
        System.out.println();
        for (i = 0; i < k; i++) {
            sum1 = sum1 + cluster1[i];
        }
        for (i = 0; i < j; i++) {
            sum2 = sum2 + cluster2[i];
        }
        //printing Centroids/Means\
        System.out.println("m1=" + m1 + "   m2=" + m2);
        a = m1;
        b = m2;
        m1 = Math.round(sum1 / k);
        m2 = Math.round(sum2 / j);
        flag = !(m1 == a && m2 == b);

        System.out.println("After iteration " + n + " , cluster 1 :\n");    //printing the clusters of each iteration
        for (i = 0; i < cluster1.length; i++) {
            System.out.print(cluster1[i] + "\t");
        }

        System.out.println("\n");
        System.out.println("After iteration " + n + " , cluster 2 :\n");
        for (i = 0; i < cluster2.length; i++) {
            System.out.print(cluster2[i] + "\t");
        }

    } while (flag);

    System.out.println("Final cluster 1 :\n");            // final clusters
    for (i = 0; i < cluster1.length; i++) {
        System.out.print(cluster1[i] + "\t");
    }

    System.out.println();
    System.out.println("Final cluster 2 :\n");
    for (i = 0; i < cluster2.length; i++) {
        System.out.print(cluster2[i] + "\t");
    }
}

}

}

This is working code.

这是工作代码。

回答by anoojkv varghese

package k;

/**
 *
 * @author Anooj.k.varghese
 */

import java.io.FileNotFoundException;
import java.io.File;
import java.util.Scanner;
public class K {


    /**
     * @param args the command line arguments
     */
    //GLOBAL VARIABLES
    //data_set[][] -------------datast is stored in the data_set[][] array
    //initial_centroid[][]------according to k'th value we select initaly k centroid.stored in the initial_centroid[][] 
    //                          value is assigned in the  'first_itration()' function
    private static double[][] arr;
    static int num = 0;
    static Double data_set[][]=new Double[20000][100];
    static Double diff[][]=new Double[20000][100];
    static Double intial_centroid[][]=new Double[300][400];
    static Double center_mean[][]=new Double[20000][100];
    static Double total_mean[]=new Double[200000];
    static int cnum;
    static int it=1;
    static int checker=1;
    static int row=4;//rows in Your DataSet here i use iris dataset 
     /////////////////////////////////reading the file/////////////////////////////////////
     // discriptin readFile readthe txt file
    private static void readFile() throws FileNotFoundException
        {
        Scanner scanner = new Scanner(new File("E:/aa.txt"));//Dataset path
        scanner.useDelimiter(System.getProperty("line.separator"));
        int lineNo = 0;
            while (scanner.hasNext())
             {
                parseLine(scanner.next(),lineNo);
                lineNo++;
                System.out.println();
             }
             // System.out.println("total"+num); PRINT THE TOTAL
     scanner.close();
        }
    //read file is copey to the data_set
    public static void parseLine(String line,int lineNo)
      { 
        Scanner lineScanner = new Scanner(line);
        lineScanner.useDelimiter(",");
          for(int col=0;col<row;col++)
              {
                  Double arry=lineScanner.nextDouble();
                  data_set[num][col]=arry;                          ///here read  data set is assign the variable data_set
               }
         num++;

        }
      public static void first_itration()
    {   double a = 0;
         System.out.println("ENTER CLUSTER NUMBER");
         Scanner sc=new Scanner(System.in);      
         cnum=sc.nextInt();   //enter the number of cenroid

         int result[]=new int[cnum];
        double re=0;

         System.out.println("centroid");
         for(int i=0;i<cnum;i++)
         {
            for(int j=0;j<row;j++)
                {
                    intial_centroid[i][j]=data_set[i][j];                  //// CENTROID ARE STORED IN AN intial_centroid variable
                    System.out.print(intial_centroid[i][j]);      
                }
            System.out.println();
         }
       System.out.println("------------");

       int counter1=0;
       for(int i=0;i<num;i++)
       {
            for(int j=0;j<row;j++)
                {
                      //System.out.println("hii");
                 System.out.print(data_set[i][j]);

                 }
       counter1++;
       System.out.println();
       }
           System.out.println("total="+counter1);                             //print the total number of data
           //----------------------------------

           ///////////////////EUCLIDEAN DISTANCE////////////////////////////////////
                                                                                /// find the Euclidean Distance
        for(int i=0;i<num;i++)
        {
                for(int j=0;j<cnum;j++)       
                {
                    re=0;
                     for(int k=0;k<row;k++)
                     {
                            a= (intial_centroid[j][k]-data_set[i][k]);
                            //System.out.println(a);
                             a=a*a;
                             re=re+a;                                                 // store the row sum

                        }

                         diff[i][j]= Math.sqrt(re);// find the squre root

        }
        }
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////

///////////////////////////////////////////////FIND THE SMALLEST VALUE////////////////////////////////////////////////
   double aaa;
   double counter;
     int ccc=1;
   for(int i=0;i<num;i++)
   {
         int c=1;
         counter=c;
         aaa=diff[i][0];
         for(int j=0;j<cnum;j++)
         {
          //System.out.println(diff[i][j]);

            if(aaa>=diff[i][j] )                                                //change
                {
                    aaa=diff[i][j];
                    counter=j;


                    // Jth value are stord in the counter variable 
               //   System.out.println(counter);
               }


         }

            data_set[i][row]=counter;                                        //assign the counter to last position of data set

            //System.out.println("--");
      }                                                                  //print the first itration
            System.out.println("**FIRST ITRATION**");

      for(int i=0;i<num;i++)
              {
                  for(int j=0;j<=row;j++)
                      {
                      //System.out.println("hii");
                              System.out.print(data_set[i][j]+ " ");
                       }
                  System.out.println();
              }

    it++;
    }


    public static void calck_mean()
    { 
        for(int i=0;i<20000;i++)
        {
            for(int j=0;j<100;j++)
            {
                center_mean[i][j]=0.0;
            }
        }


  double c = 0; 
     int a=0;
     int p;
     int abbb = 0;
        if(it%2==0)
         {
             abbb=row;
         }
        else if(it%2==1)
         {
             abbb=row+1;
          }
        for(int k=0;k<cnum;k++)
            {
                    double counter = 0;    
                    for(int i=0;i<num;i++)
                     {
                        for(int j=0;j<=row;j++)
                        {               
                            if(data_set[i][abbb]==a)
                            {
                            System.out.print(data_set[i][j]);
                            center_mean[k][j] += data_set[i][j];

                            }

                          }
                        System.out.println();
                      if(data_set[i][abbb]==a)
                        {
                            counter++;
                        }
                  System.out.println();
              }

         a++;
         total_mean[k]=counter;

         }
         for(int i=0;i<cnum;i++)
            {
            System.out.println("\n");
            for(int j=0;j<row;j++)
            {
              if(total_mean[i]==0)
              {
                   center_mean[i][j]=0.0;
              }
              else
              {
                center_mean[i][j]=center_mean[i][j]/total_mean[i];
              }
              }
        }
        for(int k=0;k<cnum;k++)
        {
            for(int j=0;j<row;j++)
            {
              //System.out.print(center_mean[k][j]);
            }
            System.out.println();

        }
       /* for(int j=0;j<cnum;j++)
        {
            System.out.println(total_mean[j]);
        }*/

    }
public static void kmeans1()
    {
       double  a = 0;
       int result[]=new int[cnum];
       double re=0;

  //// CENTROID ARE STORED IN AN data_set VARIABLE intial_centroid 
         System.out.println(" new centroid");
            for(int i=0;i<cnum;i++)
            {
                for(int j=0;j<row;j++)
                {
                    intial_centroid[i][j]=center_mean[i][j];
                    System.out.print(intial_centroid[i][j]);
                }
             System.out.println();
            }

   //----------------------------------------------JUST PRINT THE data_set

           //----------------------------------
        for(int i=0;i<num;i++)
        {
            for(int j=0;j<cnum;j++)
            {
             re=0;
             for(int k=0;k<row;k++)
             {

               a=(intial_centroid[j][k]-data_set[i][k]);
                 //System.out.println(a);
                a=a*a;        
               re=re+a;

                }

             diff[i][j]= Math.sqrt(re);
             //System.out.println(diff[i][j]);
            }
        }
   double aaa;
    double counter;
     for(int i=0;i<num;i++)
     {

         int c=1;
         counter=c;
          aaa=diff[i][0];
         for(int j=0;j<cnum;j++)
         {
            // System.out.println(diff[i][j]);
            if(aaa>=diff[i][j])                                                  //change
            {
               aaa=diff[i][j];
                counter=j;
               //   System.out.println(counter);
            }


         }


         if(it%2==0)
            {
        // abbb=4;
                data_set[i][row+1]=counter;
            }
         else if(it%2==1)
            {
                data_set[i][row]=counter;
      //   abbb=4;
            }


        //System.out.println("--");
     }
     System.out.println(it+" ITRATION**");

      for(int i=0;i<num;i++)
              {
                  for(int j=0;j<=row+1;j++)
                  {
                      //System.out.println("hii");
                      System.out.print(data_set[i][j]+" ");
                  }
                  System.out.println();
              }

    it++;
    }
public static void check()
{
    checker=0;
    for(int i=0;i<num;i++)
    {
         //System.out.println("hii");
        if(Double.compare(data_set[i][row],data_set[i][row+1]) != 0)
        {
            checker=1;
            //System.out.println("hii " + i  + " " + data_set[i][4]+ " "+data_set[i][4]);
            break;
        }
        System.out.println();
    }

}
public static void dispaly()
{

      System.out.println(it+" ITRATION**");

      for(int i=0;i<num;i++)
              {
                  for(int j=0;j<=row+1;j++)
                  {
                      //System.out.println("hii");
                      System.out.print(data_set[i][j]+" ");
                  }
                  System.out.println();
              }
}


 public static void print()
    {
        System.out.println();
         System.out.println();
          System.out.println();
        System.out.println("----OUTPUT----");
        int c=0;
        int a=0;
        for(int i=0;i<cnum;i++)
        {
            System.out.println("---------CLUSTER-"+i+"-----");
         a=0;
            for(int j=0;j<num;j++)
            {
                 if(data_set[j][row]==i)
                 {a++;
                for(int k=0;k<row;k++)
                {

                    System.out.print(data_set[j][k]+"  ");
                }
                c++;
                System.out.println();
                }
                 //System.out.println(num);

            }
               System.out.println("CLUSTER INSTANCES="+a);


        }
        System.out.println("TOTAL INSTANCE"+c);
    }


    public static void main(String[] args) throws FileNotFoundException 
    {
    readFile();
    first_itration();

    while(checker!=0)
            {
            calck_mean();
            kmeans1();
            check();
            } 
  dispaly();
  print();
    }




}


    ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////