ViewVC Help
View File | Revision Log | Show Annotations | Download File | Root Listing
root/jsr166/jsr166/src/test/loops/LU.java
Revision: 1.8
Committed: Thu Jan 15 18:34:19 2015 UTC (9 years, 4 months ago) by jsr166
Branch: MAIN
Changes since 1.7: +0 -7 lines
Log Message:
delete extraneous blank lines

File Contents

# Content
1 /*
2 * Written by Doug Lea with assistance from members of JCP JSR-166
3 * Expert Group and released to the public domain, as explained at
4 * http://creativecommons.org/publicdomain/zero/1.0/
5 */
6
7 //import jsr166y.*;
8 import java.util.concurrent.*;
9
10 /**
11 * LU matrix decomposition demo
12 * Based on those in Cilk and Hood
13 */
14 public final class LU {
15
16 /** for time conversion */
17 static final long NPS = (1000L * 1000 * 1000);
18
19 // granularity is hard-wired as compile-time constant here
20 static final int BLOCK_SIZE = 16;
21 static final boolean CHECK = false; // set true to check answer
22
23 public static void main(String[] args) throws Exception {
24
25 final String usage = "Usage: java LU <threads> <matrix size (must be a power of two)> [runs] \n For example, try java LU 2 512";
26
27 int procs = 0;
28 int n = 2048;
29 int runs = 5;
30 try {
31 if (args.length > 0)
32 procs = Integer.parseInt(args[0]);
33 if (args.length > 1)
34 n = Integer.parseInt(args[1]);
35 if (args.length > 2)
36 runs = Integer.parseInt(args[2]);
37 } catch (Exception e) {
38 System.out.println(usage);
39 return;
40 }
41
42 if ( ((n & (n - 1)) != 0)) {
43 System.out.println(usage);
44 return;
45 }
46 ForkJoinPool pool = (procs == 0) ? new ForkJoinPool() :
47 new ForkJoinPool(procs);
48 System.out.println("procs: " + pool.getParallelism() +
49 " n: " + n + " runs: " + runs);
50 for (int run = 0; run < runs; ++run) {
51 double[][] m = new double[n][n];
52 randomInit(m, n);
53 double[][] copy = null;
54 if (CHECK) {
55 copy = new double[n][n];
56 for (int i = 0; i < n; ++i) {
57 for (int j = 0; j < n; ++j) {
58 copy[i][j] = m[i][j];
59 }
60 }
61 }
62 Block M = new Block(m, 0, 0);
63 long start = System.nanoTime();
64 pool.invoke(new LowerUpper(n, M));
65 long time = System.nanoTime() - start;
66 double secs = ((double)time) / NPS;
67 System.out.printf("\tTime: %7.3f\n", secs);
68
69 if (CHECK) check(m, copy, n);
70 }
71 System.out.println(pool.toString());
72 pool.shutdown();
73 }
74
75 static void randomInit(double[][] M, int n) {
76 java.util.Random rng = new java.util.Random();
77 for (int i = 0; i < n; ++i)
78 for (int j = 0; j < n; ++j)
79 M[i][j] = rng.nextDouble();
80 // for compatibility with hood demo, force larger diagonals
81 for (int k = 0; k < n; ++k)
82 M[k][k] *= 10.0;
83 }
84
85 static void check(double[][] LU, double[][] M, int n) {
86 double maxDiff = 0.0; // track max difference
87 for (int i = 0; i < n; ++i) {
88 for (int j = 0; j < n; ++j) {
89 double v = 0.0;
90 int k;
91 for (k = 0; k < i && k <= j; k++ ) v += LU[i][k] * LU[k][j];
92 if (k == i && k <= j ) v += LU[k][j];
93 double diff = M[i][j] - v;
94 if (diff < 0) diff = -diff;
95 if (diff > 0.001) {
96 System.out.println("large diff at[" + i + "," + j + "]: " + M[i][j] + " vs " + v);
97 }
98 if (diff > maxDiff) maxDiff = diff;
99 }
100 }
101
102 System.out.println("Max difference = " + maxDiff);
103 }
104
105 // Blocks record underlying matrix, and offsets into current block
106 static final class Block {
107 final double[][] m;
108 final int loRow;
109 final int loCol;
110
111 Block(double[][] mat, int lr, int lc) {
112 m = mat; loRow = lr; loCol = lc;
113 }
114 }
115
116 static final class Schur extends RecursiveAction {
117 final int size;
118 final Block V;
119 final Block W;
120 final Block M;
121
122 Schur(int size, Block V, Block W, Block M) {
123 this.size = size; this.V = V; this.W = W; this.M = M;
124 }
125
126 void schur() { // base case
127 for (int j = 0; j < BLOCK_SIZE; ++j) {
128 for (int i = 0; i < BLOCK_SIZE; ++i) {
129 double s = M.m[i+M.loRow][j+M.loCol];
130 for (int k = 0; k < BLOCK_SIZE; ++k) {
131 s -= V.m[i+V.loRow][k+V.loCol] * W.m[k+W.loRow][j+W.loCol];
132 }
133 M.m[i+M.loRow][j+M.loCol] = s;
134 }
135 }
136 }
137
138 public void compute() {
139 if (size == BLOCK_SIZE) {
140 schur();
141 }
142 else {
143 int h = size / 2;
144
145 Block M00 = new Block(M.m, M.loRow, M.loCol);
146 Block M01 = new Block(M.m, M.loRow, M.loCol+h);
147 Block M10 = new Block(M.m, M.loRow+h, M.loCol);
148 Block M11 = new Block(M.m, M.loRow+h, M.loCol+h);
149
150 Block V00 = new Block(V.m, V.loRow, V.loCol);
151 Block V01 = new Block(V.m, V.loRow, V.loCol+h);
152 Block V10 = new Block(V.m, V.loRow+h, V.loCol);
153 Block V11 = new Block(V.m, V.loRow+h, V.loCol+h);
154
155 Block W00 = new Block(W.m, W.loRow, W.loCol);
156 Block W01 = new Block(W.m, W.loRow, W.loCol+h);
157 Block W10 = new Block(W.m, W.loRow+h, W.loCol);
158 Block W11 = new Block(W.m, W.loRow+h, W.loCol+h);
159
160 Seq2 s3 = seq(new Schur(h, V10, W01, M11),
161 new Schur(h, V11, W11, M11));
162 s3.fork();
163 Seq2 s2 = seq(new Schur(h, V10, W00, M10),
164 new Schur(h, V11, W10, M10));
165 s2.fork();
166 Seq2 s1 = seq(new Schur(h, V00, W01, M01),
167 new Schur(h, V01, W11, M01));
168 s1.fork();
169 new Schur(h, V00, W00, M00).compute();
170 new Schur(h, V01, W10, M00).compute();
171 if (s1.tryUnfork()) s1.compute(); else s1.join();
172 if (s2.tryUnfork()) s2.compute(); else s2.join();
173 if (s3.tryUnfork()) s3.compute(); else s3.join();
174 }
175 }
176 }
177
178 static final class Lower extends RecursiveAction {
179 final int size;
180 final Block L;
181 final Block M;
182 Lower(int size, Block L, Block M) {
183 this.size = size; this.L = L; this.M = M;
184 }
185
186 void lower() { // base case
187 for (int i = 1; i < BLOCK_SIZE; ++i) {
188 for (int k = 0; k < i; ++k) {
189 double a = L.m[i+L.loRow][k+L.loCol];
190 double[] x = M.m[k+M.loRow];
191 double[] y = M.m[i+M.loRow];
192 int n = BLOCK_SIZE;
193 for (int p = n-1; p >= 0; --p) {
194 y[p+M.loCol] -= a * x[p+M.loCol];
195 }
196 }
197 }
198 }
199
200 public void compute() {
201 if (size == BLOCK_SIZE) {
202 lower();
203 }
204 else {
205 int h = size / 2;
206
207 Block M00 = new Block(M.m, M.loRow, M.loCol);
208 Block M01 = new Block(M.m, M.loRow, M.loCol+h);
209 Block M10 = new Block(M.m, M.loRow+h, M.loCol);
210 Block M11 = new Block(M.m, M.loRow+h, M.loCol+h);
211
212 Block L00 = new Block(L.m, L.loRow, L.loCol);
213 Block L01 = new Block(L.m, L.loRow, L.loCol+h);
214 Block L10 = new Block(L.m, L.loRow+h, L.loCol);
215 Block L11 = new Block(L.m, L.loRow+h, L.loCol+h);
216
217 Seq3 s1 =
218 seq(new Lower(h, L00, M00),
219 new Schur(h, L10, M00, M10),
220 new Lower(h, L11, M10));
221 Seq3 s2 =
222 seq(new Lower(h, L00, M01),
223 new Schur(h, L10, M01, M11),
224 new Lower(h, L11, M11));
225 s2.fork();
226 s1.compute();
227 if (s2.tryUnfork()) s2.compute(); else s2.join();
228 }
229 }
230 }
231
232 static final class Upper extends RecursiveAction {
233 final int size;
234 final Block U;
235 final Block M;
236 Upper(int size, Block U, Block M) {
237 this.size = size; this.U = U; this.M = M;
238 }
239
240 void upper() { // base case
241 for (int i = 0; i < BLOCK_SIZE; ++i) {
242 for (int k = 0; k < BLOCK_SIZE; ++k) {
243 double a = M.m[i+M.loRow][k+M.loCol] / U.m[k+U.loRow][k+U.loCol];
244 M.m[i+M.loRow][k+M.loCol] = a;
245 double[] x = U.m[k+U.loRow];
246 double[] y = M.m[i+M.loRow];
247 int n = BLOCK_SIZE - k - 1;
248 for (int p = n - 1; p >= 0; --p) {
249 y[p+k+1+M.loCol] -= a * x[p+k+1+U.loCol];
250 }
251 }
252 }
253 }
254
255 public void compute() {
256 if (size == BLOCK_SIZE) {
257 upper();
258 }
259 else {
260 int h = size / 2;
261
262 Block M00 = new Block(M.m, M.loRow, M.loCol);
263 Block M01 = new Block(M.m, M.loRow, M.loCol+h);
264 Block M10 = new Block(M.m, M.loRow+h, M.loCol);
265 Block M11 = new Block(M.m, M.loRow+h, M.loCol+h);
266
267 Block U00 = new Block(U.m, U.loRow, U.loCol);
268 Block U01 = new Block(U.m, U.loRow, U.loCol+h);
269 Block U10 = new Block(U.m, U.loRow+h, U.loCol);
270 Block U11 = new Block(U.m, U.loRow+h, U.loCol+h);
271
272 Seq3 s1 =
273 seq(new Upper(h, U00, M00),
274 new Schur(h, M00, U01, M01),
275 new Upper(h, U11, M01));
276 Seq3 s2 =
277 seq(new Upper(h, U00, M10),
278 new Schur(h, M10, U01, M11),
279 new Upper(h, U11, M11));
280 s2.fork();
281 s1.compute();
282 if (s2.tryUnfork()) s2.compute(); else s2.join();
283 }
284 }
285 }
286
287 static final class LowerUpper extends RecursiveAction {
288 final int size;
289 final Block M;
290 LowerUpper(int size, Block M) {
291 this.size = size; this.M = M;
292 }
293
294 void lu() { // base case
295 for (int k = 0; k < BLOCK_SIZE; ++k) {
296 for (int i = k+1; i < BLOCK_SIZE; ++i) {
297 double b = M.m[k+M.loRow][k+M.loCol];
298 double a = M.m[i+M.loRow][k+M.loCol] / b;
299 M.m[i+M.loRow][k+M.loCol] = a;
300 double[] x = M.m[k+M.loRow];
301 double[] y = M.m[i+M.loRow];
302 int n = BLOCK_SIZE-k-1;
303 for (int p = n-1; p >= 0; --p) {
304 y[k+1+p+M.loCol] -= a * x[k+1+p+M.loCol];
305 }
306 }
307 }
308 }
309
310 public void compute() {
311 if (size == BLOCK_SIZE) {
312 lu();
313 }
314 else {
315 int h = size / 2;
316 Block M00 = new Block(M.m, M.loRow, M.loCol);
317 Block M01 = new Block(M.m, M.loRow, M.loCol+h);
318 Block M10 = new Block(M.m, M.loRow+h, M.loCol);
319 Block M11 = new Block(M.m, M.loRow+h, M.loCol+h);
320
321 new LowerUpper(h, M00).compute();
322 Lower sl = new Lower(h, M00, M01);
323 Upper su = new Upper(h, M00, M10);
324 su.fork();
325 sl.compute();
326 if (su.tryUnfork()) su.compute(); else su.join();
327 new Schur(h, M10, M01, M11).compute();
328 new LowerUpper(h, M11).compute();
329 }
330 }
331 }
332
333 static Seq2 seq(RecursiveAction task1,
334 RecursiveAction task2) {
335 return new Seq2(task1, task2);
336 }
337
338 static final class Seq2 extends RecursiveAction {
339 final RecursiveAction fst;
340 final RecursiveAction snd;
341 public Seq2(RecursiveAction task1, RecursiveAction task2) {
342 fst = task1;
343 snd = task2;
344 }
345 public void compute() {
346 fst.invoke();
347 snd.invoke();
348 }
349 }
350
351 static Seq3 seq(RecursiveAction task1,
352 RecursiveAction task2,
353 RecursiveAction task3) {
354 return new Seq3(task1, task2, task3);
355 }
356
357 static final class Seq3 extends RecursiveAction {
358 final RecursiveAction fst;
359 final RecursiveAction snd;
360 final RecursiveAction thr;
361 public Seq3(RecursiveAction task1,
362 RecursiveAction task2,
363 RecursiveAction task3) {
364 fst = task1;
365 snd = task2;
366 thr = task3;
367 }
368 public void compute() {
369 fst.invoke();
370 snd.invoke();
371 thr.invoke();
372 }
373 }
374 }