CS 3343/3341 Introduction to Algorithms |
Matrix Multiplication Using Divide and Conquer with Strassen Multiplication |
| 4-by-4 Matrices, Using Strassen Multiplication | |
|---|---|
// MStrassen.java: multiply matrices, using
// Strassen's multiplication
public class MStrassen {
// C = A*B, mult of 2-by-2 matrices
int[][] mult(int A[][], int B[][]){
int[][] C = new int[2][2];
int P1 =A[0][0] * (B[0][1] - B[1][1]);
int P2 =(A[0][0] + A[0][1]) * B[1][1];
int P3 =(A[1][0] + A[1][1]) * B[0][0];
int P4 =A[1][1] * (B[1][0] - B[0][0]);
int P5 = (A[0][0] + A[1][1]) *
(B[0][0] + B[1][1]);
int P6 = (A[0][1] - A[1][1]) *
(B[1][0] + B[1][1]);
int P7 = (A[0][0] - A[1][0]) *
(B[0][0] + B[0][1]);
C[0][0] = P5 + P4 - P2 + P6;
C[0][1] = P1 + P2;
C[1][0] = P3 + P4;
C[1][1] = P5 + P1 - P3 - P7;
return C;
}
// C = A + B, addition of 2-by-2 matrices
int[][] add(int A[][], int B[][]){
int[][] C = new int[2][2];
for (int i = 0; i < 2; i++) { //rows
for (int j = 0; j < 2; j++) {//cols
C[i][j] = A[i][j] + B[i][j];
}
}
return C;
}
// C = A - B, subtrac. of 2-by-2 matrices
int[][] sub(int A[][], int B[][]){
int[][] C = new int[2][2];
for (int i = 0; i < 2; i++) { //rows
for (int j = 0; j < 2; j++) {//cols
C[i][j] = A[i][j] - B[i][j];
}
}
return C;
}
void printMat(int A[][], String s){
System.out.print(s + "\n");
for (int i = 0; i < 2; i++){
System.out.print("|");
for (int j = 0; j < 2; j++)
System.out.print(" "+A[i][j]+" ");
System.out.print("|\n");
}
System.out.print("\n");
}
| public static void main(String[] argv) {
MStrassen mats = new MStrassen();
int [][] A00 = {{2, 3}, {4, 0}};
int [][] A01 = {{1, 6}, {0, 2}};
int [][] A10 = {{4, 2}, {0, 3}};
int [][] A11 = {{0, 1}, {5, 2}};
int [][] B00 = {{3, 0}, {1, 2}};
int [][] B01 = {{4, 3}, {0, 2}};
int [][] B10 = {{0, 3}, {5, 1}};
int [][] B11 = {{1, 4}, {3, 2}};
int [][] P1 = mats.
mult(A00, (mats.sub(B01, B11)));
int [][] P2 = mats.
mult((mats.add(A00, A01)), B11);
int [][] P3 = mats.
mult((mats.add(A10, A11)), B00);
int [][] P4 = mats.
mult(A11, (mats.sub(B10, B00)));
int [][] P5 = mats.mult((mats.add(A00,
A11)),(mats.add(B00, B11)));
int [][] P6 = mats.mult((mats.sub(A01,
A11)),(mats.add(B10, B11)));
int [][] P7 = mats.mult((mats.sub(A00,
A10)),(mats.add(B00, B01)));
int [][] C00 = mats.add(P5,
mats.add(mats.sub(P4, P2), P6));
int [][] C01 = mats.add(P1, P2);
int [][] C10 = mats.add(P3, P4);
int [][] C11 = mats.add(P5,
mats.sub(mats.sub(P1, P3), P7));
mats.printMat(A00, "A00");
mats.printMat(A10, "A10");
mats.printMat(B00, "B00");
mats.printMat(B10, "B10");
mats.printMat(C00, "C00");
mats.printMat(C10, "C10");
mats.printMat(A01, "A01");
mats.printMat(A11, "A11");
mats.printMat(B01, "B01");
mats.printMat(B11, "B11");
mats.printMat(C01, "C01");
mats.printMat(C11, "C11");
}
}
|
|
| |||||