ViewVC Help
View File | Revision Log | Show Annotations | Download File | Root Listing
root/jsr166/jsr166/src/test/loops/LU.java
Revision: 1.4
Committed: Mon Nov 29 20:58:07 2010 UTC (13 years, 5 months ago) by jsr166
Branch: MAIN
Changes since 1.3: +1 -1 lines
Log Message:
consistent ternary operator style

File Contents

# User Rev Content
1 dl 1.1 /*
2     * Written by Doug Lea with assistance from members of JCP JSR-166
3     * Expert Group and released to the public domain, as explained at
4     * http://creativecommons.org/licenses/publicdomain
5     */
6    
7     //import jsr166y.*;
8     import java.util.concurrent.*;
9     import java.util.concurrent.TimeUnit;
10    
11    
12     /**
13     * LU matrix decomposition demo
14     * Based on those in Cilk and Hood
15 jsr166 1.3 */
16 dl 1.1 public final class LU {
17    
18     /** for time conversion */
19     static final long NPS = (1000L * 1000 * 1000);
20    
21     // granularity is hard-wired as compile-time constant here
22 jsr166 1.2 static final int BLOCK_SIZE = 16;
23 dl 1.1 static final boolean CHECK = false; // set true to check answer
24    
25     public static void main(String[] args) throws Exception {
26    
27     final String usage = "Usage: java LU <threads> <matrix size (must be a power of two)> [runs] \n For example, try java LU 2 512";
28    
29     int procs = 0;
30     int n = 2048;
31     int runs = 5;
32     try {
33     if (args.length > 0)
34     procs = Integer.parseInt(args[0]);
35     if (args.length > 1)
36     n = Integer.parseInt(args[1]);
37 jsr166 1.2 if (args.length > 2)
38 dl 1.1 runs = Integer.parseInt(args[2]);
39     } catch (Exception e) {
40     System.out.println(usage);
41     return;
42     }
43    
44     if ( ((n & (n - 1)) != 0)) {
45     System.out.println(usage);
46     return;
47     }
48 jsr166 1.4 ForkJoinPool pool = (procs == 0) ? new ForkJoinPool() :
49 dl 1.1 new ForkJoinPool(procs);
50 jsr166 1.2 System.out.println("procs: " + pool.getParallelism() +
51 dl 1.1 " n: " + n + " runs: " + runs);
52     for (int run = 0; run < runs; ++run) {
53     double[][] m = new double[n][n];
54     randomInit(m, n);
55     double[][] copy = null;
56     if (CHECK) {
57     copy = new double[n][n];
58     for (int i = 0; i < n; ++i) {
59     for (int j = 0; j < n; ++j) {
60     copy[i][j] = m[i][j];
61     }
62     }
63     }
64     Block M = new Block(m, 0, 0);
65     long start = System.nanoTime();
66     pool.invoke(new LowerUpper(n, M));
67     long time = System.nanoTime() - start;
68     double secs = ((double)time) / NPS;
69     System.out.printf("\tTime: %7.3f\n", secs);
70    
71     if (CHECK) check(m, copy, n);
72     }
73     System.out.println(pool.toString());
74     pool.shutdown();
75     }
76    
77    
78     static void randomInit(double[][] M, int n) {
79     java.util.Random rng = new java.util.Random();
80     for (int i = 0; i < n; ++i)
81     for (int j = 0; j < n; ++j)
82     M[i][j] = rng.nextDouble();
83     // for compatibility with hood demo, force larger diagonals
84     for (int k = 0; k < n; ++k)
85     M[k][k] *= 10.0;
86     }
87    
88     static void check(double[][] LU, double[][] M, int n) {
89     double maxDiff = 0.0; // track max difference
90     for (int i = 0; i < n; ++i) {
91     for (int j = 0; j < n; ++j) {
92     double v = 0.0;
93     int k;
94     for (k = 0; k < i && k <= j; k++ ) v += LU[i][k] * LU[k][j];
95     if (k == i && k <= j ) v += LU[k][j];
96     double diff = M[i][j] - v;
97     if (diff < 0) diff = -diff;
98     if (diff > 0.001) {
99     System.out.println("large diff at[" + i + "," + j + "]: " + M[i][j] + " vs " + v);
100     }
101     if (diff > maxDiff) maxDiff = diff;
102     }
103     }
104    
105     System.out.println("Max difference = " + maxDiff);
106     }
107    
108    
109     // Blocks record underlying matrix, and offsets into current block
110     static final class Block {
111     final double[][] m;
112     final int loRow;
113     final int loCol;
114    
115     Block(double[][] mat, int lr, int lc) {
116     m = mat; loRow = lr; loCol = lc;
117     }
118     }
119    
120     static final class Schur extends RecursiveAction {
121     final int size;
122     final Block V;
123     final Block W;
124     final Block M;
125    
126     Schur(int size, Block V, Block W, Block M) {
127     this.size = size; this.V = V; this.W = W; this.M = M;
128     }
129    
130     void schur() { // base case
131     for (int j = 0; j < BLOCK_SIZE; ++j) {
132     for (int i = 0; i < BLOCK_SIZE; ++i) {
133     double s = M.m[i+M.loRow][j+M.loCol];
134     for (int k = 0; k < BLOCK_SIZE; ++k) {
135     s -= V.m[i+V.loRow][k+V.loCol] * W.m[k+W.loRow][j+W.loCol];
136     }
137     M.m[i+M.loRow][j+M.loCol] = s;
138     }
139     }
140     }
141    
142     public void compute() {
143     if (size == BLOCK_SIZE) {
144     schur();
145     }
146     else {
147     int h = size / 2;
148    
149     Block M00 = new Block(M.m, M.loRow, M.loCol);
150     Block M01 = new Block(M.m, M.loRow, M.loCol+h);
151     Block M10 = new Block(M.m, M.loRow+h, M.loCol);
152     Block M11 = new Block(M.m, M.loRow+h, M.loCol+h);
153    
154     Block V00 = new Block(V.m, V.loRow, V.loCol);
155     Block V01 = new Block(V.m, V.loRow, V.loCol+h);
156     Block V10 = new Block(V.m, V.loRow+h, V.loCol);
157     Block V11 = new Block(V.m, V.loRow+h, V.loCol+h);
158    
159     Block W00 = new Block(W.m, W.loRow, W.loCol);
160     Block W01 = new Block(W.m, W.loRow, W.loCol+h);
161     Block W10 = new Block(W.m, W.loRow+h, W.loCol);
162     Block W11 = new Block(W.m, W.loRow+h, W.loCol+h);
163 jsr166 1.2
164 dl 1.1 Seq2 s3 = seq(new Schur(h, V10, W01, M11),
165     new Schur(h, V11, W11, M11));
166     s3.fork();
167     Seq2 s2 = seq(new Schur(h, V10, W00, M10),
168     new Schur(h, V11, W10, M10));
169     s2.fork();
170     Seq2 s1 = seq(new Schur(h, V00, W01, M01),
171     new Schur(h, V01, W11, M01));
172     s1.fork();
173     new Schur(h, V00, W00, M00).compute();
174     new Schur(h, V01, W10, M00).compute();
175     if (s1.tryUnfork()) s1.compute(); else s1.join();
176     if (s2.tryUnfork()) s2.compute(); else s2.join();
177     if (s3.tryUnfork()) s3.compute(); else s3.join();
178     }
179     }
180     }
181    
182     static final class Lower extends RecursiveAction {
183     final int size;
184     final Block L;
185     final Block M;
186     Lower(int size, Block L, Block M) {
187     this.size = size; this.L = L; this.M = M;
188     }
189    
190     void lower() { // base case
191     for (int i = 1; i < BLOCK_SIZE; ++i) {
192     for (int k = 0; k < i; ++k) {
193     double a = L.m[i+L.loRow][k+L.loCol];
194     double[] x = M.m[k+M.loRow];
195     double[] y = M.m[i+M.loRow];
196     int n = BLOCK_SIZE;
197     for (int p = n-1; p >= 0; --p) {
198     y[p+M.loCol] -= a * x[p+M.loCol];
199     }
200     }
201     }
202     }
203    
204     public void compute() {
205     if (size == BLOCK_SIZE) {
206     lower();
207     }
208     else {
209     int h = size / 2;
210    
211     Block M00 = new Block(M.m, M.loRow, M.loCol);
212     Block M01 = new Block(M.m, M.loRow, M.loCol+h);
213     Block M10 = new Block(M.m, M.loRow+h, M.loCol);
214     Block M11 = new Block(M.m, M.loRow+h, M.loCol+h);
215    
216     Block L00 = new Block(L.m, L.loRow, L.loCol);
217     Block L01 = new Block(L.m, L.loRow, L.loCol+h);
218     Block L10 = new Block(L.m, L.loRow+h, L.loCol);
219     Block L11 = new Block(L.m, L.loRow+h, L.loCol+h);
220    
221 jsr166 1.2
222     Seq3 s1 =
223 dl 1.1 seq(new Lower(h, L00, M00),
224     new Schur(h, L10, M00, M10),
225     new Lower(h, L11, M10));
226 jsr166 1.2 Seq3 s2 =
227 dl 1.1 seq(new Lower(h, L00, M01),
228     new Schur(h, L10, M01, M11),
229     new Lower(h, L11, M11));
230     s2.fork();
231     s1.compute();
232     if (s2.tryUnfork()) s2.compute(); else s2.join();
233     }
234     }
235     }
236    
237    
238     static final class Upper extends RecursiveAction {
239     final int size;
240     final Block U;
241     final Block M;
242     Upper(int size, Block U, Block M) {
243     this.size = size; this.U = U; this.M = M;
244     }
245    
246     void upper() { // base case
247     for (int i = 0; i < BLOCK_SIZE; ++i) {
248     for (int k = 0; k < BLOCK_SIZE; ++k) {
249     double a = M.m[i+M.loRow][k+M.loCol] / U.m[k+U.loRow][k+U.loCol];
250     M.m[i+M.loRow][k+M.loCol] = a;
251     double[] x = U.m[k+U.loRow];
252     double[] y = M.m[i+M.loRow];
253     int n = BLOCK_SIZE - k - 1;
254     for (int p = n - 1; p >= 0; --p) {
255     y[p+k+1+M.loCol] -= a * x[p+k+1+U.loCol];
256     }
257     }
258     }
259     }
260    
261    
262     public void compute() {
263     if (size == BLOCK_SIZE) {
264     upper();
265     }
266     else {
267     int h = size / 2;
268    
269     Block M00 = new Block(M.m, M.loRow, M.loCol);
270     Block M01 = new Block(M.m, M.loRow, M.loCol+h);
271     Block M10 = new Block(M.m, M.loRow+h, M.loCol);
272     Block M11 = new Block(M.m, M.loRow+h, M.loCol+h);
273    
274     Block U00 = new Block(U.m, U.loRow, U.loCol);
275     Block U01 = new Block(U.m, U.loRow, U.loCol+h);
276     Block U10 = new Block(U.m, U.loRow+h, U.loCol);
277     Block U11 = new Block(U.m, U.loRow+h, U.loCol+h);
278    
279 jsr166 1.2 Seq3 s1 =
280 dl 1.1 seq(new Upper(h, U00, M00),
281     new Schur(h, M00, U01, M01),
282     new Upper(h, U11, M01));
283 jsr166 1.2 Seq3 s2 =
284 dl 1.1 seq(new Upper(h, U00, M10),
285     new Schur(h, M10, U01, M11),
286     new Upper(h, U11, M11));
287     s2.fork();
288     s1.compute();
289     if (s2.tryUnfork()) s2.compute(); else s2.join();
290     }
291     }
292     }
293 jsr166 1.2
294 dl 1.1
295     static final class LowerUpper extends RecursiveAction {
296     final int size;
297     final Block M;
298     LowerUpper(int size, Block M) {
299     this.size = size; this.M = M;
300     }
301    
302     void lu() { // base case
303     for (int k = 0; k < BLOCK_SIZE; ++k) {
304     for (int i = k+1; i < BLOCK_SIZE; ++i) {
305     double b = M.m[k+M.loRow][k+M.loCol];
306     double a = M.m[i+M.loRow][k+M.loCol] / b;
307     M.m[i+M.loRow][k+M.loCol] = a;
308     double[] x = M.m[k+M.loRow];
309     double[] y = M.m[i+M.loRow];
310     int n = BLOCK_SIZE-k-1;
311     for (int p = n-1; p >= 0; --p) {
312     y[k+1+p+M.loCol] -= a * x[k+1+p+M.loCol];
313     }
314     }
315     }
316     }
317    
318     public void compute() {
319     if (size == BLOCK_SIZE) {
320     lu();
321     }
322     else {
323     int h = size / 2;
324     Block M00 = new Block(M.m, M.loRow, M.loCol);
325     Block M01 = new Block(M.m, M.loRow, M.loCol+h);
326     Block M10 = new Block(M.m, M.loRow+h, M.loCol);
327     Block M11 = new Block(M.m, M.loRow+h, M.loCol+h);
328    
329     new LowerUpper(h, M00).compute();
330     Lower sl = new Lower(h, M00, M01);
331     Upper su = new Upper(h, M00, M10);
332     su.fork();
333     sl.compute();
334     if (su.tryUnfork()) su.compute(); else su.join();
335     new Schur(h, M10, M01, M11).compute();
336     new LowerUpper(h, M11).compute();
337     }
338     }
339     }
340    
341 jsr166 1.2 static Seq2 seq(RecursiveAction task1,
342     RecursiveAction task2) {
343     return new Seq2(task1, task2);
344 dl 1.1 }
345    
346     static final class Seq2 extends RecursiveAction {
347     final RecursiveAction fst;
348     final RecursiveAction snd;
349     public Seq2(RecursiveAction task1, RecursiveAction task2) {
350     fst = task1;
351     snd = task2;
352     }
353     public void compute() {
354     fst.invoke();
355     snd.invoke();
356     }
357     }
358    
359 jsr166 1.2 static Seq3 seq(RecursiveAction task1,
360 dl 1.1 RecursiveAction task2,
361 jsr166 1.2 RecursiveAction task3) {
362     return new Seq3(task1, task2, task3);
363 dl 1.1 }
364    
365     static final class Seq3 extends RecursiveAction {
366     final RecursiveAction fst;
367     final RecursiveAction snd;
368     final RecursiveAction thr;
369 jsr166 1.2 public Seq3(RecursiveAction task1,
370 dl 1.1 RecursiveAction task2,
371     RecursiveAction task3) {
372     fst = task1;
373     snd = task2;
374     thr = task3;
375     }
376     public void compute() {
377     fst.invoke();
378     snd.invoke();
379     thr.invoke();
380     }
381     }
382    
383    
384    
385     }
386 jsr166 1.2