ViewVC Help
View File | Revision Log | Show Annotations | Download File | Root Listing
root/jsr166/jsr166/src/test/loops/LU.java
Revision: 1.9
Committed: Sat Sep 12 18:59:08 2015 UTC (8 years, 7 months ago) by dl
Branch: MAIN
CVS Tags: HEAD
Changes since 1.8: +4 -3 lines
Log Message:
Use commonPool

File Contents

# Content
1 /*
2 * Written by Doug Lea with assistance from members of JCP JSR-166
3 * Expert Group and released to the public domain, as explained at
4 * http://creativecommons.org/publicdomain/zero/1.0/
5 */
6
7 //import jsr166y.*;
8 import java.util.concurrent.*;
9
10 /**
11 * LU matrix decomposition demo
12 * Based on those in Cilk and Hood
13 */
14 public final class LU {
15
16 /** for time conversion */
17 static final long NPS = (1000L * 1000 * 1000);
18
19 // granularity is hard-wired as compile-time constant here
20 static final int BLOCK_SIZE = 16;
21 static final boolean CHECK = false; // set true to check answer
22
23 public static void main(String[] args) throws Exception {
24
25 final String usage = "Usage: java LU <threads> <matrix size (must be a power of two)> [runs] \n For example, try java LU 2 512";
26
27 int procs = 0;
28 int n = 2048;
29 int runs = 32;
30 try {
31 if (args.length > 0)
32 procs = Integer.parseInt(args[0]);
33 if (args.length > 1)
34 n = Integer.parseInt(args[1]);
35 if (args.length > 2)
36 runs = Integer.parseInt(args[2]);
37 } catch (Exception e) {
38 System.out.println(usage);
39 return;
40 }
41
42 if ( ((n & (n - 1)) != 0)) {
43 System.out.println(usage);
44 return;
45 }
46 ForkJoinPool pool = (procs == 0) ? ForkJoinPool.commonPool() :
47 new ForkJoinPool(procs);
48 System.out.println("procs: " + pool.getParallelism() +
49 " n: " + n + " runs: " + runs);
50 for (int run = 0; run < runs; ++run) {
51 double[][] m = new double[n][n];
52 randomInit(m, n);
53 double[][] copy = null;
54 if (CHECK) {
55 copy = new double[n][n];
56 for (int i = 0; i < n; ++i) {
57 for (int j = 0; j < n; ++j) {
58 copy[i][j] = m[i][j];
59 }
60 }
61 }
62 Block M = new Block(m, 0, 0);
63 long start = System.nanoTime();
64 pool.invoke(new LowerUpper(n, M));
65 long time = System.nanoTime() - start;
66 double secs = ((double)time) / NPS;
67 System.out.printf("Time: %7.3f ", secs);
68 if ((run & 3) == 3) System.out.println();
69
70 if (CHECK) check(m, copy, n);
71 }
72 System.out.println(pool.toString());
73 pool.shutdown();
74 }
75
76 static void randomInit(double[][] M, int n) {
77 java.util.Random rng = new java.util.Random();
78 for (int i = 0; i < n; ++i)
79 for (int j = 0; j < n; ++j)
80 M[i][j] = rng.nextDouble();
81 // for compatibility with hood demo, force larger diagonals
82 for (int k = 0; k < n; ++k)
83 M[k][k] *= 10.0;
84 }
85
86 static void check(double[][] LU, double[][] M, int n) {
87 double maxDiff = 0.0; // track max difference
88 for (int i = 0; i < n; ++i) {
89 for (int j = 0; j < n; ++j) {
90 double v = 0.0;
91 int k;
92 for (k = 0; k < i && k <= j; k++ ) v += LU[i][k] * LU[k][j];
93 if (k == i && k <= j ) v += LU[k][j];
94 double diff = M[i][j] - v;
95 if (diff < 0) diff = -diff;
96 if (diff > 0.001) {
97 System.out.println("large diff at[" + i + "," + j + "]: " + M[i][j] + " vs " + v);
98 }
99 if (diff > maxDiff) maxDiff = diff;
100 }
101 }
102
103 System.out.println("Max difference = " + maxDiff);
104 }
105
106 // Blocks record underlying matrix, and offsets into current block
107 static final class Block {
108 final double[][] m;
109 final int loRow;
110 final int loCol;
111
112 Block(double[][] mat, int lr, int lc) {
113 m = mat; loRow = lr; loCol = lc;
114 }
115 }
116
117 static final class Schur extends RecursiveAction {
118 final int size;
119 final Block V;
120 final Block W;
121 final Block M;
122
123 Schur(int size, Block V, Block W, Block M) {
124 this.size = size; this.V = V; this.W = W; this.M = M;
125 }
126
127 void schur() { // base case
128 for (int j = 0; j < BLOCK_SIZE; ++j) {
129 for (int i = 0; i < BLOCK_SIZE; ++i) {
130 double s = M.m[i+M.loRow][j+M.loCol];
131 for (int k = 0; k < BLOCK_SIZE; ++k) {
132 s -= V.m[i+V.loRow][k+V.loCol] * W.m[k+W.loRow][j+W.loCol];
133 }
134 M.m[i+M.loRow][j+M.loCol] = s;
135 }
136 }
137 }
138
139 public void compute() {
140 if (size == BLOCK_SIZE) {
141 schur();
142 }
143 else {
144 int h = size / 2;
145
146 Block M00 = new Block(M.m, M.loRow, M.loCol);
147 Block M01 = new Block(M.m, M.loRow, M.loCol+h);
148 Block M10 = new Block(M.m, M.loRow+h, M.loCol);
149 Block M11 = new Block(M.m, M.loRow+h, M.loCol+h);
150
151 Block V00 = new Block(V.m, V.loRow, V.loCol);
152 Block V01 = new Block(V.m, V.loRow, V.loCol+h);
153 Block V10 = new Block(V.m, V.loRow+h, V.loCol);
154 Block V11 = new Block(V.m, V.loRow+h, V.loCol+h);
155
156 Block W00 = new Block(W.m, W.loRow, W.loCol);
157 Block W01 = new Block(W.m, W.loRow, W.loCol+h);
158 Block W10 = new Block(W.m, W.loRow+h, W.loCol);
159 Block W11 = new Block(W.m, W.loRow+h, W.loCol+h);
160
161 Seq2 s3 = seq(new Schur(h, V10, W01, M11),
162 new Schur(h, V11, W11, M11));
163 s3.fork();
164 Seq2 s2 = seq(new Schur(h, V10, W00, M10),
165 new Schur(h, V11, W10, M10));
166 s2.fork();
167 Seq2 s1 = seq(new Schur(h, V00, W01, M01),
168 new Schur(h, V01, W11, M01));
169 s1.fork();
170 new Schur(h, V00, W00, M00).compute();
171 new Schur(h, V01, W10, M00).compute();
172 if (s1.tryUnfork()) s1.compute(); else s1.join();
173 if (s2.tryUnfork()) s2.compute(); else s2.join();
174 if (s3.tryUnfork()) s3.compute(); else s3.join();
175 }
176 }
177 }
178
179 static final class Lower extends RecursiveAction {
180 final int size;
181 final Block L;
182 final Block M;
183 Lower(int size, Block L, Block M) {
184 this.size = size; this.L = L; this.M = M;
185 }
186
187 void lower() { // base case
188 for (int i = 1; i < BLOCK_SIZE; ++i) {
189 for (int k = 0; k < i; ++k) {
190 double a = L.m[i+L.loRow][k+L.loCol];
191 double[] x = M.m[k+M.loRow];
192 double[] y = M.m[i+M.loRow];
193 int n = BLOCK_SIZE;
194 for (int p = n-1; p >= 0; --p) {
195 y[p+M.loCol] -= a * x[p+M.loCol];
196 }
197 }
198 }
199 }
200
201 public void compute() {
202 if (size == BLOCK_SIZE) {
203 lower();
204 }
205 else {
206 int h = size / 2;
207
208 Block M00 = new Block(M.m, M.loRow, M.loCol);
209 Block M01 = new Block(M.m, M.loRow, M.loCol+h);
210 Block M10 = new Block(M.m, M.loRow+h, M.loCol);
211 Block M11 = new Block(M.m, M.loRow+h, M.loCol+h);
212
213 Block L00 = new Block(L.m, L.loRow, L.loCol);
214 Block L01 = new Block(L.m, L.loRow, L.loCol+h);
215 Block L10 = new Block(L.m, L.loRow+h, L.loCol);
216 Block L11 = new Block(L.m, L.loRow+h, L.loCol+h);
217
218 Seq3 s1 =
219 seq(new Lower(h, L00, M00),
220 new Schur(h, L10, M00, M10),
221 new Lower(h, L11, M10));
222 Seq3 s2 =
223 seq(new Lower(h, L00, M01),
224 new Schur(h, L10, M01, M11),
225 new Lower(h, L11, M11));
226 s2.fork();
227 s1.compute();
228 if (s2.tryUnfork()) s2.compute(); else s2.join();
229 }
230 }
231 }
232
233 static final class Upper extends RecursiveAction {
234 final int size;
235 final Block U;
236 final Block M;
237 Upper(int size, Block U, Block M) {
238 this.size = size; this.U = U; this.M = M;
239 }
240
241 void upper() { // base case
242 for (int i = 0; i < BLOCK_SIZE; ++i) {
243 for (int k = 0; k < BLOCK_SIZE; ++k) {
244 double a = M.m[i+M.loRow][k+M.loCol] / U.m[k+U.loRow][k+U.loCol];
245 M.m[i+M.loRow][k+M.loCol] = a;
246 double[] x = U.m[k+U.loRow];
247 double[] y = M.m[i+M.loRow];
248 int n = BLOCK_SIZE - k - 1;
249 for (int p = n - 1; p >= 0; --p) {
250 y[p+k+1+M.loCol] -= a * x[p+k+1+U.loCol];
251 }
252 }
253 }
254 }
255
256 public void compute() {
257 if (size == BLOCK_SIZE) {
258 upper();
259 }
260 else {
261 int h = size / 2;
262
263 Block M00 = new Block(M.m, M.loRow, M.loCol);
264 Block M01 = new Block(M.m, M.loRow, M.loCol+h);
265 Block M10 = new Block(M.m, M.loRow+h, M.loCol);
266 Block M11 = new Block(M.m, M.loRow+h, M.loCol+h);
267
268 Block U00 = new Block(U.m, U.loRow, U.loCol);
269 Block U01 = new Block(U.m, U.loRow, U.loCol+h);
270 Block U10 = new Block(U.m, U.loRow+h, U.loCol);
271 Block U11 = new Block(U.m, U.loRow+h, U.loCol+h);
272
273 Seq3 s1 =
274 seq(new Upper(h, U00, M00),
275 new Schur(h, M00, U01, M01),
276 new Upper(h, U11, M01));
277 Seq3 s2 =
278 seq(new Upper(h, U00, M10),
279 new Schur(h, M10, U01, M11),
280 new Upper(h, U11, M11));
281 s2.fork();
282 s1.compute();
283 if (s2.tryUnfork()) s2.compute(); else s2.join();
284 }
285 }
286 }
287
288 static final class LowerUpper extends RecursiveAction {
289 final int size;
290 final Block M;
291 LowerUpper(int size, Block M) {
292 this.size = size; this.M = M;
293 }
294
295 void lu() { // base case
296 for (int k = 0; k < BLOCK_SIZE; ++k) {
297 for (int i = k+1; i < BLOCK_SIZE; ++i) {
298 double b = M.m[k+M.loRow][k+M.loCol];
299 double a = M.m[i+M.loRow][k+M.loCol] / b;
300 M.m[i+M.loRow][k+M.loCol] = a;
301 double[] x = M.m[k+M.loRow];
302 double[] y = M.m[i+M.loRow];
303 int n = BLOCK_SIZE-k-1;
304 for (int p = n-1; p >= 0; --p) {
305 y[k+1+p+M.loCol] -= a * x[k+1+p+M.loCol];
306 }
307 }
308 }
309 }
310
311 public void compute() {
312 if (size == BLOCK_SIZE) {
313 lu();
314 }
315 else {
316 int h = size / 2;
317 Block M00 = new Block(M.m, M.loRow, M.loCol);
318 Block M01 = new Block(M.m, M.loRow, M.loCol+h);
319 Block M10 = new Block(M.m, M.loRow+h, M.loCol);
320 Block M11 = new Block(M.m, M.loRow+h, M.loCol+h);
321
322 new LowerUpper(h, M00).compute();
323 Lower sl = new Lower(h, M00, M01);
324 Upper su = new Upper(h, M00, M10);
325 su.fork();
326 sl.compute();
327 if (su.tryUnfork()) su.compute(); else su.join();
328 new Schur(h, M10, M01, M11).compute();
329 new LowerUpper(h, M11).compute();
330 }
331 }
332 }
333
334 static Seq2 seq(RecursiveAction task1,
335 RecursiveAction task2) {
336 return new Seq2(task1, task2);
337 }
338
339 static final class Seq2 extends RecursiveAction {
340 final RecursiveAction fst;
341 final RecursiveAction snd;
342 public Seq2(RecursiveAction task1, RecursiveAction task2) {
343 fst = task1;
344 snd = task2;
345 }
346 public void compute() {
347 fst.invoke();
348 snd.invoke();
349 }
350 }
351
352 static Seq3 seq(RecursiveAction task1,
353 RecursiveAction task2,
354 RecursiveAction task3) {
355 return new Seq3(task1, task2, task3);
356 }
357
358 static final class Seq3 extends RecursiveAction {
359 final RecursiveAction fst;
360 final RecursiveAction snd;
361 final RecursiveAction thr;
362 public Seq3(RecursiveAction task1,
363 RecursiveAction task2,
364 RecursiveAction task3) {
365 fst = task1;
366 snd = task2;
367 thr = task3;
368 }
369 public void compute() {
370 fst.invoke();
371 snd.invoke();
372 thr.invoke();
373 }
374 }
375 }