ViewVC Help
View File | Revision Log | Show Annotations | Download File | Root Listing
root/jsr166/jsr166/src/test/loops/LU.java
Revision: 1.7
Committed: Wed Dec 31 16:44:01 2014 UTC (9 years, 4 months ago) by jsr166
Branch: MAIN
Changes since 1.6: +0 -1 lines
Log Message:
remove unused imports

File Contents

# Content
1 /*
2 * Written by Doug Lea with assistance from members of JCP JSR-166
3 * Expert Group and released to the public domain, as explained at
4 * http://creativecommons.org/publicdomain/zero/1.0/
5 */
6
7 //import jsr166y.*;
8 import java.util.concurrent.*;
9
10
11 /**
12 * LU matrix decomposition demo
13 * Based on those in Cilk and Hood
14 */
15 public final class LU {
16
17 /** for time conversion */
18 static final long NPS = (1000L * 1000 * 1000);
19
20 // granularity is hard-wired as compile-time constant here
21 static final int BLOCK_SIZE = 16;
22 static final boolean CHECK = false; // set true to check answer
23
24 public static void main(String[] args) throws Exception {
25
26 final String usage = "Usage: java LU <threads> <matrix size (must be a power of two)> [runs] \n For example, try java LU 2 512";
27
28 int procs = 0;
29 int n = 2048;
30 int runs = 5;
31 try {
32 if (args.length > 0)
33 procs = Integer.parseInt(args[0]);
34 if (args.length > 1)
35 n = Integer.parseInt(args[1]);
36 if (args.length > 2)
37 runs = Integer.parseInt(args[2]);
38 } catch (Exception e) {
39 System.out.println(usage);
40 return;
41 }
42
43 if ( ((n & (n - 1)) != 0)) {
44 System.out.println(usage);
45 return;
46 }
47 ForkJoinPool pool = (procs == 0) ? new ForkJoinPool() :
48 new ForkJoinPool(procs);
49 System.out.println("procs: " + pool.getParallelism() +
50 " n: " + n + " runs: " + runs);
51 for (int run = 0; run < runs; ++run) {
52 double[][] m = new double[n][n];
53 randomInit(m, n);
54 double[][] copy = null;
55 if (CHECK) {
56 copy = new double[n][n];
57 for (int i = 0; i < n; ++i) {
58 for (int j = 0; j < n; ++j) {
59 copy[i][j] = m[i][j];
60 }
61 }
62 }
63 Block M = new Block(m, 0, 0);
64 long start = System.nanoTime();
65 pool.invoke(new LowerUpper(n, M));
66 long time = System.nanoTime() - start;
67 double secs = ((double)time) / NPS;
68 System.out.printf("\tTime: %7.3f\n", secs);
69
70 if (CHECK) check(m, copy, n);
71 }
72 System.out.println(pool.toString());
73 pool.shutdown();
74 }
75
76
77 static void randomInit(double[][] M, int n) {
78 java.util.Random rng = new java.util.Random();
79 for (int i = 0; i < n; ++i)
80 for (int j = 0; j < n; ++j)
81 M[i][j] = rng.nextDouble();
82 // for compatibility with hood demo, force larger diagonals
83 for (int k = 0; k < n; ++k)
84 M[k][k] *= 10.0;
85 }
86
87 static void check(double[][] LU, double[][] M, int n) {
88 double maxDiff = 0.0; // track max difference
89 for (int i = 0; i < n; ++i) {
90 for (int j = 0; j < n; ++j) {
91 double v = 0.0;
92 int k;
93 for (k = 0; k < i && k <= j; k++ ) v += LU[i][k] * LU[k][j];
94 if (k == i && k <= j ) v += LU[k][j];
95 double diff = M[i][j] - v;
96 if (diff < 0) diff = -diff;
97 if (diff > 0.001) {
98 System.out.println("large diff at[" + i + "," + j + "]: " + M[i][j] + " vs " + v);
99 }
100 if (diff > maxDiff) maxDiff = diff;
101 }
102 }
103
104 System.out.println("Max difference = " + maxDiff);
105 }
106
107
108 // Blocks record underlying matrix, and offsets into current block
109 static final class Block {
110 final double[][] m;
111 final int loRow;
112 final int loCol;
113
114 Block(double[][] mat, int lr, int lc) {
115 m = mat; loRow = lr; loCol = lc;
116 }
117 }
118
119 static final class Schur extends RecursiveAction {
120 final int size;
121 final Block V;
122 final Block W;
123 final Block M;
124
125 Schur(int size, Block V, Block W, Block M) {
126 this.size = size; this.V = V; this.W = W; this.M = M;
127 }
128
129 void schur() { // base case
130 for (int j = 0; j < BLOCK_SIZE; ++j) {
131 for (int i = 0; i < BLOCK_SIZE; ++i) {
132 double s = M.m[i+M.loRow][j+M.loCol];
133 for (int k = 0; k < BLOCK_SIZE; ++k) {
134 s -= V.m[i+V.loRow][k+V.loCol] * W.m[k+W.loRow][j+W.loCol];
135 }
136 M.m[i+M.loRow][j+M.loCol] = s;
137 }
138 }
139 }
140
141 public void compute() {
142 if (size == BLOCK_SIZE) {
143 schur();
144 }
145 else {
146 int h = size / 2;
147
148 Block M00 = new Block(M.m, M.loRow, M.loCol);
149 Block M01 = new Block(M.m, M.loRow, M.loCol+h);
150 Block M10 = new Block(M.m, M.loRow+h, M.loCol);
151 Block M11 = new Block(M.m, M.loRow+h, M.loCol+h);
152
153 Block V00 = new Block(V.m, V.loRow, V.loCol);
154 Block V01 = new Block(V.m, V.loRow, V.loCol+h);
155 Block V10 = new Block(V.m, V.loRow+h, V.loCol);
156 Block V11 = new Block(V.m, V.loRow+h, V.loCol+h);
157
158 Block W00 = new Block(W.m, W.loRow, W.loCol);
159 Block W01 = new Block(W.m, W.loRow, W.loCol+h);
160 Block W10 = new Block(W.m, W.loRow+h, W.loCol);
161 Block W11 = new Block(W.m, W.loRow+h, W.loCol+h);
162
163 Seq2 s3 = seq(new Schur(h, V10, W01, M11),
164 new Schur(h, V11, W11, M11));
165 s3.fork();
166 Seq2 s2 = seq(new Schur(h, V10, W00, M10),
167 new Schur(h, V11, W10, M10));
168 s2.fork();
169 Seq2 s1 = seq(new Schur(h, V00, W01, M01),
170 new Schur(h, V01, W11, M01));
171 s1.fork();
172 new Schur(h, V00, W00, M00).compute();
173 new Schur(h, V01, W10, M00).compute();
174 if (s1.tryUnfork()) s1.compute(); else s1.join();
175 if (s2.tryUnfork()) s2.compute(); else s2.join();
176 if (s3.tryUnfork()) s3.compute(); else s3.join();
177 }
178 }
179 }
180
181 static final class Lower extends RecursiveAction {
182 final int size;
183 final Block L;
184 final Block M;
185 Lower(int size, Block L, Block M) {
186 this.size = size; this.L = L; this.M = M;
187 }
188
189 void lower() { // base case
190 for (int i = 1; i < BLOCK_SIZE; ++i) {
191 for (int k = 0; k < i; ++k) {
192 double a = L.m[i+L.loRow][k+L.loCol];
193 double[] x = M.m[k+M.loRow];
194 double[] y = M.m[i+M.loRow];
195 int n = BLOCK_SIZE;
196 for (int p = n-1; p >= 0; --p) {
197 y[p+M.loCol] -= a * x[p+M.loCol];
198 }
199 }
200 }
201 }
202
203 public void compute() {
204 if (size == BLOCK_SIZE) {
205 lower();
206 }
207 else {
208 int h = size / 2;
209
210 Block M00 = new Block(M.m, M.loRow, M.loCol);
211 Block M01 = new Block(M.m, M.loRow, M.loCol+h);
212 Block M10 = new Block(M.m, M.loRow+h, M.loCol);
213 Block M11 = new Block(M.m, M.loRow+h, M.loCol+h);
214
215 Block L00 = new Block(L.m, L.loRow, L.loCol);
216 Block L01 = new Block(L.m, L.loRow, L.loCol+h);
217 Block L10 = new Block(L.m, L.loRow+h, L.loCol);
218 Block L11 = new Block(L.m, L.loRow+h, L.loCol+h);
219
220
221 Seq3 s1 =
222 seq(new Lower(h, L00, M00),
223 new Schur(h, L10, M00, M10),
224 new Lower(h, L11, M10));
225 Seq3 s2 =
226 seq(new Lower(h, L00, M01),
227 new Schur(h, L10, M01, M11),
228 new Lower(h, L11, M11));
229 s2.fork();
230 s1.compute();
231 if (s2.tryUnfork()) s2.compute(); else s2.join();
232 }
233 }
234 }
235
236
237 static final class Upper extends RecursiveAction {
238 final int size;
239 final Block U;
240 final Block M;
241 Upper(int size, Block U, Block M) {
242 this.size = size; this.U = U; this.M = M;
243 }
244
245 void upper() { // base case
246 for (int i = 0; i < BLOCK_SIZE; ++i) {
247 for (int k = 0; k < BLOCK_SIZE; ++k) {
248 double a = M.m[i+M.loRow][k+M.loCol] / U.m[k+U.loRow][k+U.loCol];
249 M.m[i+M.loRow][k+M.loCol] = a;
250 double[] x = U.m[k+U.loRow];
251 double[] y = M.m[i+M.loRow];
252 int n = BLOCK_SIZE - k - 1;
253 for (int p = n - 1; p >= 0; --p) {
254 y[p+k+1+M.loCol] -= a * x[p+k+1+U.loCol];
255 }
256 }
257 }
258 }
259
260
261 public void compute() {
262 if (size == BLOCK_SIZE) {
263 upper();
264 }
265 else {
266 int h = size / 2;
267
268 Block M00 = new Block(M.m, M.loRow, M.loCol);
269 Block M01 = new Block(M.m, M.loRow, M.loCol+h);
270 Block M10 = new Block(M.m, M.loRow+h, M.loCol);
271 Block M11 = new Block(M.m, M.loRow+h, M.loCol+h);
272
273 Block U00 = new Block(U.m, U.loRow, U.loCol);
274 Block U01 = new Block(U.m, U.loRow, U.loCol+h);
275 Block U10 = new Block(U.m, U.loRow+h, U.loCol);
276 Block U11 = new Block(U.m, U.loRow+h, U.loCol+h);
277
278 Seq3 s1 =
279 seq(new Upper(h, U00, M00),
280 new Schur(h, M00, U01, M01),
281 new Upper(h, U11, M01));
282 Seq3 s2 =
283 seq(new Upper(h, U00, M10),
284 new Schur(h, M10, U01, M11),
285 new Upper(h, U11, M11));
286 s2.fork();
287 s1.compute();
288 if (s2.tryUnfork()) s2.compute(); else s2.join();
289 }
290 }
291 }
292
293
294 static final class LowerUpper extends RecursiveAction {
295 final int size;
296 final Block M;
297 LowerUpper(int size, Block M) {
298 this.size = size; this.M = M;
299 }
300
301 void lu() { // base case
302 for (int k = 0; k < BLOCK_SIZE; ++k) {
303 for (int i = k+1; i < BLOCK_SIZE; ++i) {
304 double b = M.m[k+M.loRow][k+M.loCol];
305 double a = M.m[i+M.loRow][k+M.loCol] / b;
306 M.m[i+M.loRow][k+M.loCol] = a;
307 double[] x = M.m[k+M.loRow];
308 double[] y = M.m[i+M.loRow];
309 int n = BLOCK_SIZE-k-1;
310 for (int p = n-1; p >= 0; --p) {
311 y[k+1+p+M.loCol] -= a * x[k+1+p+M.loCol];
312 }
313 }
314 }
315 }
316
317 public void compute() {
318 if (size == BLOCK_SIZE) {
319 lu();
320 }
321 else {
322 int h = size / 2;
323 Block M00 = new Block(M.m, M.loRow, M.loCol);
324 Block M01 = new Block(M.m, M.loRow, M.loCol+h);
325 Block M10 = new Block(M.m, M.loRow+h, M.loCol);
326 Block M11 = new Block(M.m, M.loRow+h, M.loCol+h);
327
328 new LowerUpper(h, M00).compute();
329 Lower sl = new Lower(h, M00, M01);
330 Upper su = new Upper(h, M00, M10);
331 su.fork();
332 sl.compute();
333 if (su.tryUnfork()) su.compute(); else su.join();
334 new Schur(h, M10, M01, M11).compute();
335 new LowerUpper(h, M11).compute();
336 }
337 }
338 }
339
340 static Seq2 seq(RecursiveAction task1,
341 RecursiveAction task2) {
342 return new Seq2(task1, task2);
343 }
344
345 static final class Seq2 extends RecursiveAction {
346 final RecursiveAction fst;
347 final RecursiveAction snd;
348 public Seq2(RecursiveAction task1, RecursiveAction task2) {
349 fst = task1;
350 snd = task2;
351 }
352 public void compute() {
353 fst.invoke();
354 snd.invoke();
355 }
356 }
357
358 static Seq3 seq(RecursiveAction task1,
359 RecursiveAction task2,
360 RecursiveAction task3) {
361 return new Seq3(task1, task2, task3);
362 }
363
364 static final class Seq3 extends RecursiveAction {
365 final RecursiveAction fst;
366 final RecursiveAction snd;
367 final RecursiveAction thr;
368 public Seq3(RecursiveAction task1,
369 RecursiveAction task2,
370 RecursiveAction task3) {
371 fst = task1;
372 snd = task2;
373 thr = task3;
374 }
375 public void compute() {
376 fst.invoke();
377 snd.invoke();
378 thr.invoke();
379 }
380 }
381 }