java implements the arbitrary matrix Strassen algorithm

  • 2020-05-05 11:13:29
  • OfStack

In this example, the input is two arbitrary size matrices m * n, n * m, and the output is the product of two matrices. Strassen algorithm is used to calculate matrix multiplication of arbitrary size. Procedures for their own, after the test, please feel free to use. The basic algorithm is:
1. For a square matrix, find the largest l such that l = 2 ^ k, k is an integer, and l < m. The Strassen algorithm is adopted for the square matrix with side length of l, and the brute-force method is used for the remaining parts and the missing parts in the square matrix.
2. For a non-square matrix, add 0 according to the row and column to make it square.
StrassenMethodTest.java


package matrixalgorithm;
 
import java.util.Scanner;
 
public class StrassenMethodTest {
 
  private StrassenMethod strassenMultiply;
   
   StrassenMethodTest(){
    strassenMultiply = new StrassenMethod();
  }//end cons
 
   public static void main(String[] args){
    Scanner input = new Scanner(System.in);
    System.out.println("Input row size of the first matrix: ");
    int arow = input.nextInt();
    System.out.println("Input column size of the first matrix: ");
    int acol = input.nextInt();
    System.out.println("Input row size of the second matrix: ");
    int brow = input.nextInt();
    System.out.println("Input column size of the second matrix: ");
    int bcol = input.nextInt();
 
    double[][] A = new double[arow][acol];
    double[][] B = new double[brow][bcol];
    double[][] C = new double[arow][bcol];
    System.out.println("Input data for matrix A: ");
     
    /*In all of the codes later in this project,
    r means row while c means column.
    */
    for (int r = 0; r < arow; r++) {
      for (int c = 0; c < acol; c++) {
        System.out.printf("Data of A[%d][%d]: ", r, c);
        A[r][c] = input.nextDouble();
      }//end inner loop
    }//end loop
 
    System.out.println("Input data for matrix B: ");
    for (int r = 0; r < brow; r++) {
      for (int c = 0; c < bcol; c++) {
        System.out.printf("Data of A[%d][%d]: ", r, c);
        B[r][c] = input.nextDouble();
      }//end inner loop
    }//end loop
 
    StrassenMethodTest algorithm = new StrassenMethodTest();
    C = algorithm.multiplyRectMatrix(A, B, arow, acol, brow, bcol);
 
    //Display the calculation result:
    System.out.println("Result from matrix C: ");
    for (int r = 0; r < arow; r++) {
      for (int c = 0; c < bcol; c++) {
        System.out.printf("Data of C[%d][%d]: %f\n", r, c, C[r][c]);
      }//end inner loop
    }//end outter loop
   }//end main
   
  //Deal with matrices that are not square:
  public double[][] multiplyRectMatrix(double[][] A, double[][] B,
      int arow, int acol, int brow, int bcol) {
    if (arow != bcol) //Invalid multiplicatio
      return new double[][]{{0}};
    
    double[][] C = new double[arow][bcol];
 
    if (arow < acol) {
       
      double[][] newA = new double[acol][acol];
      double[][] newB = new double[brow][brow];
 
      int n = acol;
       
      for (int r = 0; r < acol; r++) 
        for (int c = 0; c < acol; c++) 
          newA[r][c] = 0.0;
 
      for (int r = 0; r < brow; r++) 
        for (int c = 0; c < brow; c++) 
          newB[r][c] = 0.0;
 
      for (int r = 0; r < arow; r++) 
        for (int c = 0; c < acol; c++) 
          newA[r][c] = A[r][c];
 
      for (int r = 0; r < brow; r++) 
        for (int c = 0; c < bcol; c++) 
          newB[r][c] = B[r][c];
       
      double[][] C2 = multiplySquareMatrix(newA, newB, n);
      for(int r = 0; r < arow; r++)
        for(int c = 0; c < bcol; c++)
          C[r][c] = C2[r][c];
    }//end if
     
    else if(arow == acol)
      C = multiplySquareMatrix(A, B, arow);
       
    else {
      int n = arow;
      double[][] newA = new double[arow][arow];
      double[][] newB = new double[bcol][bcol];
 
      for (int r = 0; r < arow; r++) 
        for (int c = 0; c < arow; c++) 
          newA[r][c] = 0.0;
 
      for (int r = 0; r < bcol; r++) 
        for (int c = 0; c < bcol; c++) 
          newB[r][c] = 0.0;
 
 
      for (int r = 0; r < arow; r++) 
        for (int c = 0; c < acol; c++) 
          newA[r][c] = A[r][c];
 
      for (int r = 0; r < brow; r++)
        for (int c = 0; c < bcol; c++) 
          newB[r][c] = B[r][c];
 
      double[][] C2 = multiplySquareMatrix(newA, newB, n);
      for(int r = 0; r < arow; r++)
        for(int c = 0; c < bcol; c++)
          C[r][c] = C2[r][c];
    }//end else
       
     return C;
   }//end method
   
