CS 3343/3341
  Introduction to Algorithms  
Matrix Multiplication
Using Divide and Conquer
with Strassen Multiplication

4-by-4 Matrices, Using Strassen Multiplication
// MStrassen.java: multiply matrices, using
//  Strassen's multiplication
public class MStrassen {

   // C = A*B, mult of 2-by-2 matrices
   int[][] mult(int A[][], int B[][]){
      int[][] C = new int[2][2];

      int P1 =A[0][0] * (B[0][1] - B[1][1]);
      int P2 =(A[0][0] + A[0][1]) * B[1][1];
      int P3 =(A[1][0] + A[1][1]) * B[0][0];
      int P4 =A[1][1] * (B[1][0] - B[0][0]);
      int P5 = (A[0][0] + A[1][1]) *
               (B[0][0] + B[1][1]);
      int P6 = (A[0][1] - A[1][1]) *
               (B[1][0] + B[1][1]);
      int P7 = (A[0][0] - A[1][0]) *
               (B[0][0] + B[0][1]);   

      C[0][0] = P5 + P4 - P2 + P6;
      C[0][1] = P1 + P2;
      C[1][0] = P3 + P4;
      C[1][1] = P5 + P1 - P3 - P7;

      return C;
   }

   // C = A + B, addition of 2-by-2 matrices
   int[][] add(int A[][], int B[][]){
      int[][] C = new int[2][2];
      for (int i = 0; i < 2; i++) { //rows
         for (int j = 0; j < 2; j++) {//cols
            C[i][j] = A[i][j] + B[i][j];
         }
      }
      return C;
   }

   // C = A - B, subtrac. of 2-by-2 matrices
   int[][] sub(int A[][], int B[][]){
      int[][] C = new int[2][2];
      for (int i = 0; i < 2; i++) { //rows
         for (int j = 0; j < 2; j++) {//cols
            C[i][j] = A[i][j] - B[i][j];
         }
      }
      return C;
   }

   void printMat(int A[][], String s){
      System.out.print(s + "\n");
      for (int i = 0; i < 2; i++){
        System.out.print("|");
        for (int j = 0; j < 2; j++)
          System.out.print(" "+A[i][j]+" ");
        System.out.print("|\n");
      }
      System.out.print("\n");
   }
   public static void main(String[] argv) {
      MStrassen mats = new MStrassen();
      int [][] A00 = {{2, 3}, {4, 0}};
      int [][] A01 = {{1, 6}, {0, 2}};
      int [][] A10 = {{4, 2}, {0, 3}};
      int [][] A11 = {{0, 1}, {5, 2}};
      int [][] B00 = {{3, 0}, {1, 2}};
      int [][] B01 = {{4, 3}, {0, 2}};
      int [][] B10 = {{0, 3}, {5, 1}};
      int [][] B11 = {{1, 4}, {3, 2}};

      int [][] P1 = mats.
         mult(A00, (mats.sub(B01, B11)));
      int [][] P2 = mats.
         mult((mats.add(A00, A01)), B11);
      int [][] P3 = mats.
         mult((mats.add(A10, A11)), B00);
      int [][] P4 = mats.
         mult(A11, (mats.sub(B10, B00)));
      int [][] P5 = mats.mult((mats.add(A00,
         A11)),(mats.add(B00, B11)));
      int [][] P6 = mats.mult((mats.sub(A01,
         A11)),(mats.add(B10, B11)));
      int [][] P7 = mats.mult((mats.sub(A00,
         A10)),(mats.add(B00, B01)));   

      int [][] C00 = mats.add(P5,
         mats.add(mats.sub(P4, P2), P6));
      int [][] C01 = mats.add(P1, P2);
      int [][] C10 = mats.add(P3, P4);
      int [][] C11 = mats.add(P5,
         mats.sub(mats.sub(P1, P3), P7));

      mats.printMat(A00, "A00");
      mats.printMat(A10, "A10"); 

      mats.printMat(B00, "B00");
      mats.printMat(B10, "B10");

      mats.printMat(C00, "C00");
      mats.printMat(C10, "C10");

      mats.printMat(A01, "A01");
      mats.printMat(A11, "A11");

      mats.printMat(B01, "B01");
      mats.printMat(B11, "B11");

      mats.printMat(C01, "C01");
      mats.printMat(C11, "C11");
   }
}

| 2  3  1  6 |        
| 4  0  0  2 |        | A00  A01 |
| 4  2  0  1 | = A =  | A10  A11 |
| 0  3  5  2 |

| 3  0  4  3 |
| 1  2  0  2 |        | B00  B01 |
| 0  3  1  4 | = B =  | B10  B11 |
| 5  1  3  2 |

| 39  15  27  28 |
| 22   2  22  16 |            | C00  C01 |
| 19   5  19  18 | = A * B =  | C10  C11 |
| 13  23  11  30 |
    
A00:
| 2  3 |
| 4  0 |

A10:
| 4  2 |
| 0  3 |

B00: | 3 0 | | 1 2 | B10: | 0 3 | | 5 1 |
C00: | 39 15 | | 22 2 | C10: | 19 5 | | 13 23 |
A01:
| 1  6 |
| 0  2 |

A11:
| 0  1 |
| 5  2 |

B01: | 4 3 | | 0 2 | B11: | 1 4 | | 3 2 |
C01: | 27 28 | | 22 16 | C11: | 19 18 | | 11 30 |


Revision date: 2012-09-27. (Please use ISO 8601, the International Standard Date and Time Notation.)