ViewVC Help
View File | Revision Log | Show Annotations | Download File | Root Listing
root/jsr166/jsr166/src/test/loops/LU.java
Revision: 1.8
Committed: Thu Jan 15 18:34:19 2015 UTC (9 years, 4 months ago) by jsr166
Branch: MAIN
Changes since 1.7: +0 -7 lines
Log Message:
delete extraneous blank lines

File Contents

# User Rev Content
1 dl 1.1 /*
2     * Written by Doug Lea with assistance from members of JCP JSR-166
3     * Expert Group and released to the public domain, as explained at
4 jsr166 1.5 * http://creativecommons.org/publicdomain/zero/1.0/
5 dl 1.1 */
6    
7     //import jsr166y.*;
8     import java.util.concurrent.*;
9    
10     /**
11     * LU matrix decomposition demo
12     * Based on those in Cilk and Hood
13 jsr166 1.3 */
14 dl 1.1 public final class LU {
15    
16     /** for time conversion */
17     static final long NPS = (1000L * 1000 * 1000);
18    
19     // granularity is hard-wired as compile-time constant here
20 jsr166 1.2 static final int BLOCK_SIZE = 16;
21 dl 1.1 static final boolean CHECK = false; // set true to check answer
22    
23     public static void main(String[] args) throws Exception {
24    
25     final String usage = "Usage: java LU <threads> <matrix size (must be a power of two)> [runs] \n For example, try java LU 2 512";
26    
27     int procs = 0;
28     int n = 2048;
29     int runs = 5;
30     try {
31     if (args.length > 0)
32     procs = Integer.parseInt(args[0]);
33     if (args.length > 1)
34     n = Integer.parseInt(args[1]);
35 jsr166 1.2 if (args.length > 2)
36 dl 1.1 runs = Integer.parseInt(args[2]);
37     } catch (Exception e) {
38     System.out.println(usage);
39     return;
40     }
41    
42     if ( ((n & (n - 1)) != 0)) {
43     System.out.println(usage);
44     return;
45     }
46 jsr166 1.4 ForkJoinPool pool = (procs == 0) ? new ForkJoinPool() :
47 dl 1.1 new ForkJoinPool(procs);
48 jsr166 1.2 System.out.println("procs: " + pool.getParallelism() +
49 dl 1.1 " n: " + n + " runs: " + runs);
50     for (int run = 0; run < runs; ++run) {
51     double[][] m = new double[n][n];
52     randomInit(m, n);
53     double[][] copy = null;
54     if (CHECK) {
55     copy = new double[n][n];
56     for (int i = 0; i < n; ++i) {
57     for (int j = 0; j < n; ++j) {
58     copy[i][j] = m[i][j];
59     }
60     }
61     }
62     Block M = new Block(m, 0, 0);
63     long start = System.nanoTime();
64     pool.invoke(new LowerUpper(n, M));
65     long time = System.nanoTime() - start;
66     double secs = ((double)time) / NPS;
67     System.out.printf("\tTime: %7.3f\n", secs);
68    
69     if (CHECK) check(m, copy, n);
70     }
71     System.out.println(pool.toString());
72     pool.shutdown();
73     }
74    
75     static void randomInit(double[][] M, int n) {
76     java.util.Random rng = new java.util.Random();
77     for (int i = 0; i < n; ++i)
78     for (int j = 0; j < n; ++j)
79     M[i][j] = rng.nextDouble();
80     // for compatibility with hood demo, force larger diagonals
81     for (int k = 0; k < n; ++k)
82     M[k][k] *= 10.0;
83     }
84    
85     static void check(double[][] LU, double[][] M, int n) {
86     double maxDiff = 0.0; // track max difference
87     for (int i = 0; i < n; ++i) {
88     for (int j = 0; j < n; ++j) {
89     double v = 0.0;
90     int k;
91     for (k = 0; k < i && k <= j; k++ ) v += LU[i][k] * LU[k][j];
92     if (k == i && k <= j ) v += LU[k][j];
93     double diff = M[i][j] - v;
94     if (diff < 0) diff = -diff;
95     if (diff > 0.001) {
96     System.out.println("large diff at[" + i + "," + j + "]: " + M[i][j] + " vs " + v);
97     }
98     if (diff > maxDiff) maxDiff = diff;
99     }
100     }
101    
102     System.out.println("Max difference = " + maxDiff);
103     }
104    
105     // Blocks record underlying matrix, and offsets into current block
106     static final class Block {
107     final double[][] m;
108     final int loRow;
109     final int loCol;
110    
111     Block(double[][] mat, int lr, int lc) {
112     m = mat; loRow = lr; loCol = lc;
113     }
114     }
115    
116     static final class Schur extends RecursiveAction {
117     final int size;
118     final Block V;
119     final Block W;
120     final Block M;
121    
122     Schur(int size, Block V, Block W, Block M) {
123     this.size = size; this.V = V; this.W = W; this.M = M;
124     }
125    
126     void schur() { // base case
127     for (int j = 0; j < BLOCK_SIZE; ++j) {
128     for (int i = 0; i < BLOCK_SIZE; ++i) {
129     double s = M.m[i+M.loRow][j+M.loCol];
130     for (int k = 0; k < BLOCK_SIZE; ++k) {
131     s -= V.m[i+V.loRow][k+V.loCol] * W.m[k+W.loRow][j+W.loCol];
132     }
133     M.m[i+M.loRow][j+M.loCol] = s;
134     }
135     }
136     }
137    
138     public void compute() {
139     if (size == BLOCK_SIZE) {
140     schur();
141     }
142     else {
143     int h = size / 2;
144    
145     Block M00 = new Block(M.m, M.loRow, M.loCol);
146     Block M01 = new Block(M.m, M.loRow, M.loCol+h);
147     Block M10 = new Block(M.m, M.loRow+h, M.loCol);
148     Block M11 = new Block(M.m, M.loRow+h, M.loCol+h);
149    
150     Block V00 = new Block(V.m, V.loRow, V.loCol);
151     Block V01 = new Block(V.m, V.loRow, V.loCol+h);
152     Block V10 = new Block(V.m, V.loRow+h, V.loCol);
153     Block V11 = new Block(V.m, V.loRow+h, V.loCol+h);
154    
155     Block W00 = new Block(W.m, W.loRow, W.loCol);
156     Block W01 = new Block(W.m, W.loRow, W.loCol+h);
157     Block W10 = new Block(W.m, W.loRow+h, W.loCol);
158     Block W11 = new Block(W.m, W.loRow+h, W.loCol+h);
159 jsr166 1.2
160 dl 1.1 Seq2 s3 = seq(new Schur(h, V10, W01, M11),
161     new Schur(h, V11, W11, M11));
162     s3.fork();
163     Seq2 s2 = seq(new Schur(h, V10, W00, M10),
164     new Schur(h, V11, W10, M10));
165     s2.fork();
166     Seq2 s1 = seq(new Schur(h, V00, W01, M01),
167     new Schur(h, V01, W11, M01));
168     s1.fork();
169     new Schur(h, V00, W00, M00).compute();
170     new Schur(h, V01, W10, M00).compute();
171     if (s1.tryUnfork()) s1.compute(); else s1.join();
172     if (s2.tryUnfork()) s2.compute(); else s2.join();
173     if (s3.tryUnfork()) s3.compute(); else s3.join();
174     }
175     }
176     }
177    
178     static final class Lower extends RecursiveAction {
179     final int size;
180     final Block L;
181     final Block M;
182     Lower(int size, Block L, Block M) {
183     this.size = size; this.L = L; this.M = M;
184     }
185    
186     void lower() { // base case
187     for (int i = 1; i < BLOCK_SIZE; ++i) {
188     for (int k = 0; k < i; ++k) {
189     double a = L.m[i+L.loRow][k+L.loCol];
190     double[] x = M.m[k+M.loRow];
191     double[] y = M.m[i+M.loRow];
192     int n = BLOCK_SIZE;
193     for (int p = n-1; p >= 0; --p) {
194     y[p+M.loCol] -= a * x[p+M.loCol];
195     }
196     }
197     }
198     }
199    
200     public void compute() {
201     if (size == BLOCK_SIZE) {
202     lower();
203     }
204     else {
205     int h = size / 2;
206    
207     Block M00 = new Block(M.m, M.loRow, M.loCol);
208     Block M01 = new Block(M.m, M.loRow, M.loCol+h);
209     Block M10 = new Block(M.m, M.loRow+h, M.loCol);
210     Block M11 = new Block(M.m, M.loRow+h, M.loCol+h);
211    
212     Block L00 = new Block(L.m, L.loRow, L.loCol);
213     Block L01 = new Block(L.m, L.loRow, L.loCol+h);
214     Block L10 = new Block(L.m, L.loRow+h, L.loCol);
215     Block L11 = new Block(L.m, L.loRow+h, L.loCol+h);
216    
217 jsr166 1.2 Seq3 s1 =
218 dl 1.1 seq(new Lower(h, L00, M00),
219     new Schur(h, L10, M00, M10),
220     new Lower(h, L11, M10));
221 jsr166 1.2 Seq3 s2 =
222 dl 1.1 seq(new Lower(h, L00, M01),
223     new Schur(h, L10, M01, M11),
224     new Lower(h, L11, M11));
225     s2.fork();
226     s1.compute();
227     if (s2.tryUnfork()) s2.compute(); else s2.join();
228     }
229     }
230     }
231    
232     static final class Upper extends RecursiveAction {
233     final int size;
234     final Block U;
235     final Block M;
236     Upper(int size, Block U, Block M) {
237     this.size = size; this.U = U; this.M = M;
238     }
239    
240     void upper() { // base case
241     for (int i = 0; i < BLOCK_SIZE; ++i) {
242     for (int k = 0; k < BLOCK_SIZE; ++k) {
243     double a = M.m[i+M.loRow][k+M.loCol] / U.m[k+U.loRow][k+U.loCol];
244     M.m[i+M.loRow][k+M.loCol] = a;
245     double[] x = U.m[k+U.loRow];
246     double[] y = M.m[i+M.loRow];
247     int n = BLOCK_SIZE - k - 1;
248     for (int p = n - 1; p >= 0; --p) {
249     y[p+k+1+M.loCol] -= a * x[p+k+1+U.loCol];
250     }
251     }
252     }
253     }
254    
255     public void compute() {
256     if (size == BLOCK_SIZE) {
257     upper();
258     }
259     else {
260     int h = size / 2;
261    
262     Block M00 = new Block(M.m, M.loRow, M.loCol);
263     Block M01 = new Block(M.m, M.loRow, M.loCol+h);
264     Block M10 = new Block(M.m, M.loRow+h, M.loCol);
265     Block M11 = new Block(M.m, M.loRow+h, M.loCol+h);
266    
267     Block U00 = new Block(U.m, U.loRow, U.loCol);
268     Block U01 = new Block(U.m, U.loRow, U.loCol+h);
269     Block U10 = new Block(U.m, U.loRow+h, U.loCol);
270     Block U11 = new Block(U.m, U.loRow+h, U.loCol+h);
271    
272 jsr166 1.2 Seq3 s1 =
273 dl 1.1 seq(new Upper(h, U00, M00),
274     new Schur(h, M00, U01, M01),
275     new Upper(h, U11, M01));
276 jsr166 1.2 Seq3 s2 =
277 dl 1.1 seq(new Upper(h, U00, M10),
278     new Schur(h, M10, U01, M11),
279     new Upper(h, U11, M11));
280     s2.fork();
281     s1.compute();
282     if (s2.tryUnfork()) s2.compute(); else s2.join();
283     }
284     }
285     }
286 jsr166 1.2
287 dl 1.1 static final class LowerUpper extends RecursiveAction {
288     final int size;
289     final Block M;
290     LowerUpper(int size, Block M) {
291     this.size = size; this.M = M;
292     }
293    
294     void lu() { // base case
295     for (int k = 0; k < BLOCK_SIZE; ++k) {
296     for (int i = k+1; i < BLOCK_SIZE; ++i) {
297     double b = M.m[k+M.loRow][k+M.loCol];
298     double a = M.m[i+M.loRow][k+M.loCol] / b;
299     M.m[i+M.loRow][k+M.loCol] = a;
300     double[] x = M.m[k+M.loRow];
301     double[] y = M.m[i+M.loRow];
302     int n = BLOCK_SIZE-k-1;
303     for (int p = n-1; p >= 0; --p) {
304     y[k+1+p+M.loCol] -= a * x[k+1+p+M.loCol];
305     }
306     }
307     }
308     }
309    
310     public void compute() {
311     if (size == BLOCK_SIZE) {
312     lu();
313     }
314     else {
315     int h = size / 2;
316     Block M00 = new Block(M.m, M.loRow, M.loCol);
317     Block M01 = new Block(M.m, M.loRow, M.loCol+h);
318     Block M10 = new Block(M.m, M.loRow+h, M.loCol);
319     Block M11 = new Block(M.m, M.loRow+h, M.loCol+h);
320    
321     new LowerUpper(h, M00).compute();
322     Lower sl = new Lower(h, M00, M01);
323     Upper su = new Upper(h, M00, M10);
324     su.fork();
325     sl.compute();
326     if (su.tryUnfork()) su.compute(); else su.join();
327     new Schur(h, M10, M01, M11).compute();
328     new LowerUpper(h, M11).compute();
329     }
330     }
331     }
332    
333 jsr166 1.2 static Seq2 seq(RecursiveAction task1,
334     RecursiveAction task2) {
335     return new Seq2(task1, task2);
336 dl 1.1 }
337    
338     static final class Seq2 extends RecursiveAction {
339     final RecursiveAction fst;
340     final RecursiveAction snd;
341     public Seq2(RecursiveAction task1, RecursiveAction task2) {
342     fst = task1;
343     snd = task2;
344     }
345     public void compute() {
346     fst.invoke();
347     snd.invoke();
348     }
349     }
350    
351 jsr166 1.2 static Seq3 seq(RecursiveAction task1,
352 dl 1.1 RecursiveAction task2,
353 jsr166 1.2 RecursiveAction task3) {
354     return new Seq3(task1, task2, task3);
355 dl 1.1 }
356    
357     static final class Seq3 extends RecursiveAction {
358     final RecursiveAction fst;
359     final RecursiveAction snd;
360     final RecursiveAction thr;
361 jsr166 1.2 public Seq3(RecursiveAction task1,
362 dl 1.1 RecursiveAction task2,
363     RecursiveAction task3) {
364     fst = task1;
365     snd = task2;
366     thr = task3;
367     }
368     public void compute() {
369     fst.invoke();
370     snd.invoke();
371     thr.invoke();
372     }
373     }
374     }