  //Deal with matrices that are square matrices. 
   public double[][] multiplySquareMatrix(double[][] A2, double[][] B2, int n){
      
     double[][] C2 = new double[n][n];
    
     for(int r = 0; r < n; r++)
       for(int c = 0; c < n; c++)
         C2[r][c] = 0;
     
     if(n == 1){
      C2[0][0] = A2[0][0] * B2[0][0];
      return C2;
     }//end if
          
     int exp2k = 2;
     
     while(exp2k <= (n / 2) ){
       exp2k *= 2;
     }//end loop
     
     if(exp2k == n){
       C2 = strassenMultiply.strassenMultiplyMatrix(A2, B2, n);
       return C2;
     }//end else
     
     //The "biggest" strassen matrix:
     double[][][] A = new double[6][exp2k][exp2k];
     double[][][] B = new double[6][exp2k][exp2k];
     double[][][] C = new double[6][exp2k][exp2k];
     
     for(int r = 0; r < exp2k; r++){
       for(int c = 0; c < exp2k; c++){
         A[0][r][c] = A2[r][c];
         B[0][r][c] = B2[r][c];
       }//end inner loop
     }//end outter loop
     
    C[0] = strassenMultiply.strassenMultiplyMatrix(A[0], B[0], exp2k);
     
    for(int r = 0; r < exp2k; r++)
      for(int c = 0; c < exp2k; c++)
        C2[r][c] = C[0][r][c];
     
    int middle = exp2k / 2;
     
    for(int r = 0; r < middle; r++){
      for(int c = exp2k; c < n; c++){
        A[1][r][c - exp2k] = A2[r][c];
        B[3][r][c - exp2k] = B2[r][c];
      }//end inner loop     
    }//end outter loop
     
    for(int r = exp2k; r < n; r++){
      for(int c = 0; c < middle; c++){
        A[3][r - exp2k][c] = A2[r][c];
        B[1][r - exp2k][c] = B2[r][c];
      }//end inner loop     
    }//end outter loop
     
    for(int r = middle; r < exp2k; r++){
      for(int c = exp2k; c < n; c++){
        A[2][r - middle][c - exp2k] = A2[r][c];
        B[4][r - middle][c - exp2k] = B2[r][c];
      }//end inner loop     
    }//end outter loop
     
    for(int r = exp2k; r < n; r++){
      for(int c = middle; c < n - exp2k + 1; c++){
        A[4][r - exp2k][c - middle] = A2[r][c];
        B[2][r - exp2k][c - middle] = B2[r][c];     
      }//end inner loop     
    }//end outter loop
    
    for(int i = 1; i <= 4; i++)
      C[i] = multiplyRectMatrix(A[i], B[i], middle, A[i].length, A[i].length, middle);
     
    /*
    Calculate the final results of grids in the "biggest 2^k square,
    according to the rules of matrice multiplication.
    */
    for (int row = 0; row < exp2k; row++) {
       for (int col = 0; col < exp2k; col++) {
         for (int k = exp2k; k < n; k++) {
           C2[row][col] += A2[row][k] * B2[k][col];
         }//end loop
       }//end inner loop
     }//end outter loop
     
    //Use brute force to solve the rest, will be improved later:
    for(int col = exp2k; col < n; col++){
      for(int row = 0; row < n; row++){
        for(int k = 0; k < n; k++)
          C2[row][col] += A2[row][k] * B2[k][row];
      }//end inner loop
    }//end outter loop
     
    for(int row = exp2k; row < n; row++){
      for(int col = 0; col < exp2k; col++){
        for(int k = 0; k < n; k++)
          C2[row][col] += A2[row][k] * B2[k][row];
      }//end inner loop
    }//end outter loop   
     
    return C2;
   }//end method
   
}//end class

StrassenMethod.java


package matrixalgorithm;
 
import java.util.Scanner;
 
public class StrassenMethod {
 
  private double[][][][] A = new double[2][2][][];
  private double[][][][] B = new double[2][2][][];
  private double[][][][] C = new double[2][2][][];
 
  /*//Codes for testing this class:
    public static void main(String[] args) {
    Scanner input = new Scanner(System.in);
    System.out.println("Input size of the matrix: ");
    int n = input.nextInt();
 
    double[][] A = new double[n][n];
    double[][] B = new double[n][n];
    double[][] C = new double[n][n];
    System.out.println("Input data for matrix A: ");
    for (int r = 0; r < n; r++) {
      for (int c = 0; c < n; c++) {
        System.out.printf("Data of A[%d][%d]: ", r, c);
        A[r][c] = input.nextDouble();
      }//end inner loop
    }//end loop
 
    System.out.println("Input data for matrix B: ");
    for (int r = 0; r < n; r++) {
      for (int c = 0; c < n; c++) {
        System.out.printf("Data of A[%d][%d]: ", r, c);
        B[r][c] = input.nextDouble();
      }//end inner loop
    }//end loop
 
    StrassenMethod algorithm = new StrassenMethod();
    C = algorithm.strassenMultiplyMatrix(A, B, n);
 
    System.out.println("Result from matrix C: ");
    for (int r = 0; r < n; r++) {
      for (int c = 0; c < n; c++) {
        System.out.printf("Data of C[%d][%d]: %f\n", r, c, C[r][c]);
      }//end inner loop
    }//end outter loop
 
  }//end main*/
   
