ViewVC Help
View File | Revision Log | Show Annotations | Download File | Root Listing
root/jsr166/jsr166/src/test/loops/LU.java
Revision: 1.3
Committed: Sat Oct 16 16:22:57 2010 UTC (13 years, 7 months ago) by jsr166
Branch: MAIN
Changes since 1.2: +1 -2 lines
Log Message:
coding style

File Contents

# Content
1 /*
2 * Written by Doug Lea with assistance from members of JCP JSR-166
3 * Expert Group and released to the public domain, as explained at
4 * http://creativecommons.org/licenses/publicdomain
5 */
6
7 //import jsr166y.*;
8 import java.util.concurrent.*;
9 import java.util.concurrent.TimeUnit;
10
11
12 /**
13 * LU matrix decomposition demo
14 * Based on those in Cilk and Hood
15 */
16 public final class LU {
17
18 /** for time conversion */
19 static final long NPS = (1000L * 1000 * 1000);
20
21 // granularity is hard-wired as compile-time constant here
22 static final int BLOCK_SIZE = 16;
23 static final boolean CHECK = false; // set true to check answer
24
25 public static void main(String[] args) throws Exception {
26
27 final String usage = "Usage: java LU <threads> <matrix size (must be a power of two)> [runs] \n For example, try java LU 2 512";
28
29 int procs = 0;
30 int n = 2048;
31 int runs = 5;
32 try {
33 if (args.length > 0)
34 procs = Integer.parseInt(args[0]);
35 if (args.length > 1)
36 n = Integer.parseInt(args[1]);
37 if (args.length > 2)
38 runs = Integer.parseInt(args[2]);
39 } catch (Exception e) {
40 System.out.println(usage);
41 return;
42 }
43
44 if ( ((n & (n - 1)) != 0)) {
45 System.out.println(usage);
46 return;
47 }
48 ForkJoinPool pool = procs == 0? new ForkJoinPool() :
49 new ForkJoinPool(procs);
50 System.out.println("procs: " + pool.getParallelism() +
51 " n: " + n + " runs: " + runs);
52 for (int run = 0; run < runs; ++run) {
53 double[][] m = new double[n][n];
54 randomInit(m, n);
55 double[][] copy = null;
56 if (CHECK) {
57 copy = new double[n][n];
58 for (int i = 0; i < n; ++i) {
59 for (int j = 0; j < n; ++j) {
60 copy[i][j] = m[i][j];
61 }
62 }
63 }
64 Block M = new Block(m, 0, 0);
65 long start = System.nanoTime();
66 pool.invoke(new LowerUpper(n, M));
67 long time = System.nanoTime() - start;
68 double secs = ((double)time) / NPS;
69 System.out.printf("\tTime: %7.3f\n", secs);
70
71 if (CHECK) check(m, copy, n);
72 }
73 System.out.println(pool.toString());
74 pool.shutdown();
75 }
76
77
78 static void randomInit(double[][] M, int n) {
79 java.util.Random rng = new java.util.Random();
80 for (int i = 0; i < n; ++i)
81 for (int j = 0; j < n; ++j)
82 M[i][j] = rng.nextDouble();
83 // for compatibility with hood demo, force larger diagonals
84 for (int k = 0; k < n; ++k)
85 M[k][k] *= 10.0;
86 }
87
88 static void check(double[][] LU, double[][] M, int n) {
89 double maxDiff = 0.0; // track max difference
90 for (int i = 0; i < n; ++i) {
91 for (int j = 0; j < n; ++j) {
92 double v = 0.0;
93 int k;
94 for (k = 0; k < i && k <= j; k++ ) v += LU[i][k] * LU[k][j];
95 if (k == i && k <= j ) v += LU[k][j];
96 double diff = M[i][j] - v;
97 if (diff < 0) diff = -diff;
98 if (diff > 0.001) {
99 System.out.println("large diff at[" + i + "," + j + "]: " + M[i][j] + " vs " + v);
100 }
101 if (diff > maxDiff) maxDiff = diff;
102 }
103 }
104
105 System.out.println("Max difference = " + maxDiff);
106 }
107
108
109 // Blocks record underlying matrix, and offsets into current block
110 static final class Block {
111 final double[][] m;
112 final int loRow;
113 final int loCol;
114
115 Block(double[][] mat, int lr, int lc) {
116 m = mat; loRow = lr; loCol = lc;
117 }
118 }
119
120 static final class Schur extends RecursiveAction {
121 final int size;
122 final Block V;
123 final Block W;
124 final Block M;
125
126 Schur(int size, Block V, Block W, Block M) {
127 this.size = size; this.V = V; this.W = W; this.M = M;
128 }
129
130 void schur() { // base case
131 for (int j = 0; j < BLOCK_SIZE; ++j) {
132 for (int i = 0; i < BLOCK_SIZE; ++i) {
133 double s = M.m[i+M.loRow][j+M.loCol];
134 for (int k = 0; k < BLOCK_SIZE; ++k) {
135 s -= V.m[i+V.loRow][k+V.loCol] * W.m[k+W.loRow][j+W.loCol];
136 }
137 M.m[i+M.loRow][j+M.loCol] = s;
138 }
139 }
140 }
141
142 public void compute() {
143 if (size == BLOCK_SIZE) {
144 schur();
145 }
146 else {
147 int h = size / 2;
148
149 Block M00 = new Block(M.m, M.loRow, M.loCol);
150 Block M01 = new Block(M.m, M.loRow, M.loCol+h);
151 Block M10 = new Block(M.m, M.loRow+h, M.loCol);
152 Block M11 = new Block(M.m, M.loRow+h, M.loCol+h);
153
154 Block V00 = new Block(V.m, V.loRow, V.loCol);
155 Block V01 = new Block(V.m, V.loRow, V.loCol+h);
156 Block V10 = new Block(V.m, V.loRow+h, V.loCol);
157 Block V11 = new Block(V.m, V.loRow+h, V.loCol+h);
158
159 Block W00 = new Block(W.m, W.loRow, W.loCol);
160 Block W01 = new Block(W.m, W.loRow, W.loCol+h);
161 Block W10 = new Block(W.m, W.loRow+h, W.loCol);
162 Block W11 = new Block(W.m, W.loRow+h, W.loCol+h);
163
164 Seq2 s3 = seq(new Schur(h, V10, W01, M11),
165 new Schur(h, V11, W11, M11));
166 s3.fork();
167 Seq2 s2 = seq(new Schur(h, V10, W00, M10),
168 new Schur(h, V11, W10, M10));
169 s2.fork();
170 Seq2 s1 = seq(new Schur(h, V00, W01, M01),
171 new Schur(h, V01, W11, M01));
172 s1.fork();
173 new Schur(h, V00, W00, M00).compute();
174 new Schur(h, V01, W10, M00).compute();
175 if (s1.tryUnfork()) s1.compute(); else s1.join();
176 if (s2.tryUnfork()) s2.compute(); else s2.join();
177 if (s3.tryUnfork()) s3.compute(); else s3.join();
178 }
179 }
180 }
181
182 static final class Lower extends RecursiveAction {
183 final int size;
184 final Block L;
185 final Block M;
186 Lower(int size, Block L, Block M) {
187 this.size = size; this.L = L; this.M = M;
188 }
189
190 void lower() { // base case
191 for (int i = 1; i < BLOCK_SIZE; ++i) {
192 for (int k = 0; k < i; ++k) {
193 double a = L.m[i+L.loRow][k+L.loCol];
194 double[] x = M.m[k+M.loRow];
195 double[] y = M.m[i+M.loRow];
196 int n = BLOCK_SIZE;
197 for (int p = n-1; p >= 0; --p) {
198 y[p+M.loCol] -= a * x[p+M.loCol];
199 }
200 }
201 }
202 }
203
204 public void compute() {
205 if (size == BLOCK_SIZE) {
206 lower();
207 }
208 else {
209 int h = size / 2;
210
211 Block M00 = new Block(M.m, M.loRow, M.loCol);
212 Block M01 = new Block(M.m, M.loRow, M.loCol+h);
213 Block M10 = new Block(M.m, M.loRow+h, M.loCol);
214 Block M11 = new Block(M.m, M.loRow+h, M.loCol+h);
215
216 Block L00 = new Block(L.m, L.loRow, L.loCol);
217 Block L01 = new Block(L.m, L.loRow, L.loCol+h);
218 Block L10 = new Block(L.m, L.loRow+h, L.loCol);
219 Block L11 = new Block(L.m, L.loRow+h, L.loCol+h);
220
221
222 Seq3 s1 =
223 seq(new Lower(h, L00, M00),
224 new Schur(h, L10, M00, M10),
225 new Lower(h, L11, M10));
226 Seq3 s2 =
227 seq(new Lower(h, L00, M01),
228 new Schur(h, L10, M01, M11),
229 new Lower(h, L11, M11));
230 s2.fork();
231 s1.compute();
232 if (s2.tryUnfork()) s2.compute(); else s2.join();
233 }
234 }
235 }
236
237
238 static final class Upper extends RecursiveAction {
239 final int size;
240 final Block U;
241 final Block M;
242 Upper(int size, Block U, Block M) {
243 this.size = size; this.U = U; this.M = M;
244 }
245
246 void upper() { // base case
247 for (int i = 0; i < BLOCK_SIZE; ++i) {
248 for (int k = 0; k < BLOCK_SIZE; ++k) {
249 double a = M.m[i+M.loRow][k+M.loCol] / U.m[k+U.loRow][k+U.loCol];
250 M.m[i+M.loRow][k+M.loCol] = a;
251 double[] x = U.m[k+U.loRow];
252 double[] y = M.m[i+M.loRow];
253 int n = BLOCK_SIZE - k - 1;
254 for (int p = n - 1; p >= 0; --p) {
255 y[p+k+1+M.loCol] -= a * x[p+k+1+U.loCol];
256 }
257 }
258 }
259 }
260
261
262 public void compute() {
263 if (size == BLOCK_SIZE) {
264 upper();
265 }
266 else {
267 int h = size / 2;
268
269 Block M00 = new Block(M.m, M.loRow, M.loCol);
270 Block M01 = new Block(M.m, M.loRow, M.loCol+h);
271 Block M10 = new Block(M.m, M.loRow+h, M.loCol);
272 Block M11 = new Block(M.m, M.loRow+h, M.loCol+h);
273
274 Block U00 = new Block(U.m, U.loRow, U.loCol);
275 Block U01 = new Block(U.m, U.loRow, U.loCol+h);
276 Block U10 = new Block(U.m, U.loRow+h, U.loCol);
277 Block U11 = new Block(U.m, U.loRow+h, U.loCol+h);
278
279 Seq3 s1 =
280 seq(new Upper(h, U00, M00),
281 new Schur(h, M00, U01, M01),
282 new Upper(h, U11, M01));
283 Seq3 s2 =
284 seq(new Upper(h, U00, M10),
285 new Schur(h, M10, U01, M11),
286 new Upper(h, U11, M11));
287 s2.fork();
288 s1.compute();
289 if (s2.tryUnfork()) s2.compute(); else s2.join();
290 }
291 }
292 }
293
294
295 static final class LowerUpper extends RecursiveAction {
296 final int size;
297 final Block M;
298 LowerUpper(int size, Block M) {
299 this.size = size; this.M = M;
300 }
301
302 void lu() { // base case
303 for (int k = 0; k < BLOCK_SIZE; ++k) {
304 for (int i = k+1; i < BLOCK_SIZE; ++i) {
305 double b = M.m[k+M.loRow][k+M.loCol];
306 double a = M.m[i+M.loRow][k+M.loCol] / b;
307 M.m[i+M.loRow][k+M.loCol] = a;
308 double[] x = M.m[k+M.loRow];
309 double[] y = M.m[i+M.loRow];
310 int n = BLOCK_SIZE-k-1;
311 for (int p = n-1; p >= 0; --p) {
312 y[k+1+p+M.loCol] -= a * x[k+1+p+M.loCol];
313 }
314 }
315 }
316 }
317
318 public void compute() {
319 if (size == BLOCK_SIZE) {
320 lu();
321 }
322 else {
323 int h = size / 2;
324 Block M00 = new Block(M.m, M.loRow, M.loCol);
325 Block M01 = new Block(M.m, M.loRow, M.loCol+h);
326 Block M10 = new Block(M.m, M.loRow+h, M.loCol);
327 Block M11 = new Block(M.m, M.loRow+h, M.loCol+h);
328
329 new LowerUpper(h, M00).compute();
330 Lower sl = new Lower(h, M00, M01);
331 Upper su = new Upper(h, M00, M10);
332 su.fork();
333 sl.compute();
334 if (su.tryUnfork()) su.compute(); else su.join();
335 new Schur(h, M10, M01, M11).compute();
336 new LowerUpper(h, M11).compute();
337 }
338 }
339 }
340
341 static Seq2 seq(RecursiveAction task1,
342 RecursiveAction task2) {
343 return new Seq2(task1, task2);
344 }
345
346 static final class Seq2 extends RecursiveAction {
347 final RecursiveAction fst;
348 final RecursiveAction snd;
349 public Seq2(RecursiveAction task1, RecursiveAction task2) {
350 fst = task1;
351 snd = task2;
352 }
353 public void compute() {
354 fst.invoke();
355 snd.invoke();
356 }
357 }
358
359 static Seq3 seq(RecursiveAction task1,
360 RecursiveAction task2,
361 RecursiveAction task3) {
362 return new Seq3(task1, task2, task3);
363 }
364
365 static final class Seq3 extends RecursiveAction {
366 final RecursiveAction fst;
367 final RecursiveAction snd;
368 final RecursiveAction thr;
369 public Seq3(RecursiveAction task1,
370 RecursiveAction task2,
371 RecursiveAction task3) {
372 fst = task1;
373 snd = task2;
374 thr = task3;
375 }
376 public void compute() {
377 fst.invoke();
378 snd.invoke();
379 thr.invoke();
380 }
381 }
382
383
384
385 }
386