ViewVC Help
View File | Revision Log | Show Annotations | Download File | Root Listing
root/jsr166/jsr166/src/test/loops/MatrixMultiply.java
Revision: 1.5
Committed: Tue Mar 15 19:47:05 2011 UTC (13 years, 2 months ago) by jsr166
Branch: MAIN
CVS Tags: release-1_7_0
Changes since 1.4: +1 -1 lines
Log Message:
Update Creative Commons license URL in legal notices

File Contents

# Content
1 /*
2 * Written by Doug Lea with assistance from members of JCP JSR-166
3 * Expert Group and released to the public domain, as explained at
4 * http://creativecommons.org/publicdomain/zero/1.0/
5 */
6
7 //import jsr166y.*;
8 import java.util.concurrent.*;
9 import java.util.concurrent.TimeUnit;
10
11
12 /**
13 * Divide and Conquer matrix multiply demo
14 */
15 public class MatrixMultiply {
16
17 /** for time conversion */
18 static final long NPS = (1000L * 1000 * 1000);
19
20 static final int DEFAULT_GRANULARITY = 32;
21
22 /**
23 * The quadrant size at which to stop recursing down
24 * and instead directly multiply the matrices.
25 * Must be a power of two. Minimum value is 2.
26 */
27 static int granularity = DEFAULT_GRANULARITY;
28
29 public static void main(String[] args) throws Exception {
30
31 final String usage = "Usage: java MatrixMultiply <threads> <matrix size (must be a power of two)> [<granularity>] \n Size and granularity must be powers of two.\n For example, try java MatrixMultiply 2 512 16";
32
33 int procs = 0;
34 int n = 2048;
35 int runs = 5;
36 try {
37 if (args.length > 0)
38 procs = Integer.parseInt(args[0]);
39 if (args.length > 1)
40 n = Integer.parseInt(args[1]);
41 if (args.length > 2)
42 granularity = Integer.parseInt(args[2]);
43 if (args.length > 3)
44 runs = Integer.parseInt(args[2]);
45 }
46
47 catch (Exception e) {
48 System.out.println(usage);
49 return;
50 }
51
52 if ( ((n & (n - 1)) != 0) ||
53 ((granularity & (granularity - 1)) != 0) ||
54 granularity < 2) {
55 System.out.println(usage);
56 return;
57 }
58
59 ForkJoinPool pool = (procs == 0) ? new ForkJoinPool() :
60 new ForkJoinPool(procs);
61 System.out.println("procs: " + pool.getParallelism() +
62 " n: " + n + " granularity: " + granularity +
63 " runs: " + runs);
64
65 float[][] a = new float[n][n];
66 float[][] b = new float[n][n];
67 float[][] c = new float[n][n];
68
69 for (int i = 0; i < runs; ++i) {
70 init(a, b, n);
71 long start = System.nanoTime();
72 pool.invoke(new Multiplier(a, 0, 0, b, 0, 0, c, 0, 0, n));
73 long time = System.nanoTime() - start;
74 double secs = ((double)time) / NPS;
75 Thread.sleep(100);
76 System.out.printf("\tTime: %7.3f\n", secs);
77 // check(c, n);
78 }
79 System.out.println(pool.toString());
80 pool.shutdown();
81 }
82
83
84 // To simplify checking, fill with all 1's. Answer should be all n's.
85 static void init(float[][] a, float[][] b, int n) {
86 for (int i = 0; i < n; ++i) {
87 for (int j = 0; j < n; ++j) {
88 a[i][j] = 1.0F;
89 b[i][j] = 1.0F;
90 }
91 }
92 }
93
94 static void check(float[][] c, int n) {
95 for (int i = 0; i < n; i++ ) {
96 for (int j = 0; j < n; j++ ) {
97 if (c[i][j] != n) {
98 throw new Error("Check Failed at [" + i +"]["+j+"]: " + c[i][j]);
99 }
100 }
101 }
102 }
103
104 /**
105 * Multiply matrices AxB by dividing into quadrants, using algorithm:
106 * <pre>
107 * A x B
108 *
109 * A11 | A12 B11 | B12 A11*B11 | A11*B12 A12*B21 | A12*B22
110 * |----+----| x |----+----| = |--------+--------| + |---------+-------|
111 * A21 | A22 B21 | B21 A21*B11 | A21*B21 A22*B21 | A22*B22
112 * </pre>
113 */
114
115
116 static class Multiplier extends RecursiveAction {
117 final float[][] A; // Matrix A
118 final int aRow; // first row of current quadrant of A
119 final int aCol; // first column of current quadrant of A
120
121 final float[][] B; // Similarly for B
122 final int bRow;
123 final int bCol;
124
125 final float[][] C; // Similarly for result matrix C
126 final int cRow;
127 final int cCol;
128
129 final int size; // number of elements in current quadrant
130
131 Multiplier(float[][] A, int aRow, int aCol,
132 float[][] B, int bRow, int bCol,
133 float[][] C, int cRow, int cCol,
134 int size) {
135 this.A = A; this.aRow = aRow; this.aCol = aCol;
136 this.B = B; this.bRow = bRow; this.bCol = bCol;
137 this.C = C; this.cRow = cRow; this.cCol = cCol;
138 this.size = size;
139 }
140
141 public void compute() {
142
143 if (size <= granularity) {
144 multiplyStride2();
145 }
146
147 else {
148 int h = size / 2;
149
150 invokeAll(new Seq2[] {
151 seq(new Multiplier(A, aRow, aCol, // A11
152 B, bRow, bCol, // B11
153 C, cRow, cCol, // C11
154 h),
155 new Multiplier(A, aRow, aCol+h, // A12
156 B, bRow+h, bCol, // B21
157 C, cRow, cCol, // C11
158 h)),
159
160 seq(new Multiplier(A, aRow, aCol, // A11
161 B, bRow, bCol+h, // B12
162 C, cRow, cCol+h, // C12
163 h),
164 new Multiplier(A, aRow, aCol+h, // A12
165 B, bRow+h, bCol+h, // B22
166 C, cRow, cCol+h, // C12
167 h)),
168
169 seq(new Multiplier(A, aRow+h, aCol, // A21
170 B, bRow, bCol, // B11
171 C, cRow+h, cCol, // C21
172 h),
173 new Multiplier(A, aRow+h, aCol+h, // A22
174 B, bRow+h, bCol, // B21
175 C, cRow+h, cCol, // C21
176 h)),
177
178 seq(new Multiplier(A, aRow+h, aCol, // A21
179 B, bRow, bCol+h, // B12
180 C, cRow+h, cCol+h, // C22
181 h),
182 new Multiplier(A, aRow+h, aCol+h, // A22
183 B, bRow+h, bCol+h, // B22
184 C, cRow+h, cCol+h, // C22
185 h))
186 });
187 }
188 }
189
190 /**
191 * Version of matrix multiplication that steps 2 rows and columns
192 * at a time. Adapted from Cilk demos.
193 * Note that the results are added into C, not just set into C.
194 * This works well here because Java array elements
195 * are created with all zero values.
196 */
197 void multiplyStride2() {
198 for (int j = 0; j < size; j+=2) {
199 for (int i = 0; i < size; i +=2) {
200
201 float[] a0 = A[aRow+i];
202 float[] a1 = A[aRow+i+1];
203
204 float s00 = 0.0F;
205 float s01 = 0.0F;
206 float s10 = 0.0F;
207 float s11 = 0.0F;
208
209 for (int k = 0; k < size; k+=2) {
210
211 float[] b0 = B[bRow+k];
212
213 s00 += a0[aCol+k] * b0[bCol+j];
214 s10 += a1[aCol+k] * b0[bCol+j];
215 s01 += a0[aCol+k] * b0[bCol+j+1];
216 s11 += a1[aCol+k] * b0[bCol+j+1];
217
218 float[] b1 = B[bRow+k+1];
219
220 s00 += a0[aCol+k+1] * b1[bCol+j];
221 s10 += a1[aCol+k+1] * b1[bCol+j];
222 s01 += a0[aCol+k+1] * b1[bCol+j+1];
223 s11 += a1[aCol+k+1] * b1[bCol+j+1];
224 }
225
226 C[cRow+i] [cCol+j] += s00;
227 C[cRow+i] [cCol+j+1] += s01;
228 C[cRow+i+1][cCol+j] += s10;
229 C[cRow+i+1][cCol+j+1] += s11;
230 }
231 }
232 }
233
234 }
235
236 static Seq2 seq(RecursiveAction task1,
237 RecursiveAction task2) {
238 return new Seq2(task1, task2);
239 }
240
241 static final class Seq2 extends RecursiveAction {
242 final RecursiveAction fst;
243 final RecursiveAction snd;
244 public Seq2(RecursiveAction task1, RecursiveAction task2) {
245 fst = task1;
246 snd = task2;
247 }
248 public void compute() {
249 fst.invoke();
250 snd.invoke();
251 }
252 }
253
254
255 }