CS 3343/3341 Introduction to Algorithms |
Matrix Multiplication Using Divide and Conquer with Strassen Multiplication |
4-by-4 Matrices, Using Strassen Multiplication | |
---|---|
// MStrassen.java: multiply matrices, using // Strassen's multiplication public class MStrassen { // C = A*B, mult of 2-by-2 matrices int[][] mult(int A[][], int B[][]){ int[][] C = new int[2][2]; int P1 =A[0][0] * (B[0][1] - B[1][1]); int P2 =(A[0][0] + A[0][1]) * B[1][1]; int P3 =(A[1][0] + A[1][1]) * B[0][0]; int P4 =A[1][1] * (B[1][0] - B[0][0]); int P5 = (A[0][0] + A[1][1]) * (B[0][0] + B[1][1]); int P6 = (A[0][1] - A[1][1]) * (B[1][0] + B[1][1]); int P7 = (A[0][0] - A[1][0]) * (B[0][0] + B[0][1]); C[0][0] = P5 + P4 - P2 + P6; C[0][1] = P1 + P2; C[1][0] = P3 + P4; C[1][1] = P5 + P1 - P3 - P7; return C; } // C = A + B, addition of 2-by-2 matrices int[][] add(int A[][], int B[][]){ int[][] C = new int[2][2]; for (int i = 0; i < 2; i++) { //rows for (int j = 0; j < 2; j++) {//cols C[i][j] = A[i][j] + B[i][j]; } } return C; } // C = A - B, subtrac. of 2-by-2 matrices int[][] sub(int A[][], int B[][]){ int[][] C = new int[2][2]; for (int i = 0; i < 2; i++) { //rows for (int j = 0; j < 2; j++) {//cols C[i][j] = A[i][j] - B[i][j]; } } return C; } void printMat(int A[][], String s){ System.out.print(s + "\n"); for (int i = 0; i < 2; i++){ System.out.print("|"); for (int j = 0; j < 2; j++) System.out.print(" "+A[i][j]+" "); System.out.print("|\n"); } System.out.print("\n"); } | public static void main(String[] argv) { MStrassen mats = new MStrassen(); int [][] A00 = {{2, 3}, {4, 0}}; int [][] A01 = {{1, 6}, {0, 2}}; int [][] A10 = {{4, 2}, {0, 3}}; int [][] A11 = {{0, 1}, {5, 2}}; int [][] B00 = {{3, 0}, {1, 2}}; int [][] B01 = {{4, 3}, {0, 2}}; int [][] B10 = {{0, 3}, {5, 1}}; int [][] B11 = {{1, 4}, {3, 2}}; int [][] P1 = mats. mult(A00, (mats.sub(B01, B11))); int [][] P2 = mats. mult((mats.add(A00, A01)), B11); int [][] P3 = mats. mult((mats.add(A10, A11)), B00); int [][] P4 = mats. mult(A11, (mats.sub(B10, B00))); int [][] P5 = mats.mult((mats.add(A00, A11)),(mats.add(B00, B11))); int [][] P6 = mats.mult((mats.sub(A01, A11)),(mats.add(B10, B11))); int [][] P7 = mats.mult((mats.sub(A00, A10)),(mats.add(B00, B01))); int [][] C00 = mats.add(P5, mats.add(mats.sub(P4, P2), P6)); int [][] C01 = mats.add(P1, P2); int [][] C10 = mats.add(P3, P4); int [][] C11 = mats.add(P5, mats.sub(mats.sub(P1, P3), P7)); mats.printMat(A00, "A00"); mats.printMat(A10, "A10"); mats.printMat(B00, "B00"); mats.printMat(B10, "B10"); mats.printMat(C00, "C00"); mats.printMat(C10, "C10"); mats.printMat(A01, "A01"); mats.printMat(A11, "A11"); mats.printMat(B01, "B01"); mats.printMat(B11, "B11"); mats.printMat(C01, "C01"); mats.printMat(C11, "C11"); } } |
|
|