ViewVC Help
View File | Revision Log | Show Annotations | Download File | Root Listing
root/jsr166/jsr166/src/test/loops/LU.java
Revision: 1.2
Committed: Mon Sep 20 20:42:37 2010 UTC (13 years, 7 months ago) by jsr166
Branch: MAIN
Changes since 1.1: +19 -19 lines
Log Message:
whitespace

File Contents

# User Rev Content
1 dl 1.1 /*
2     * Written by Doug Lea with assistance from members of JCP JSR-166
3     * Expert Group and released to the public domain, as explained at
4     * http://creativecommons.org/licenses/publicdomain
5     */
6    
7     //import jsr166y.*;
8     import java.util.concurrent.*;
9     import java.util.concurrent.TimeUnit;
10    
11    
12     /**
13     * LU matrix decomposition demo
14     * Based on those in Cilk and Hood
15     **/
16    
17     public final class LU {
18    
19     /** for time conversion */
20     static final long NPS = (1000L * 1000 * 1000);
21    
22     // granularity is hard-wired as compile-time constant here
23 jsr166 1.2 static final int BLOCK_SIZE = 16;
24 dl 1.1 static final boolean CHECK = false; // set true to check answer
25    
26     public static void main(String[] args) throws Exception {
27    
28     final String usage = "Usage: java LU <threads> <matrix size (must be a power of two)> [runs] \n For example, try java LU 2 512";
29    
30     int procs = 0;
31     int n = 2048;
32     int runs = 5;
33     try {
34     if (args.length > 0)
35     procs = Integer.parseInt(args[0]);
36     if (args.length > 1)
37     n = Integer.parseInt(args[1]);
38 jsr166 1.2 if (args.length > 2)
39 dl 1.1 runs = Integer.parseInt(args[2]);
40     } catch (Exception e) {
41     System.out.println(usage);
42     return;
43     }
44    
45     if ( ((n & (n - 1)) != 0)) {
46     System.out.println(usage);
47     return;
48     }
49 jsr166 1.2 ForkJoinPool pool = procs == 0? new ForkJoinPool() :
50 dl 1.1 new ForkJoinPool(procs);
51 jsr166 1.2 System.out.println("procs: " + pool.getParallelism() +
52 dl 1.1 " n: " + n + " runs: " + runs);
53     for (int run = 0; run < runs; ++run) {
54     double[][] m = new double[n][n];
55     randomInit(m, n);
56     double[][] copy = null;
57     if (CHECK) {
58     copy = new double[n][n];
59     for (int i = 0; i < n; ++i) {
60     for (int j = 0; j < n; ++j) {
61     copy[i][j] = m[i][j];
62     }
63     }
64     }
65     Block M = new Block(m, 0, 0);
66     long start = System.nanoTime();
67     pool.invoke(new LowerUpper(n, M));
68     long time = System.nanoTime() - start;
69     double secs = ((double)time) / NPS;
70     System.out.printf("\tTime: %7.3f\n", secs);
71    
72     if (CHECK) check(m, copy, n);
73     }
74     System.out.println(pool.toString());
75     pool.shutdown();
76     }
77    
78    
79     static void randomInit(double[][] M, int n) {
80     java.util.Random rng = new java.util.Random();
81     for (int i = 0; i < n; ++i)
82     for (int j = 0; j < n; ++j)
83     M[i][j] = rng.nextDouble();
84     // for compatibility with hood demo, force larger diagonals
85     for (int k = 0; k < n; ++k)
86     M[k][k] *= 10.0;
87     }
88    
89     static void check(double[][] LU, double[][] M, int n) {
90     double maxDiff = 0.0; // track max difference
91     for (int i = 0; i < n; ++i) {
92     for (int j = 0; j < n; ++j) {
93     double v = 0.0;
94     int k;
95     for (k = 0; k < i && k <= j; k++ ) v += LU[i][k] * LU[k][j];
96     if (k == i && k <= j ) v += LU[k][j];
97     double diff = M[i][j] - v;
98     if (diff < 0) diff = -diff;
99     if (diff > 0.001) {
100     System.out.println("large diff at[" + i + "," + j + "]: " + M[i][j] + " vs " + v);
101     }
102     if (diff > maxDiff) maxDiff = diff;
103     }
104     }
105    
106     System.out.println("Max difference = " + maxDiff);
107     }
108    
109    
110     // Blocks record underlying matrix, and offsets into current block
111     static final class Block {
112     final double[][] m;
113     final int loRow;
114     final int loCol;
115    
116     Block(double[][] mat, int lr, int lc) {
117     m = mat; loRow = lr; loCol = lc;
118     }
119     }
120    
121     static final class Schur extends RecursiveAction {
122     final int size;
123     final Block V;
124     final Block W;
125     final Block M;
126    
127     Schur(int size, Block V, Block W, Block M) {
128     this.size = size; this.V = V; this.W = W; this.M = M;
129     }
130    
131     void schur() { // base case
132     for (int j = 0; j < BLOCK_SIZE; ++j) {
133     for (int i = 0; i < BLOCK_SIZE; ++i) {
134     double s = M.m[i+M.loRow][j+M.loCol];
135     for (int k = 0; k < BLOCK_SIZE; ++k) {
136     s -= V.m[i+V.loRow][k+V.loCol] * W.m[k+W.loRow][j+W.loCol];
137     }
138     M.m[i+M.loRow][j+M.loCol] = s;
139     }
140     }
141     }
142    
143     public void compute() {
144     if (size == BLOCK_SIZE) {
145     schur();
146     }
147     else {
148     int h = size / 2;
149    
150     Block M00 = new Block(M.m, M.loRow, M.loCol);
151     Block M01 = new Block(M.m, M.loRow, M.loCol+h);
152     Block M10 = new Block(M.m, M.loRow+h, M.loCol);
153     Block M11 = new Block(M.m, M.loRow+h, M.loCol+h);
154    
155     Block V00 = new Block(V.m, V.loRow, V.loCol);
156     Block V01 = new Block(V.m, V.loRow, V.loCol+h);
157     Block V10 = new Block(V.m, V.loRow+h, V.loCol);
158     Block V11 = new Block(V.m, V.loRow+h, V.loCol+h);
159    
160     Block W00 = new Block(W.m, W.loRow, W.loCol);
161     Block W01 = new Block(W.m, W.loRow, W.loCol+h);
162     Block W10 = new Block(W.m, W.loRow+h, W.loCol);
163     Block W11 = new Block(W.m, W.loRow+h, W.loCol+h);
164 jsr166 1.2
165 dl 1.1 Seq2 s3 = seq(new Schur(h, V10, W01, M11),
166     new Schur(h, V11, W11, M11));
167     s3.fork();
168     Seq2 s2 = seq(new Schur(h, V10, W00, M10),
169     new Schur(h, V11, W10, M10));
170     s2.fork();
171     Seq2 s1 = seq(new Schur(h, V00, W01, M01),
172     new Schur(h, V01, W11, M01));
173     s1.fork();
174     new Schur(h, V00, W00, M00).compute();
175     new Schur(h, V01, W10, M00).compute();
176     if (s1.tryUnfork()) s1.compute(); else s1.join();
177     if (s2.tryUnfork()) s2.compute(); else s2.join();
178     if (s3.tryUnfork()) s3.compute(); else s3.join();
179     }
180     }
181     }
182    
183     static final class Lower extends RecursiveAction {
184     final int size;
185     final Block L;
186     final Block M;
187     Lower(int size, Block L, Block M) {
188     this.size = size; this.L = L; this.M = M;
189     }
190    
191     void lower() { // base case
192     for (int i = 1; i < BLOCK_SIZE; ++i) {
193     for (int k = 0; k < i; ++k) {
194     double a = L.m[i+L.loRow][k+L.loCol];
195     double[] x = M.m[k+M.loRow];
196     double[] y = M.m[i+M.loRow];
197     int n = BLOCK_SIZE;
198     for (int p = n-1; p >= 0; --p) {
199     y[p+M.loCol] -= a * x[p+M.loCol];
200     }
201     }
202     }
203     }
204    
205     public void compute() {
206     if (size == BLOCK_SIZE) {
207     lower();
208     }
209     else {
210     int h = size / 2;
211    
212     Block M00 = new Block(M.m, M.loRow, M.loCol);
213     Block M01 = new Block(M.m, M.loRow, M.loCol+h);
214     Block M10 = new Block(M.m, M.loRow+h, M.loCol);
215     Block M11 = new Block(M.m, M.loRow+h, M.loCol+h);
216    
217     Block L00 = new Block(L.m, L.loRow, L.loCol);
218     Block L01 = new Block(L.m, L.loRow, L.loCol+h);
219     Block L10 = new Block(L.m, L.loRow+h, L.loCol);
220     Block L11 = new Block(L.m, L.loRow+h, L.loCol+h);
221    
222 jsr166 1.2
223     Seq3 s1 =
224 dl 1.1 seq(new Lower(h, L00, M00),
225     new Schur(h, L10, M00, M10),
226     new Lower(h, L11, M10));
227 jsr166 1.2 Seq3 s2 =
228 dl 1.1 seq(new Lower(h, L00, M01),
229     new Schur(h, L10, M01, M11),
230     new Lower(h, L11, M11));
231     s2.fork();
232     s1.compute();
233     if (s2.tryUnfork()) s2.compute(); else s2.join();
234     }
235     }
236     }
237    
238    
239     static final class Upper extends RecursiveAction {
240     final int size;
241     final Block U;
242     final Block M;
243     Upper(int size, Block U, Block M) {
244     this.size = size; this.U = U; this.M = M;
245     }
246    
247     void upper() { // base case
248     for (int i = 0; i < BLOCK_SIZE; ++i) {
249     for (int k = 0; k < BLOCK_SIZE; ++k) {
250     double a = M.m[i+M.loRow][k+M.loCol] / U.m[k+U.loRow][k+U.loCol];
251     M.m[i+M.loRow][k+M.loCol] = a;
252     double[] x = U.m[k+U.loRow];
253     double[] y = M.m[i+M.loRow];
254     int n = BLOCK_SIZE - k - 1;
255     for (int p = n - 1; p >= 0; --p) {
256     y[p+k+1+M.loCol] -= a * x[p+k+1+U.loCol];
257     }
258     }
259     }
260     }
261    
262    
263     public void compute() {
264     if (size == BLOCK_SIZE) {
265     upper();
266     }
267     else {
268     int h = size / 2;
269    
270     Block M00 = new Block(M.m, M.loRow, M.loCol);
271     Block M01 = new Block(M.m, M.loRow, M.loCol+h);
272     Block M10 = new Block(M.m, M.loRow+h, M.loCol);
273     Block M11 = new Block(M.m, M.loRow+h, M.loCol+h);
274    
275     Block U00 = new Block(U.m, U.loRow, U.loCol);
276     Block U01 = new Block(U.m, U.loRow, U.loCol+h);
277     Block U10 = new Block(U.m, U.loRow+h, U.loCol);
278     Block U11 = new Block(U.m, U.loRow+h, U.loCol+h);
279    
280 jsr166 1.2 Seq3 s1 =
281 dl 1.1 seq(new Upper(h, U00, M00),
282     new Schur(h, M00, U01, M01),
283     new Upper(h, U11, M01));
284 jsr166 1.2 Seq3 s2 =
285 dl 1.1 seq(new Upper(h, U00, M10),
286     new Schur(h, M10, U01, M11),
287     new Upper(h, U11, M11));
288     s2.fork();
289     s1.compute();
290     if (s2.tryUnfork()) s2.compute(); else s2.join();
291     }
292     }
293     }
294 jsr166 1.2
295 dl 1.1
296     static final class LowerUpper extends RecursiveAction {
297     final int size;
298     final Block M;
299     LowerUpper(int size, Block M) {
300     this.size = size; this.M = M;
301     }
302    
303     void lu() { // base case
304     for (int k = 0; k < BLOCK_SIZE; ++k) {
305     for (int i = k+1; i < BLOCK_SIZE; ++i) {
306     double b = M.m[k+M.loRow][k+M.loCol];
307     double a = M.m[i+M.loRow][k+M.loCol] / b;
308     M.m[i+M.loRow][k+M.loCol] = a;
309     double[] x = M.m[k+M.loRow];
310     double[] y = M.m[i+M.loRow];
311     int n = BLOCK_SIZE-k-1;
312     for (int p = n-1; p >= 0; --p) {
313     y[k+1+p+M.loCol] -= a * x[k+1+p+M.loCol];
314     }
315     }
316     }
317     }
318    
319     public void compute() {
320     if (size == BLOCK_SIZE) {
321     lu();
322     }
323     else {
324     int h = size / 2;
325     Block M00 = new Block(M.m, M.loRow, M.loCol);
326     Block M01 = new Block(M.m, M.loRow, M.loCol+h);
327     Block M10 = new Block(M.m, M.loRow+h, M.loCol);
328     Block M11 = new Block(M.m, M.loRow+h, M.loCol+h);
329    
330     new LowerUpper(h, M00).compute();
331     Lower sl = new Lower(h, M00, M01);
332     Upper su = new Upper(h, M00, M10);
333     su.fork();
334     sl.compute();
335     if (su.tryUnfork()) su.compute(); else su.join();
336     new Schur(h, M10, M01, M11).compute();
337     new LowerUpper(h, M11).compute();
338     }
339     }
340     }
341    
342 jsr166 1.2 static Seq2 seq(RecursiveAction task1,
343     RecursiveAction task2) {
344     return new Seq2(task1, task2);
345 dl 1.1 }
346    
347     static final class Seq2 extends RecursiveAction {
348     final RecursiveAction fst;
349     final RecursiveAction snd;
350     public Seq2(RecursiveAction task1, RecursiveAction task2) {
351     fst = task1;
352     snd = task2;
353     }
354     public void compute() {
355     fst.invoke();
356     snd.invoke();
357     }
358     }
359    
360 jsr166 1.2 static Seq3 seq(RecursiveAction task1,
361 dl 1.1 RecursiveAction task2,
362 jsr166 1.2 RecursiveAction task3) {
363     return new Seq3(task1, task2, task3);
364 dl 1.1 }
365    
366     static final class Seq3 extends RecursiveAction {
367     final RecursiveAction fst;
368     final RecursiveAction snd;
369     final RecursiveAction thr;
370 jsr166 1.2 public Seq3(RecursiveAction task1,
371 dl 1.1 RecursiveAction task2,
372     RecursiveAction task3) {
373     fst = task1;
374     snd = task2;
375     thr = task3;
376     }
377     public void compute() {
378     fst.invoke();
379     snd.invoke();
380     thr.invoke();
381     }
382     }
383    
384    
385    
386     }
387 jsr166 1.2