ViewVC Help
View File | Revision Log | Show Annotations | Download File | Root Listing
root/jsr166/jsr166/src/test/loops/MatrixMultiply.java
Revision: 1.5
Committed: Tue Mar 15 19:47:05 2011 UTC (13 years, 2 months ago) by jsr166
Branch: MAIN
CVS Tags: release-1_7_0
Changes since 1.4: +1 -1 lines
Log Message:
Update Creative Commons license URL in legal notices

File Contents

# User Rev Content
1 dl 1.1 /*
2     * Written by Doug Lea with assistance from members of JCP JSR-166
3     * Expert Group and released to the public domain, as explained at
4 jsr166 1.5 * http://creativecommons.org/publicdomain/zero/1.0/
5 dl 1.1 */
6    
7     //import jsr166y.*;
8     import java.util.concurrent.*;
9     import java.util.concurrent.TimeUnit;
10    
11    
12     /**
13     * Divide and Conquer matrix multiply demo
14 jsr166 1.3 */
15 dl 1.1 public class MatrixMultiply {
16    
17     /** for time conversion */
18     static final long NPS = (1000L * 1000 * 1000);
19    
20     static final int DEFAULT_GRANULARITY = 32;
21    
22 jsr166 1.3 /**
23     * The quadrant size at which to stop recursing down
24 dl 1.1 * and instead directly multiply the matrices.
25     * Must be a power of two. Minimum value is 2.
26 jsr166 1.3 */
27 dl 1.1 static int granularity = DEFAULT_GRANULARITY;
28    
29     public static void main(String[] args) throws Exception {
30    
31     final String usage = "Usage: java MatrixMultiply <threads> <matrix size (must be a power of two)> [<granularity>] \n Size and granularity must be powers of two.\n For example, try java MatrixMultiply 2 512 16";
32    
33     int procs = 0;
34     int n = 2048;
35     int runs = 5;
36     try {
37     if (args.length > 0)
38     procs = Integer.parseInt(args[0]);
39     if (args.length > 1)
40     n = Integer.parseInt(args[1]);
41 jsr166 1.2 if (args.length > 2)
42 dl 1.1 granularity = Integer.parseInt(args[2]);
43 jsr166 1.2 if (args.length > 3)
44 dl 1.1 runs = Integer.parseInt(args[2]);
45     }
46 jsr166 1.2
47 dl 1.1 catch (Exception e) {
48     System.out.println(usage);
49     return;
50     }
51 jsr166 1.2
52     if ( ((n & (n - 1)) != 0) ||
53 dl 1.1 ((granularity & (granularity - 1)) != 0) ||
54     granularity < 2) {
55     System.out.println(usage);
56     return;
57     }
58 jsr166 1.2
59 jsr166 1.4 ForkJoinPool pool = (procs == 0) ? new ForkJoinPool() :
60 dl 1.1 new ForkJoinPool(procs);
61 jsr166 1.2 System.out.println("procs: " + pool.getParallelism() +
62 dl 1.1 " n: " + n + " granularity: " + granularity +
63     " runs: " + runs);
64 jsr166 1.2
65 dl 1.1 float[][] a = new float[n][n];
66     float[][] b = new float[n][n];
67     float[][] c = new float[n][n];
68 jsr166 1.2
69 dl 1.1 for (int i = 0; i < runs; ++i) {
70     init(a, b, n);
71     long start = System.nanoTime();
72     pool.invoke(new Multiplier(a, 0, 0, b, 0, 0, c, 0, 0, n));
73     long time = System.nanoTime() - start;
74     double secs = ((double)time) / NPS;
75     Thread.sleep(100);
76     System.out.printf("\tTime: %7.3f\n", secs);
77     // check(c, n);
78     }
79     System.out.println(pool.toString());
80     pool.shutdown();
81     }
82    
83    
84     // To simplify checking, fill with all 1's. Answer should be all n's.
85     static void init(float[][] a, float[][] b, int n) {
86     for (int i = 0; i < n; ++i) {
87     for (int j = 0; j < n; ++j) {
88     a[i][j] = 1.0F;
89     b[i][j] = 1.0F;
90     }
91     }
92     }
93    
94     static void check(float[][] c, int n) {
95     for (int i = 0; i < n; i++ ) {
96     for (int j = 0; j < n; j++ ) {
97     if (c[i][j] != n) {
98     throw new Error("Check Failed at [" + i +"]["+j+"]: " + c[i][j]);
99     }
100     }
101     }
102     }
103    
104 jsr166 1.2 /**
105 dl 1.1 * Multiply matrices AxB by dividing into quadrants, using algorithm:
106     * <pre>
107 jsr166 1.2 * A x B
108 dl 1.1 *
109 jsr166 1.2 * A11 | A12 B11 | B12 A11*B11 | A11*B12 A12*B21 | A12*B22
110 dl 1.1 * |----+----| x |----+----| = |--------+--------| + |---------+-------|
111 jsr166 1.2 * A21 | A22 B21 | B21 A21*B11 | A21*B21 A22*B21 | A22*B22
112 dl 1.1 * </pre>
113     */
114    
115    
116     static class Multiplier extends RecursiveAction {
117     final float[][] A; // Matrix A
118     final int aRow; // first row of current quadrant of A
119     final int aCol; // first column of current quadrant of A
120    
121     final float[][] B; // Similarly for B
122     final int bRow;
123     final int bCol;
124    
125     final float[][] C; // Similarly for result matrix C
126     final int cRow;
127     final int cCol;
128    
129     final int size; // number of elements in current quadrant
130 jsr166 1.2
131 dl 1.1 Multiplier(float[][] A, int aRow, int aCol,
132     float[][] B, int bRow, int bCol,
133     float[][] C, int cRow, int cCol,
134     int size) {
135     this.A = A; this.aRow = aRow; this.aCol = aCol;
136     this.B = B; this.bRow = bRow; this.bCol = bCol;
137     this.C = C; this.cRow = cRow; this.cCol = cCol;
138     this.size = size;
139     }
140    
141     public void compute() {
142    
143     if (size <= granularity) {
144     multiplyStride2();
145     }
146    
147     else {
148     int h = size / 2;
149    
150     invokeAll(new Seq2[] {
151     seq(new Multiplier(A, aRow, aCol, // A11
152     B, bRow, bCol, // B11
153     C, cRow, cCol, // C11
154     h),
155     new Multiplier(A, aRow, aCol+h, // A12
156     B, bRow+h, bCol, // B21
157     C, cRow, cCol, // C11
158     h)),
159 jsr166 1.2
160 dl 1.1 seq(new Multiplier(A, aRow, aCol, // A11
161     B, bRow, bCol+h, // B12
162     C, cRow, cCol+h, // C12
163     h),
164     new Multiplier(A, aRow, aCol+h, // A12
165     B, bRow+h, bCol+h, // B22
166     C, cRow, cCol+h, // C12
167     h)),
168 jsr166 1.2
169 dl 1.1 seq(new Multiplier(A, aRow+h, aCol, // A21
170     B, bRow, bCol, // B11
171     C, cRow+h, cCol, // C21
172     h),
173     new Multiplier(A, aRow+h, aCol+h, // A22
174     B, bRow+h, bCol, // B21
175     C, cRow+h, cCol, // C21
176     h)),
177 jsr166 1.2
178 dl 1.1 seq(new Multiplier(A, aRow+h, aCol, // A21
179     B, bRow, bCol+h, // B12
180     C, cRow+h, cCol+h, // C22
181     h),
182     new Multiplier(A, aRow+h, aCol+h, // A22
183     B, bRow+h, bCol+h, // B22
184     C, cRow+h, cCol+h, // C22
185     h))
186     });
187     }
188     }
189    
190 jsr166 1.2 /**
191 dl 1.1 * Version of matrix multiplication that steps 2 rows and columns
192     * at a time. Adapted from Cilk demos.
193     * Note that the results are added into C, not just set into C.
194     * This works well here because Java array elements
195     * are created with all zero values.
196 jsr166 1.3 */
197 dl 1.1 void multiplyStride2() {
198     for (int j = 0; j < size; j+=2) {
199     for (int i = 0; i < size; i +=2) {
200    
201     float[] a0 = A[aRow+i];
202     float[] a1 = A[aRow+i+1];
203 jsr166 1.2
204     float s00 = 0.0F;
205     float s01 = 0.0F;
206     float s10 = 0.0F;
207     float s11 = 0.0F;
208 dl 1.1
209     for (int k = 0; k < size; k+=2) {
210    
211     float[] b0 = B[bRow+k];
212    
213     s00 += a0[aCol+k] * b0[bCol+j];
214     s10 += a1[aCol+k] * b0[bCol+j];
215     s01 += a0[aCol+k] * b0[bCol+j+1];
216     s11 += a1[aCol+k] * b0[bCol+j+1];
217    
218     float[] b1 = B[bRow+k+1];
219    
220     s00 += a0[aCol+k+1] * b1[bCol+j];
221     s10 += a1[aCol+k+1] * b1[bCol+j];
222     s01 += a0[aCol+k+1] * b1[bCol+j+1];
223     s11 += a1[aCol+k+1] * b1[bCol+j+1];
224     }
225    
226     C[cRow+i] [cCol+j] += s00;
227     C[cRow+i] [cCol+j+1] += s01;
228     C[cRow+i+1][cCol+j] += s10;
229     C[cRow+i+1][cCol+j+1] += s11;
230     }
231     }
232     }
233    
234     }
235    
236 jsr166 1.2 static Seq2 seq(RecursiveAction task1,
237     RecursiveAction task2) {
238     return new Seq2(task1, task2);
239 dl 1.1 }
240    
241     static final class Seq2 extends RecursiveAction {
242     final RecursiveAction fst;
243     final RecursiveAction snd;
244     public Seq2(RecursiveAction task1, RecursiveAction task2) {
245     fst = task1;
246     snd = task2;
247     }
248     public void compute() {
249     fst.invoke();
250     snd.invoke();
251     }
252     }
253    
254    
255     }