ViewVC Help
View File | Revision Log | Show Annotations | Download File | Root Listing
root/jsr166/jsr166/src/test/loops/MatrixMultiply.java
(Generate patch)

Comparing jsr166/src/test/loops/MatrixMultiply.java (file contents):
Revision 1.1 by dl, Sun Sep 19 12:55:37 2010 UTC vs.
Revision 1.9 by dl, Sat Sep 12 19:39:26 2015 UTC

# Line 1 | Line 1
1   /*
2   * Written by Doug Lea with assistance from members of JCP JSR-166
3   * Expert Group and released to the public domain, as explained at
4 < * http://creativecommons.org/licenses/publicdomain
4 > * http://creativecommons.org/publicdomain/zero/1.0/
5   */
6  
7   //import jsr166y.*;
8   import java.util.concurrent.*;
9 import java.util.concurrent.TimeUnit;
10
9  
10   /**
11   * Divide and Conquer matrix multiply demo
12 < **/
15 <
12 > */
13   public class MatrixMultiply {
14  
15      /** for time conversion */
16      static final long NPS = (1000L * 1000 * 1000);
17  
18 <    static final int DEFAULT_GRANULARITY = 32;
18 >    static final int DEFAULT_GRANULARITY = 16; // 32;
19  
20 <    /** The quadrant size at which to stop recursing down
20 >    /**
21 >     * The quadrant size at which to stop recursing down
22       * and instead directly multiply the matrices.
23       * Must be a power of two. Minimum value is 2.
24 <     **/
24 >     */
25      static int granularity = DEFAULT_GRANULARITY;
26  
27      public static void main(String[] args) throws Exception {
# Line 32 | Line 30 | public class MatrixMultiply {
30  
31          int procs = 0;
32          int n = 2048;
33 <        int runs = 5;
33 >        int runs = 32;
34          try {
35              if (args.length > 0)
36                  procs = Integer.parseInt(args[0]);
37              if (args.length > 1)
38                  n = Integer.parseInt(args[1]);
39 <            if (args.length > 2)
39 >            if (args.length > 2)
40                  granularity = Integer.parseInt(args[2]);
41 <            if (args.length > 3)
41 >            if (args.length > 3)
42                  runs = Integer.parseInt(args[2]);
43          }
44 <    
44 >
45          catch (Exception e) {
46              System.out.println(usage);
47              return;
48          }
49 <    
50 <        if ( ((n & (n - 1)) != 0) ||
49 >
50 >        if ( ((n & (n - 1)) != 0) ||
51               ((granularity & (granularity - 1)) != 0) ||
52               granularity < 2) {
53              System.out.println(usage);
54              return;
55          }
56 <    
57 <        ForkJoinPool pool = procs == 0? new ForkJoinPool() :
56 >
57 >        ForkJoinPool pool = (procs == 0) ? ForkJoinPool.commonPool() :
58              new ForkJoinPool(procs);
59 <        System.out.println("procs: " + pool.getParallelism() +
59 >        System.out.println("procs: " + pool.getParallelism() +
60                             " n: " + n + " granularity: " + granularity +
61                             " runs: " + runs);
62 <    
62 >
63          float[][] a = new float[n][n];
64          float[][] b = new float[n][n];
65          float[][] c = new float[n][n];
66 <    
66 >
67          for (int i = 0; i < runs; ++i) {
68              init(a, b, n);
69              long start = System.nanoTime();
70 <            pool.invoke(new Multiplier(a, 0, 0, b, 0, 0, c, 0, 0, n));
70 >            new Multiplier(a, 0, 0, b, 0, 0, c, 0, 0, n).invoke();
71              long time = System.nanoTime() - start;
72              double secs = ((double)time) / NPS;
73              Thread.sleep(100);
74 <            System.out.printf("\tTime: %7.3f\n", secs);
74 >            System.out.printf("Time: %7.3f ", secs);
75 >            if ((i & 3) == 3) System.out.println();
76              // check(c, n);
77          }
78          System.out.println(pool.toString());
79 <        pool.shutdown();
79 >        if (pool != ForkJoinPool.commonPool())
80 >            pool.shutdown();
81 >        Thread.sleep(100);
82      }
83  
83
84      // To simplify checking, fill with all 1's. Answer should be all n's.
85      static void init(float[][] a, float[][] b, int n) {
86          for (int i = 0; i < n; ++i) {
# Line 101 | Line 101 | public class MatrixMultiply {
101          }
102      }
103  
104 <    /**
104 >    /**
105       * Multiply matrices AxB by dividing into quadrants, using algorithm:
106       * <pre>
107 <     *      A      x      B                            
107 >     *      A      x      B
108       *
109 <     *  A11 | A12     B11 | B12     A11*B11 | A11*B12     A12*B21 | A12*B22
109 >     *  A11 | A12     B11 | B12     A11*B11 | A11*B12     A12*B21 | A12*B22
110       * |----+----| x |----+----| = |--------+--------| + |---------+-------|
111 <     *  A21 | A22     B21 | B21     A21*B11 | A21*B21     A22*B21 | A22*B22
111 >     *  A21 | A22     B21 | B21     A21*B11 | A21*B21     A22*B21 | A22*B22
112       * </pre>
113       */
114
115
114      static class Multiplier extends RecursiveAction {
115          final float[][] A;   // Matrix A
116          final int aRow;      // first row    of current quadrant of A
# Line 127 | Line 125 | public class MatrixMultiply {
125          final int cCol;
126  
127          final int size;      // number of elements in current quadrant
128 <    
128 >
129          Multiplier(float[][] A, int aRow, int aCol,
130                     float[][] B, int bRow, int bCol,
131                     float[][] C, int cRow, int cCol,
# Line 156 | Line 154 | public class MatrixMultiply {
154                                         B, bRow+h, bCol,    // B21
155                                         C, cRow,   cCol,    // C11
156                                         h)),
157 <            
157 >
158                      seq(new Multiplier(A, aRow,   aCol,    // A11
159                                         B, bRow,   bCol+h,  // B12
160                                         C, cRow,   cCol+h,  // C12
# Line 165 | Line 163 | public class MatrixMultiply {
163                                         B, bRow+h, bCol+h,  // B22
164                                         C, cRow,   cCol+h,  // C12
165                                         h)),
166 <          
166 >
167                      seq(new Multiplier(A, aRow+h, aCol,    // A21
168                                         B, bRow,   bCol,    // B11
169                                         C, cRow+h, cCol,    // C21
# Line 174 | Line 172 | public class MatrixMultiply {
172                                         B, bRow+h, bCol,    // B21
173                                         C, cRow+h, cCol,    // C21
174                                         h)),
175 <          
175 >
176                      seq(new Multiplier(A, aRow+h, aCol,    // A21
177                                         B, bRow,   bCol+h,  // B12
178                                         C, cRow+h, cCol+h,  // C22
# Line 187 | Line 185 | public class MatrixMultiply {
185              }
186          }
187  
188 <        /**
188 >        /**
189           * Version of matrix multiplication that steps 2 rows and columns
190           * at a time. Adapted from Cilk demos.
191           * Note that the results are added into C, not just set into C.
192           * This works well here because Java array elements
193           * are created with all zero values.
194 <         **/
197 <
194 >         */
195          void multiplyStride2() {
196              for (int j = 0; j < size; j+=2) {
197                  for (int i = 0; i < size; i +=2) {
198  
199                      float[] a0 = A[aRow+i];
200                      float[] a1 = A[aRow+i+1];
201 <        
202 <                    float s00 = 0.0F;
203 <                    float s01 = 0.0F;
204 <                    float s10 = 0.0F;
205 <                    float s11 = 0.0F;
201 >
202 >                    float s00 = 0.0F;
203 >                    float s01 = 0.0F;
204 >                    float s10 = 0.0F;
205 >                    float s11 = 0.0F;
206  
207                      for (int k = 0; k < size; k+=2) {
208  
# Line 234 | Line 231 | public class MatrixMultiply {
231  
232      }
233  
234 <    static Seq2 seq(RecursiveAction task1,
235 <                    RecursiveAction task2) {
236 <        return new Seq2(task1, task2);
234 >    static Seq2 seq(RecursiveAction task1,
235 >                    RecursiveAction task2) {
236 >        return new Seq2(task1, task2);
237      }
238  
239      static final class Seq2 extends RecursiveAction {
# Line 251 | Line 248 | public class MatrixMultiply {
248              snd.invoke();
249          }
250      }
254
255
251   }

Diff Legend

Removed lines
+ Added lines
< Changed lines
> Changed lines