CS3343/3341
 Analysis of Algorithms 
   Matrix-chain  
  Multiplications   

Matrix-chain Multiplications: Matrix multiplication is not commutative, but it is associative. For matrices that are not square, the order of assiciation can make a big difference. Dynamic programming solves this problem (see your text, pages 370-378). This example has nothing to do with Strassen's method of matrix multiplication.

First, recall that if one wants to multiply two matrices, the number of rows of the second one must equal the number of columns of the first, so that they must look like: (p x q) * (q x r) = (p x r), because we keep taking a row of the first times (inner product) a column of the second (they have to "fit together'). In this example, we do p*r inner products, and each inner product involves q multiplications, so there are p*q*r multiplications altogether.

Following your text's example, consider a chain of three matrix multiplications: (10 x 100) * (100 x 5) * (5 x 50). There are two ways to associate this triple product, each giving the same answer:

This is an extreme example, but the two different way differ by a factor of 10. We can represent the sizes of this triple product by a simple list of 4 numbers: {10,100,5,50}. We know the remaining dimensions since the matrices must fit together.

A chain of a product of n matrices is described by a list of n+1 numbers. Your text's standard example is given by the list: {30,35,15,5,10,20,25}. Then a triangular table is constructed from the bottom up, choosing the best way to do the final multiplication at each stage. Here is an illustration from your text (page 376).


Click picture or here for full size picture.

We start with a list of matrices: {A1,A1,...,An} for some n >= 1. There is also a list of integers {p0,p1,...,pn} such for each i, Ai has dimensions pi-1 x pi.

Your text denotes by Ai..j, for i <= j, the matrix resulting from the chain multiplication Ai*Ai+1*...*Aj. This matrix is independent of the association, and we are only interested in the required number of scalar multiplications needed for this computation. In the diagram above, the table entry m[i,j] holds the minimum number of scalar multiplications needed to compute the matrix chain Ai..j. Of course, Ai..j = 0. Then Ai..i+1 is just the product of the two matrices Ai*Ai+1.

In case i < j, then however one associates (inserts parentheses), in calculating Ai..j there has to be a last product calculated. This means there is a k so that Ai..j = Ai..k*Ak..j. The number of multiplications in this case is given by m[i,k] + m[k+1,j] + pi-1*pk*pj. In order to calculate m[i,j], one has to take the minimum of the above expression for all k satisfying 1 <= k < j.

The image above from your text shows this process for m[2,5], the result of multiplying 4 matrices together. This calculation uses table entries calculated earlier, in a bottom-up fashion. (It would also be possible to calculate this recursively, but without memoization, this would be inefficient.) Just carry through this process and the result is as shown in the image above.


The Matrix-chain Multiplication Program:

Matrix Chain Multiplications
// Matrix chain multiplications.
// Input: sequence of matrix dimens, end in 0.
#include <stdio.h>
#define L 15  // big enough for most examples

int r(int i, int j, int s[L][L]);  // print =s
void putpre(int x);                // used by r
void putdig(int x);                // used by r
int newtemp(void);                 // used by r
void r2(int i, int j, int s[L][L]);// print par
void r3(int i, int j, int s[L][L]);// expression

void main(void) {
   int p[L];     // input array of dimensions
   int m[L][L];  // array of numbers of mults
   int s[L][L];  // array giving index opt sol
   int n = 0, i, ll, j, k, q;
   int res; // final result from seq of assigns
   for (i = 0; i < L; i++) {
      scanf("%i", &p[i]); // read up to a zero
      if (p[i] <= 0) break;
   }
   n = i - 1;
   // calculate matices m and s
   for (i = 1 ; i <= n; i++)
      m[i][i] = 0;
   for (ll = 2; ll <= n; ll++)
      for (i = 1; i <= n - ll + 1; i++) {
         j = i + ll - 1;
         m[i][j] = 444444444; // "infinity"
         for (k = i; k <= j - 1; k++) {
            q = m[i][k] + m[k+1][j] +
               p[i-1]*p[k]*p[j];
            if (q < m[i][j]) {
               m[i][j] = q;
               s[i][j] = k;
            }
         }
      }
   // print p
   printf("The array p:\n\n");
   for (i = 0; i <= n; i++)
      printf("p[%i] =%3i, ", i, p[i]);
   // print m
   printf("\n\nThe array m:\n\n       ");
   for (i = 1; i <= n; i++)
      printf(" i=%2i  ", i);
   printf("\n");
   for (j = n; j >= 1; j--) {
      printf("j=%2i:", j);
      for (i = 1; i <= j; i++)
         printf("%7i", m[i][j]);
      printf("\n");
   }
   // print s
   printf("\nThe array s:\n\n       ");
   for (i = 1; i < n; i++)
      printf(" i=%2i  ", i);
   printf("\n");
   for (j = n; j >= 2; j--) {
      printf("j=%2i:", j);
      for (i = 1; i < j; i++)
         printf("%7i", s[i][j]);
      printf("\n");
   }
   printf("\n");
   res = r(1, n, s);
   printf("Final result is in T");
   putdig(res); printf("\n\n");
   r2(1, n, s); printf("\n\n");
   r3(1, n, s); printf("\n");
}
// r: function that calcs series of assigns
int r(int i, int j, int s[L][L]) {
   int k, arg1, arg2, res;
   if (i == j) return i;
   k = s[i][j];         // top-level split
   arg1 = r(i, k, s);   // temp left half
   arg2 = r(k+1, j, s); // temp right half
   res = -newtemp();    // next temp
   // a negative number indicates a temp
   // next three lines output one equation
   putpre(res); putdig(res); putchar('=');
   putpre(arg1);putdig(arg1);putchar('*');
   putpre(arg2);putdig(arg2);putchar('\n');
   return res;
}

// putpre: neg is a temp; pos if an arg
void putpre(int x) {
   if (x < 0) putchar('T');
   else putchar('A');
}

// putdigit: spit one digit, without sign
void putdig(int x) {
   if (x < 0) x = -x;
   printf("%i", x);
}

// newtemp: return next integer in order
int newtemp(void) {
   static i = 1;
   return i++;
}

// r2: output paren form with extra parens
void r2(int i, int j, int s[L][L]) {
   int k;
   if (i == j) {
      printf("A%i", i);
      return;
   }
   k = s[i][j];
   printf("(");
   r2(i, k, s);
   printf(")*(");
   r2(k+1, j, s);
   printf(")");
}

// r3: output paren form with fewer parens
void r3(int i, int j, int s[L][L]) {
   int k;
   if (i == j) {
      printf("A%i", i);
      return;
   }
   printf("(");
   k = s[i][j];
   r3(i, k, s);
   printf("*");
   r3(k+1, j, s);
   printf(")");
}


Output: Here are four different outputs for different sequences of dimensions of the matrices:

The array p:

p[0] = 30, p[1] = 35, p[2] = 15, p[3] =  5,
p[4] = 10, p[5] = 20, p[6] = 25, 

The array m:

        i= 1   i= 2   i= 3   i= 4  i= 5  i= 6
j= 6:  15125  10500   5375   3500  5000     0
j= 5:  11875   7125   2500   1000     0
j= 4:   9375   4375    750      0
j= 3:   7875   2625      0
j= 2:  15750      0
j= 1:      0

The array s:

        i= 1   i= 2   i= 3   i= 4   i= 5  
j= 6:      3      3      3      5      5
j= 5:      3      3      3      4
j= 4:      3      3      3
j= 3:      1      2
j= 2:      1

T1=A2*A3
T2=A1*T1
T3=A4*A5
T4=T3*A6
T5=T2*T4
Final result is in T5

((A1)*((A2)*(A3)))*(((A4)*(A5))*(A6))

((A1*(A2*A3))*((A4*A5)*A6))
The array p:

p[0] =  5, p[1] = 10, p[2] =  3, p[3] = 12,
p[4] =  5, p[5] = 50, p[6] =  6, 

The array m:

        i= 1   i= 2   i= 3   i= 4  i= 5  i= 6
j= 6:   2010   1950   1770   1860  1500     0
j= 5:   1655   2430    930   3000     0
j= 4:    405    330    180      0
j= 3:    330    360      0
j= 2:    150      0
j= 1:      0

The array s:

        i= 1   i= 2   i= 3   i= 4   i= 5  
j= 6:      2      2      4      4      5
j= 5:      4      2      4      4
j= 4:      2      2      3
j= 3:      2      2
j= 2:      1

T1=A1*A2
T2=A3*A4
T3=A5*A6
T4=T2*T3
T5=T1*T4
Final result is in T5

((A1)*(A2))*(((A3)*(A4))*((A5)*(A6)))

((A1*A2)*((A3*A4)*(A5*A6)))
The array p:

p[0] = 20, p[1] = 25, p[2] =  5, p[3] = 10, p[4] = 30, p[5] = 15,
p[6] = 20, p[7] = 10, p[8] =  5, p[9] = 40, 

The array m:

        i= 1   i= 2   i= 3   i= 4   i= 5   i= 6   i= 7   i= 8   i= 9  
j= 9:  13500  12125   7500   8250  10750   5500   5000   2000      0
j= 8:   9500   7125   6500   6250   4750   2500   1000      0
j= 7:   9750   7500   6250   9000   7500   3000      0
j= 6:   9750   7750   5250   7500   9000      0
j= 5:   7750   5625   3750   4500      0
j= 4:   7000   5250   1500      0
j= 3:   3500   1250      0
j= 2:   2500      0
j= 1:      0

The array s:

        i= 1   i= 2   i= 3   i= 4   i= 5   i= 6   i= 7   i= 8  
j= 9:      8      8      8      8      8      8      8      8
j= 8:      2      2      3      4      5      6      7
j= 7:      2      2      6      5      5      6
j= 6:      2      2      5      5      5
j= 5:      2      2      4      4
j= 4:      2      2      3
j= 3:      2      2
j= 2:      1

T1=A1*A2
T2=A7*A8
T3=A6*T2
T4=A5*T3
T5=A4*T4
T6=A3*T5
T7=T1*T6
T8=T7*A9
Final result is in T8

(((A1)*(A2))*((A3)*((A4)*((A5)*((A6)*((A7)*(A8)))))))*(A9)

(((A1*A2)*(A3*(A4*(A5*(A6*(A7*A8))))))*A9)
The array p:

p[0] = 20, p[1] = 10, p[2] = 15, p[3] =  5, p[4] = 25, p[5] = 10, p[6] = 30, p[7] =  5,
p[8] = 10, p[9] = 20, p[10] = 25, p[11] = 15, 

The array m:

        i= 1   i= 2   i= 3   i= 4   i= 5   i= 6   i= 7   i= 8   i= 9   i=10   i=11  
j=11:  11875  10125   9750   8625  10000   7625   7625   5375   8750   7500      0
j=10:  11000   8750   8625   6750   9375   6250   7250   3500   5000      0
j= 9:   8000   6000   5750   4250   6250   3500   4000   1000      0
j= 8:   6000   4500   4000   3250   4000   2000   1500      0
j= 7:   5000   4000   3375   3000   2750   1500      0
j= 6:   7500   5000   5000   2750   7500      0
j= 5:   4000   2500   2000   1250      0
j= 4:   4250   2000   1875      0
j= 3:   1750    750      0
j= 2:   3000      0
j= 1:      0

The array s:

        i= 1   i= 2   i= 3   i= 4   i= 5   i= 6   i= 7   i= 8   i= 9   i=10  
j=11:      3      3      3     10      7      7      7     10     10     10
j=10:      3      3      3      9      7      7      7      9      9
j= 9:      3      3      3      8      7      7      7      8
j= 8:      3      3      3      7      7      7      7
j= 7:      1      3      3      5      5      6
j= 6:      3      3      3      5      5
j= 5:      3      3      3      4
j= 4:      3      3      3
j= 3:      1      2
j= 2:      1

T1=A2*A3
T2=A1*T1
T3=A4*A5
T4=A6*A7
T5=T3*T4
T6=T5*A8
T7=T6*A9
T8=T7*A10
T9=T8*A11
T10=T2*T9
Final result is in T10

((A1)*((A2)*(A3)))*(((((((A4)*(A5))*((A6)*(A7)))*(A8))*(A9))*(A10))*(A11))

((A1*(A2*A3))*((((((A4*A5)*(A6*A7))*A8)*A9)*A10)*A11))