   public double[][] strassenMultiplyMatrix(double[][] A2, double B2[][], int n){
    double[][] C2 = new double[n][n];
    //Initialize the matrix:
    for(int rowIndex = 0; rowIndex < n; rowIndex++)
      for(int colIndex = 0; colIndex < n; colIndex++)
        C2[rowIndex][colIndex] = 0.0;
 
    if(n == 1)
      C2[0][0] = A2[0][0] * B2[0][0];
    //"Slice matrices into 2 * 2 parts: 
    else{
      double[][][][] A = new double[2][2][n / 2][n / 2];
      double[][][][] B = new double[2][2][n / 2][n / 2];
      double[][][][] C = new double[2][2][n / 2][n / 2];
       
      for(int r = 0; r < n / 2; r++){
        for(int c = 0; c < n / 2; c++){          
          A[0][0][r][c] = A2[r][c];
          A[0][1][r][c] = A2[r][n / 2 + c];
          A[1][0][r][c] = A2[n / 2 + r][c];
          A[1][1][r][c] = A2[n / 2 + r][n / 2 + c];
           
          B[0][0][r][c] = B2[r][c];
          B[0][1][r][c] = B2[r][n / 2 + c];
          B[1][0][r][c] = B2[n / 2 + r][c];
          B[1][1][r][c] = B2[n / 2 + r][n / 2 + c];
        }//end loop
      }//end loop
       
      n = n / 2;
       
      double[][][] S = new double[10][n][n];
      S[0] = minusMatrix(B[0][1], B[1][1], n);
      S[1] = addMatrix(A[0][0], A[0][1], n);
      S[2] = addMatrix(A[1][0], A[1][1], n);
      S[3] = minusMatrix(B[1][0], B[0][0], n);
      S[4] = addMatrix(A[0][0], A[1][1], n);
      S[5] = addMatrix(B[0][0], B[1][1], n);
      S[6] = minusMatrix(A[0][1], A[1][1], n);
      S[7] = addMatrix(B[1][0], B[1][1], n);
      S[8] = minusMatrix(A[0][0], A[1][0], n);
      S[9] = addMatrix(B[0][0], B[0][1], n);
       
      double[][][] P = new double[7][n][n];
      P[0] = strassenMultiplyMatrix(A[0][0], S[0], n);
      P[1] = strassenMultiplyMatrix(S[1], B[1][1], n);
      P[2] = strassenMultiplyMatrix(S[2], B[0][0], n);
      P[3] = strassenMultiplyMatrix(A[1][1], S[3], n);
      P[4] = strassenMultiplyMatrix(S[4], S[5], n);
      P[5] = strassenMultiplyMatrix(S[6], S[7], n);
      P[6] = strassenMultiplyMatrix(S[8], S[9], n);
       
      C[0][0] = addMatrix(minusMatrix(addMatrix(P[4], P[3], n), P[1], n), P[5], n);
      C[0][1] = addMatrix(P[0], P[1], n);
      C[1][0] = addMatrix(P[2], P[3], n);
      C[1][1] = minusMatrix(minusMatrix(addMatrix(P[4], P[0], n), P[2], n), P[6], n);
       
      n *= 2;
       
       for(int r = 0; r < n / 2; r++){
        for(int c = 0; c < n / 2; c++){
          C2[r][c] = C[0][0][r][c];
          C2[r][n / 2 + c] = C[0][1][r][c];
          C2[n / 2 + r][c] = C[1][0][r][c];
          C2[n / 2 + r][n / 2 + c] = C[1][1][r][c];
        }//end inner loop
      }//end outter loop
    }//end else     
 
    return C2;
  }//end method
   
   //Add two matrices according to matrix addition.
   private double[][] addMatrix(double[][] A, double[][] B, int n){
    double C[][] = new double[n][n];
     
    for(int r = 0; r < n; r++)
      for(int c = 0; c < n; c++)
        C[r][c] = A[r][c] + B[r][c];
     
    return C;
  }//end method 
   
   //Substract two matrices according to matrix addition.
   private double[][] minusMatrix(double[][] A, double[][] B, int n){
    double C[][] = new double[n][n];
     
    for(int r = 0; r < n; r++)
      for(int c = 0; c < n; c++)
        C[r][c] = A[r][c] - B[r][c];
     
    return C;
  }//end method
   
}//end class

I hope this article is helpful for you to learn java programming.


Related articles: