ViewVC Help
View File | Revision Log | Show Annotations | Download File | Root Listing
root/jsr166/jsr166/src/test/loops/LU.java
Revision: 1.1
Committed: Sun Sep 19 12:55:37 2010 UTC (13 years, 8 months ago) by dl
Branch: MAIN
Log Message:
Add and update FJ and Queue tests

File Contents

# Content
1 /*
2 * Written by Doug Lea with assistance from members of JCP JSR-166
3 * Expert Group and released to the public domain, as explained at
4 * http://creativecommons.org/licenses/publicdomain
5 */
6
7 //import jsr166y.*;
8 import java.util.concurrent.*;
9 import java.util.concurrent.TimeUnit;
10
11
12 /**
13 * LU matrix decomposition demo
14 * Based on those in Cilk and Hood
15 **/
16
17 public final class LU {
18
19 /** for time conversion */
20 static final long NPS = (1000L * 1000 * 1000);
21
22 // granularity is hard-wired as compile-time constant here
23 static final int BLOCK_SIZE = 16;
24 static final boolean CHECK = false; // set true to check answer
25
26 public static void main(String[] args) throws Exception {
27
28 final String usage = "Usage: java LU <threads> <matrix size (must be a power of two)> [runs] \n For example, try java LU 2 512";
29
30 int procs = 0;
31 int n = 2048;
32 int runs = 5;
33 try {
34 if (args.length > 0)
35 procs = Integer.parseInt(args[0]);
36 if (args.length > 1)
37 n = Integer.parseInt(args[1]);
38 if (args.length > 2)
39 runs = Integer.parseInt(args[2]);
40 } catch (Exception e) {
41 System.out.println(usage);
42 return;
43 }
44
45 if ( ((n & (n - 1)) != 0)) {
46 System.out.println(usage);
47 return;
48 }
49 ForkJoinPool pool = procs == 0? new ForkJoinPool() :
50 new ForkJoinPool(procs);
51 System.out.println("procs: " + pool.getParallelism() +
52 " n: " + n + " runs: " + runs);
53 for (int run = 0; run < runs; ++run) {
54 double[][] m = new double[n][n];
55 randomInit(m, n);
56 double[][] copy = null;
57 if (CHECK) {
58 copy = new double[n][n];
59 for (int i = 0; i < n; ++i) {
60 for (int j = 0; j < n; ++j) {
61 copy[i][j] = m[i][j];
62 }
63 }
64 }
65 Block M = new Block(m, 0, 0);
66 long start = System.nanoTime();
67 pool.invoke(new LowerUpper(n, M));
68 long time = System.nanoTime() - start;
69 double secs = ((double)time) / NPS;
70 System.out.printf("\tTime: %7.3f\n", secs);
71
72 if (CHECK) check(m, copy, n);
73 }
74 System.out.println(pool.toString());
75 pool.shutdown();
76 }
77
78
79 static void randomInit(double[][] M, int n) {
80 java.util.Random rng = new java.util.Random();
81 for (int i = 0; i < n; ++i)
82 for (int j = 0; j < n; ++j)
83 M[i][j] = rng.nextDouble();
84 // for compatibility with hood demo, force larger diagonals
85 for (int k = 0; k < n; ++k)
86 M[k][k] *= 10.0;
87 }
88
89 static void check(double[][] LU, double[][] M, int n) {
90 double maxDiff = 0.0; // track max difference
91 for (int i = 0; i < n; ++i) {
92 for (int j = 0; j < n; ++j) {
93 double v = 0.0;
94 int k;
95 for (k = 0; k < i && k <= j; k++ ) v += LU[i][k] * LU[k][j];
96 if (k == i && k <= j ) v += LU[k][j];
97 double diff = M[i][j] - v;
98 if (diff < 0) diff = -diff;
99 if (diff > 0.001) {
100 System.out.println("large diff at[" + i + "," + j + "]: " + M[i][j] + " vs " + v);
101 }
102 if (diff > maxDiff) maxDiff = diff;
103 }
104 }
105
106 System.out.println("Max difference = " + maxDiff);
107 }
108
109
110 // Blocks record underlying matrix, and offsets into current block
111 static final class Block {
112 final double[][] m;
113 final int loRow;
114 final int loCol;
115
116 Block(double[][] mat, int lr, int lc) {
117 m = mat; loRow = lr; loCol = lc;
118 }
119 }
120
121 static final class Schur extends RecursiveAction {
122 final int size;
123 final Block V;
124 final Block W;
125 final Block M;
126
127 Schur(int size, Block V, Block W, Block M) {
128 this.size = size; this.V = V; this.W = W; this.M = M;
129 }
130
131 void schur() { // base case
132 for (int j = 0; j < BLOCK_SIZE; ++j) {
133 for (int i = 0; i < BLOCK_SIZE; ++i) {
134 double s = M.m[i+M.loRow][j+M.loCol];
135 for (int k = 0; k < BLOCK_SIZE; ++k) {
136 s -= V.m[i+V.loRow][k+V.loCol] * W.m[k+W.loRow][j+W.loCol];
137 }
138 M.m[i+M.loRow][j+M.loCol] = s;
139 }
140 }
141 }
142
143 public void compute() {
144 if (size == BLOCK_SIZE) {
145 schur();
146 }
147 else {
148 int h = size / 2;
149
150 Block M00 = new Block(M.m, M.loRow, M.loCol);
151 Block M01 = new Block(M.m, M.loRow, M.loCol+h);
152 Block M10 = new Block(M.m, M.loRow+h, M.loCol);
153 Block M11 = new Block(M.m, M.loRow+h, M.loCol+h);
154
155 Block V00 = new Block(V.m, V.loRow, V.loCol);
156 Block V01 = new Block(V.m, V.loRow, V.loCol+h);
157 Block V10 = new Block(V.m, V.loRow+h, V.loCol);
158 Block V11 = new Block(V.m, V.loRow+h, V.loCol+h);
159
160 Block W00 = new Block(W.m, W.loRow, W.loCol);
161 Block W01 = new Block(W.m, W.loRow, W.loCol+h);
162 Block W10 = new Block(W.m, W.loRow+h, W.loCol);
163 Block W11 = new Block(W.m, W.loRow+h, W.loCol+h);
164
165 Seq2 s3 = seq(new Schur(h, V10, W01, M11),
166 new Schur(h, V11, W11, M11));
167 s3.fork();
168 Seq2 s2 = seq(new Schur(h, V10, W00, M10),
169 new Schur(h, V11, W10, M10));
170 s2.fork();
171 Seq2 s1 = seq(new Schur(h, V00, W01, M01),
172 new Schur(h, V01, W11, M01));
173 s1.fork();
174 new Schur(h, V00, W00, M00).compute();
175 new Schur(h, V01, W10, M00).compute();
176 if (s1.tryUnfork()) s1.compute(); else s1.join();
177 if (s2.tryUnfork()) s2.compute(); else s2.join();
178 if (s3.tryUnfork()) s3.compute(); else s3.join();
179 }
180 }
181 }
182
183 static final class Lower extends RecursiveAction {
184 final int size;
185 final Block L;
186 final Block M;
187 Lower(int size, Block L, Block M) {
188 this.size = size; this.L = L; this.M = M;
189 }
190
191 void lower() { // base case
192 for (int i = 1; i < BLOCK_SIZE; ++i) {
193 for (int k = 0; k < i; ++k) {
194 double a = L.m[i+L.loRow][k+L.loCol];
195 double[] x = M.m[k+M.loRow];
196 double[] y = M.m[i+M.loRow];
197 int n = BLOCK_SIZE;
198 for (int p = n-1; p >= 0; --p) {
199 y[p+M.loCol] -= a * x[p+M.loCol];
200 }
201 }
202 }
203 }
204
205 public void compute() {
206 if (size == BLOCK_SIZE) {
207 lower();
208 }
209 else {
210 int h = size / 2;
211
212 Block M00 = new Block(M.m, M.loRow, M.loCol);
213 Block M01 = new Block(M.m, M.loRow, M.loCol+h);
214 Block M10 = new Block(M.m, M.loRow+h, M.loCol);
215 Block M11 = new Block(M.m, M.loRow+h, M.loCol+h);
216
217 Block L00 = new Block(L.m, L.loRow, L.loCol);
218 Block L01 = new Block(L.m, L.loRow, L.loCol+h);
219 Block L10 = new Block(L.m, L.loRow+h, L.loCol);
220 Block L11 = new Block(L.m, L.loRow+h, L.loCol+h);
221
222
223 Seq3 s1 =
224 seq(new Lower(h, L00, M00),
225 new Schur(h, L10, M00, M10),
226 new Lower(h, L11, M10));
227 Seq3 s2 =
228 seq(new Lower(h, L00, M01),
229 new Schur(h, L10, M01, M11),
230 new Lower(h, L11, M11));
231 s2.fork();
232 s1.compute();
233 if (s2.tryUnfork()) s2.compute(); else s2.join();
234 }
235 }
236 }
237
238
239 static final class Upper extends RecursiveAction {
240 final int size;
241 final Block U;
242 final Block M;
243 Upper(int size, Block U, Block M) {
244 this.size = size; this.U = U; this.M = M;
245 }
246
247 void upper() { // base case
248 for (int i = 0; i < BLOCK_SIZE; ++i) {
249 for (int k = 0; k < BLOCK_SIZE; ++k) {
250 double a = M.m[i+M.loRow][k+M.loCol] / U.m[k+U.loRow][k+U.loCol];
251 M.m[i+M.loRow][k+M.loCol] = a;
252 double[] x = U.m[k+U.loRow];
253 double[] y = M.m[i+M.loRow];
254 int n = BLOCK_SIZE - k - 1;
255 for (int p = n - 1; p >= 0; --p) {
256 y[p+k+1+M.loCol] -= a * x[p+k+1+U.loCol];
257 }
258 }
259 }
260 }
261
262
263 public void compute() {
264 if (size == BLOCK_SIZE) {
265 upper();
266 }
267 else {
268 int h = size / 2;
269
270 Block M00 = new Block(M.m, M.loRow, M.loCol);
271 Block M01 = new Block(M.m, M.loRow, M.loCol+h);
272 Block M10 = new Block(M.m, M.loRow+h, M.loCol);
273 Block M11 = new Block(M.m, M.loRow+h, M.loCol+h);
274
275 Block U00 = new Block(U.m, U.loRow, U.loCol);
276 Block U01 = new Block(U.m, U.loRow, U.loCol+h);
277 Block U10 = new Block(U.m, U.loRow+h, U.loCol);
278 Block U11 = new Block(U.m, U.loRow+h, U.loCol+h);
279
280 Seq3 s1 =
281 seq(new Upper(h, U00, M00),
282 new Schur(h, M00, U01, M01),
283 new Upper(h, U11, M01));
284 Seq3 s2 =
285 seq(new Upper(h, U00, M10),
286 new Schur(h, M10, U01, M11),
287 new Upper(h, U11, M11));
288 s2.fork();
289 s1.compute();
290 if (s2.tryUnfork()) s2.compute(); else s2.join();
291 }
292 }
293 }
294
295
296 static final class LowerUpper extends RecursiveAction {
297 final int size;
298 final Block M;
299 LowerUpper(int size, Block M) {
300 this.size = size; this.M = M;
301 }
302
303 void lu() { // base case
304 for (int k = 0; k < BLOCK_SIZE; ++k) {
305 for (int i = k+1; i < BLOCK_SIZE; ++i) {
306 double b = M.m[k+M.loRow][k+M.loCol];
307 double a = M.m[i+M.loRow][k+M.loCol] / b;
308 M.m[i+M.loRow][k+M.loCol] = a;
309 double[] x = M.m[k+M.loRow];
310 double[] y = M.m[i+M.loRow];
311 int n = BLOCK_SIZE-k-1;
312 for (int p = n-1; p >= 0; --p) {
313 y[k+1+p+M.loCol] -= a * x[k+1+p+M.loCol];
314 }
315 }
316 }
317 }
318
319 public void compute() {
320 if (size == BLOCK_SIZE) {
321 lu();
322 }
323 else {
324 int h = size / 2;
325 Block M00 = new Block(M.m, M.loRow, M.loCol);
326 Block M01 = new Block(M.m, M.loRow, M.loCol+h);
327 Block M10 = new Block(M.m, M.loRow+h, M.loCol);
328 Block M11 = new Block(M.m, M.loRow+h, M.loCol+h);
329
330 new LowerUpper(h, M00).compute();
331 Lower sl = new Lower(h, M00, M01);
332 Upper su = new Upper(h, M00, M10);
333 su.fork();
334 sl.compute();
335 if (su.tryUnfork()) su.compute(); else su.join();
336 new Schur(h, M10, M01, M11).compute();
337 new LowerUpper(h, M11).compute();
338 }
339 }
340 }
341
342 static Seq2 seq(RecursiveAction task1,
343 RecursiveAction task2) {
344 return new Seq2(task1, task2);
345 }
346
347 static final class Seq2 extends RecursiveAction {
348 final RecursiveAction fst;
349 final RecursiveAction snd;
350 public Seq2(RecursiveAction task1, RecursiveAction task2) {
351 fst = task1;
352 snd = task2;
353 }
354 public void compute() {
355 fst.invoke();
356 snd.invoke();
357 }
358 }
359
360 static Seq3 seq(RecursiveAction task1,
361 RecursiveAction task2,
362 RecursiveAction task3) {
363 return new Seq3(task1, task2, task3);
364 }
365
366 static final class Seq3 extends RecursiveAction {
367 final RecursiveAction fst;
368 final RecursiveAction snd;
369 final RecursiveAction thr;
370 public Seq3(RecursiveAction task1,
371 RecursiveAction task2,
372 RecursiveAction task3) {
373 fst = task1;
374 snd = task2;
375 thr = task3;
376 }
377 public void compute() {
378 fst.invoke();
379 snd.invoke();
380 thr.invoke();
381 }
382 }
383
384
385
386 }
387