java implements the arbitrary matrix Strassen algorithm
- 2020-05-05 11:13:29
- OfStack
In this example, the input is two arbitrary size matrices m * n, n * m, and the output is the product of two matrices. Strassen algorithm is used to calculate matrix multiplication of arbitrary size. Procedures for their own, after the test, please feel free to use. The basic algorithm is:
1. For a square matrix, find the largest l such that l = 2 ^ k, k is an integer, and l <
m. The Strassen algorithm is adopted for the square matrix with side length of l, and the brute-force method is used for the remaining parts and the missing parts in the square matrix.
2. For a non-square matrix, add 0 according to the row and column to make it square.
StrassenMethodTest.java
package matrixalgorithm;
import java.util.Scanner;
public class StrassenMethodTest {
private StrassenMethod strassenMultiply;
StrassenMethodTest(){
strassenMultiply = new StrassenMethod();
}//end cons
public static void main(String[] args){
Scanner input = new Scanner(System.in);
System.out.println("Input row size of the first matrix: ");
int arow = input.nextInt();
System.out.println("Input column size of the first matrix: ");
int acol = input.nextInt();
System.out.println("Input row size of the second matrix: ");
int brow = input.nextInt();
System.out.println("Input column size of the second matrix: ");
int bcol = input.nextInt();
double[][] A = new double[arow][acol];
double[][] B = new double[brow][bcol];
double[][] C = new double[arow][bcol];
System.out.println("Input data for matrix A: ");
/*In all of the codes later in this project,
r means row while c means column.
*/
for (int r = 0; r < arow; r++) {
for (int c = 0; c < acol; c++) {
System.out.printf("Data of A[%d][%d]: ", r, c);
A[r][c] = input.nextDouble();
}//end inner loop
}//end loop
System.out.println("Input data for matrix B: ");
for (int r = 0; r < brow; r++) {
for (int c = 0; c < bcol; c++) {
System.out.printf("Data of A[%d][%d]: ", r, c);
B[r][c] = input.nextDouble();
}//end inner loop
}//end loop
StrassenMethodTest algorithm = new StrassenMethodTest();
C = algorithm.multiplyRectMatrix(A, B, arow, acol, brow, bcol);
//Display the calculation result:
System.out.println("Result from matrix C: ");
for (int r = 0; r < arow; r++) {
for (int c = 0; c < bcol; c++) {
System.out.printf("Data of C[%d][%d]: %f\n", r, c, C[r][c]);
}//end inner loop
}//end outter loop
}//end main
//Deal with matrices that are not square:
public double[][] multiplyRectMatrix(double[][] A, double[][] B,
int arow, int acol, int brow, int bcol) {
if (arow != bcol) //Invalid multiplicatio
return new double[][]{{0}};
double[][] C = new double[arow][bcol];
if (arow < acol) {
double[][] newA = new double[acol][acol];
double[][] newB = new double[brow][brow];
int n = acol;
for (int r = 0; r < acol; r++)
for (int c = 0; c < acol; c++)
newA[r][c] = 0.0;
for (int r = 0; r < brow; r++)
for (int c = 0; c < brow; c++)
newB[r][c] = 0.0;
for (int r = 0; r < arow; r++)
for (int c = 0; c < acol; c++)
newA[r][c] = A[r][c];
for (int r = 0; r < brow; r++)
for (int c = 0; c < bcol; c++)
newB[r][c] = B[r][c];
double[][] C2 = multiplySquareMatrix(newA, newB, n);
for(int r = 0; r < arow; r++)
for(int c = 0; c < bcol; c++)
C[r][c] = C2[r][c];
}//end if
else if(arow == acol)
C = multiplySquareMatrix(A, B, arow);
else {
int n = arow;
double[][] newA = new double[arow][arow];
double[][] newB = new double[bcol][bcol];
for (int r = 0; r < arow; r++)
for (int c = 0; c < arow; c++)
newA[r][c] = 0.0;
for (int r = 0; r < bcol; r++)
for (int c = 0; c < bcol; c++)
newB[r][c] = 0.0;
for (int r = 0; r < arow; r++)
for (int c = 0; c < acol; c++)
newA[r][c] = A[r][c];
for (int r = 0; r < brow; r++)
for (int c = 0; c < bcol; c++)
newB[r][c] = B[r][c];
double[][] C2 = multiplySquareMatrix(newA, newB, n);
for(int r = 0; r < arow; r++)
for(int c = 0; c < bcol; c++)
C[r][c] = C2[r][c];
}//end else
return C;
}//end method
//Deal with matrices that are square matrices.
public double[][] multiplySquareMatrix(double[][] A2, double[][] B2, int n){
double[][] C2 = new double[n][n];
for(int r = 0; r < n; r++)
for(int c = 0; c < n; c++)
C2[r][c] = 0;
if(n == 1){
C2[0][0] = A2[0][0] * B2[0][0];
return C2;
}//end if
int exp2k = 2;
while(exp2k <= (n / 2) ){
exp2k *= 2;
}//end loop
if(exp2k == n){
C2 = strassenMultiply.strassenMultiplyMatrix(A2, B2, n);
return C2;
}//end else
//The "biggest" strassen matrix:
double[][][] A = new double[6][exp2k][exp2k];
double[][][] B = new double[6][exp2k][exp2k];
double[][][] C = new double[6][exp2k][exp2k];
for(int r = 0; r < exp2k; r++){
for(int c = 0; c < exp2k; c++){
A[0][r][c] = A2[r][c];
B[0][r][c] = B2[r][c];
}//end inner loop
}//end outter loop
C[0] = strassenMultiply.strassenMultiplyMatrix(A[0], B[0], exp2k);
for(int r = 0; r < exp2k; r++)
for(int c = 0; c < exp2k; c++)
C2[r][c] = C[0][r][c];
int middle = exp2k / 2;
for(int r = 0; r < middle; r++){
for(int c = exp2k; c < n; c++){
A[1][r][c - exp2k] = A2[r][c];
B[3][r][c - exp2k] = B2[r][c];
}//end inner loop
}//end outter loop
for(int r = exp2k; r < n; r++){
for(int c = 0; c < middle; c++){
A[3][r - exp2k][c] = A2[r][c];
B[1][r - exp2k][c] = B2[r][c];
}//end inner loop
}//end outter loop
for(int r = middle; r < exp2k; r++){
for(int c = exp2k; c < n; c++){
A[2][r - middle][c - exp2k] = A2[r][c];
B[4][r - middle][c - exp2k] = B2[r][c];
}//end inner loop
}//end outter loop
for(int r = exp2k; r < n; r++){
for(int c = middle; c < n - exp2k + 1; c++){
A[4][r - exp2k][c - middle] = A2[r][c];
B[2][r - exp2k][c - middle] = B2[r][c];
}//end inner loop
}//end outter loop
for(int i = 1; i <= 4; i++)
C[i] = multiplyRectMatrix(A[i], B[i], middle, A[i].length, A[i].length, middle);
/*
Calculate the final results of grids in the "biggest 2^k square,
according to the rules of matrice multiplication.
*/
for (int row = 0; row < exp2k; row++) {
for (int col = 0; col < exp2k; col++) {
for (int k = exp2k; k < n; k++) {
C2[row][col] += A2[row][k] * B2[k][col];
}//end loop
}//end inner loop
}//end outter loop
//Use brute force to solve the rest, will be improved later:
for(int col = exp2k; col < n; col++){
for(int row = 0; row < n; row++){
for(int k = 0; k < n; k++)
C2[row][col] += A2[row][k] * B2[k][row];
}//end inner loop
}//end outter loop
for(int row = exp2k; row < n; row++){
for(int col = 0; col < exp2k; col++){
for(int k = 0; k < n; k++)
C2[row][col] += A2[row][k] * B2[k][row];
}//end inner loop
}//end outter loop
return C2;
}//end method
}//end class
StrassenMethod.java
package matrixalgorithm;
import java.util.Scanner;
public class StrassenMethod {
private double[][][][] A = new double[2][2][][];
private double[][][][] B = new double[2][2][][];
private double[][][][] C = new double[2][2][][];
/*//Codes for testing this class:
public static void main(String[] args) {
Scanner input = new Scanner(System.in);
System.out.println("Input size of the matrix: ");
int n = input.nextInt();
double[][] A = new double[n][n];
double[][] B = new double[n][n];
double[][] C = new double[n][n];
System.out.println("Input data for matrix A: ");
for (int r = 0; r < n; r++) {
for (int c = 0; c < n; c++) {
System.out.printf("Data of A[%d][%d]: ", r, c);
A[r][c] = input.nextDouble();
}//end inner loop
}//end loop
System.out.println("Input data for matrix B: ");
for (int r = 0; r < n; r++) {
for (int c = 0; c < n; c++) {
System.out.printf("Data of A[%d][%d]: ", r, c);
B[r][c] = input.nextDouble();
}//end inner loop
}//end loop
StrassenMethod algorithm = new StrassenMethod();
C = algorithm.strassenMultiplyMatrix(A, B, n);
System.out.println("Result from matrix C: ");
for (int r = 0; r < n; r++) {
for (int c = 0; c < n; c++) {
System.out.printf("Data of C[%d][%d]: %f\n", r, c, C[r][c]);
}//end inner loop
}//end outter loop
}//end main*/
public double[][] strassenMultiplyMatrix(double[][] A2, double B2[][], int n){
double[][] C2 = new double[n][n];
//Initialize the matrix:
for(int rowIndex = 0; rowIndex < n; rowIndex++)
for(int colIndex = 0; colIndex < n; colIndex++)
C2[rowIndex][colIndex] = 0.0;
if(n == 1)
C2[0][0] = A2[0][0] * B2[0][0];
//"Slice matrices into 2 * 2 parts:
else{
double[][][][] A = new double[2][2][n / 2][n / 2];
double[][][][] B = new double[2][2][n / 2][n / 2];
double[][][][] C = new double[2][2][n / 2][n / 2];
for(int r = 0; r < n / 2; r++){
for(int c = 0; c < n / 2; c++){
A[0][0][r][c] = A2[r][c];
A[0][1][r][c] = A2[r][n / 2 + c];
A[1][0][r][c] = A2[n / 2 + r][c];
A[1][1][r][c] = A2[n / 2 + r][n / 2 + c];
B[0][0][r][c] = B2[r][c];
B[0][1][r][c] = B2[r][n / 2 + c];
B[1][0][r][c] = B2[n / 2 + r][c];
B[1][1][r][c] = B2[n / 2 + r][n / 2 + c];
}//end loop
}//end loop
n = n / 2;
double[][][] S = new double[10][n][n];
S[0] = minusMatrix(B[0][1], B[1][1], n);
S[1] = addMatrix(A[0][0], A[0][1], n);
S[2] = addMatrix(A[1][0], A[1][1], n);
S[3] = minusMatrix(B[1][0], B[0][0], n);
S[4] = addMatrix(A[0][0], A[1][1], n);
S[5] = addMatrix(B[0][0], B[1][1], n);
S[6] = minusMatrix(A[0][1], A[1][1], n);
S[7] = addMatrix(B[1][0], B[1][1], n);
S[8] = minusMatrix(A[0][0], A[1][0], n);
S[9] = addMatrix(B[0][0], B[0][1], n);
double[][][] P = new double[7][n][n];
P[0] = strassenMultiplyMatrix(A[0][0], S[0], n);
P[1] = strassenMultiplyMatrix(S[1], B[1][1], n);
P[2] = strassenMultiplyMatrix(S[2], B[0][0], n);
P[3] = strassenMultiplyMatrix(A[1][1], S[3], n);
P[4] = strassenMultiplyMatrix(S[4], S[5], n);
P[5] = strassenMultiplyMatrix(S[6], S[7], n);
P[6] = strassenMultiplyMatrix(S[8], S[9], n);
C[0][0] = addMatrix(minusMatrix(addMatrix(P[4], P[3], n), P[1], n), P[5], n);
C[0][1] = addMatrix(P[0], P[1], n);
C[1][0] = addMatrix(P[2], P[3], n);
C[1][1] = minusMatrix(minusMatrix(addMatrix(P[4], P[0], n), P[2], n), P[6], n);
n *= 2;
for(int r = 0; r < n / 2; r++){
for(int c = 0; c < n / 2; c++){
C2[r][c] = C[0][0][r][c];
C2[r][n / 2 + c] = C[0][1][r][c];
C2[n / 2 + r][c] = C[1][0][r][c];
C2[n / 2 + r][n / 2 + c] = C[1][1][r][c];
}//end inner loop
}//end outter loop
}//end else
return C2;
}//end method
//Add two matrices according to matrix addition.
private double[][] addMatrix(double[][] A, double[][] B, int n){
double C[][] = new double[n][n];
for(int r = 0; r < n; r++)
for(int c = 0; c < n; c++)
C[r][c] = A[r][c] + B[r][c];
return C;
}//end method
//Substract two matrices according to matrix addition.
private double[][] minusMatrix(double[][] A, double[][] B, int n){
double C[][] = new double[n][n];
for(int r = 0; r < n; r++)
for(int c = 0; c < n; c++)
C[r][c] = A[r][c] - B[r][c];
return C;
}//end method
}//end class
I hope this article is helpful for you to learn java programming.