ViewVC Help
View File | Revision Log | Show Annotations | Download File | Root Listing
root/jsr166/jsr166/src/test/loops/LU.java
Revision: 1.9
Committed: Sat Sep 12 18:59:08 2015 UTC (8 years, 8 months ago) by dl
Branch: MAIN
CVS Tags: HEAD
Changes since 1.8: +4 -3 lines
Log Message:
Use commonPool

File Contents

# User Rev Content
1 dl 1.1 /*
2     * Written by Doug Lea with assistance from members of JCP JSR-166
3     * Expert Group and released to the public domain, as explained at
4 jsr166 1.5 * http://creativecommons.org/publicdomain/zero/1.0/
5 dl 1.1 */
6    
7     //import jsr166y.*;
8     import java.util.concurrent.*;
9    
10     /**
11     * LU matrix decomposition demo
12     * Based on those in Cilk and Hood
13 jsr166 1.3 */
14 dl 1.1 public final class LU {
15    
16     /** for time conversion */
17     static final long NPS = (1000L * 1000 * 1000);
18    
19     // granularity is hard-wired as compile-time constant here
20 jsr166 1.2 static final int BLOCK_SIZE = 16;
21 dl 1.1 static final boolean CHECK = false; // set true to check answer
22    
23     public static void main(String[] args) throws Exception {
24    
25     final String usage = "Usage: java LU <threads> <matrix size (must be a power of two)> [runs] \n For example, try java LU 2 512";
26    
27     int procs = 0;
28     int n = 2048;
29 dl 1.9 int runs = 32;
30 dl 1.1 try {
31     if (args.length > 0)
32     procs = Integer.parseInt(args[0]);
33     if (args.length > 1)
34     n = Integer.parseInt(args[1]);
35 jsr166 1.2 if (args.length > 2)
36 dl 1.1 runs = Integer.parseInt(args[2]);
37     } catch (Exception e) {
38     System.out.println(usage);
39     return;
40     }
41    
42     if ( ((n & (n - 1)) != 0)) {
43     System.out.println(usage);
44     return;
45     }
46 dl 1.9 ForkJoinPool pool = (procs == 0) ? ForkJoinPool.commonPool() :
47 dl 1.1 new ForkJoinPool(procs);
48 jsr166 1.2 System.out.println("procs: " + pool.getParallelism() +
49 dl 1.1 " n: " + n + " runs: " + runs);
50     for (int run = 0; run < runs; ++run) {
51     double[][] m = new double[n][n];
52     randomInit(m, n);
53     double[][] copy = null;
54     if (CHECK) {
55     copy = new double[n][n];
56     for (int i = 0; i < n; ++i) {
57     for (int j = 0; j < n; ++j) {
58     copy[i][j] = m[i][j];
59     }
60     }
61     }
62     Block M = new Block(m, 0, 0);
63     long start = System.nanoTime();
64     pool.invoke(new LowerUpper(n, M));
65     long time = System.nanoTime() - start;
66     double secs = ((double)time) / NPS;
67 dl 1.9 System.out.printf("Time: %7.3f ", secs);
68     if ((run & 3) == 3) System.out.println();
69 dl 1.1
70     if (CHECK) check(m, copy, n);
71     }
72     System.out.println(pool.toString());
73     pool.shutdown();
74     }
75    
76     static void randomInit(double[][] M, int n) {
77     java.util.Random rng = new java.util.Random();
78     for (int i = 0; i < n; ++i)
79     for (int j = 0; j < n; ++j)
80     M[i][j] = rng.nextDouble();
81     // for compatibility with hood demo, force larger diagonals
82     for (int k = 0; k < n; ++k)
83     M[k][k] *= 10.0;
84     }
85    
86     static void check(double[][] LU, double[][] M, int n) {
87     double maxDiff = 0.0; // track max difference
88     for (int i = 0; i < n; ++i) {
89     for (int j = 0; j < n; ++j) {
90     double v = 0.0;
91     int k;
92     for (k = 0; k < i && k <= j; k++ ) v += LU[i][k] * LU[k][j];
93     if (k == i && k <= j ) v += LU[k][j];
94     double diff = M[i][j] - v;
95     if (diff < 0) diff = -diff;
96     if (diff > 0.001) {
97     System.out.println("large diff at[" + i + "," + j + "]: " + M[i][j] + " vs " + v);
98     }
99     if (diff > maxDiff) maxDiff = diff;
100     }
101     }
102    
103     System.out.println("Max difference = " + maxDiff);
104     }
105    
106     // Blocks record underlying matrix, and offsets into current block
107     static final class Block {
108     final double[][] m;
109     final int loRow;
110     final int loCol;
111    
112     Block(double[][] mat, int lr, int lc) {
113     m = mat; loRow = lr; loCol = lc;
114     }
115     }
116    
117     static final class Schur extends RecursiveAction {
118     final int size;
119     final Block V;
120     final Block W;
121     final Block M;
122    
123     Schur(int size, Block V, Block W, Block M) {
124     this.size = size; this.V = V; this.W = W; this.M = M;
125     }
126    
127     void schur() { // base case
128     for (int j = 0; j < BLOCK_SIZE; ++j) {
129     for (int i = 0; i < BLOCK_SIZE; ++i) {
130     double s = M.m[i+M.loRow][j+M.loCol];
131     for (int k = 0; k < BLOCK_SIZE; ++k) {
132     s -= V.m[i+V.loRow][k+V.loCol] * W.m[k+W.loRow][j+W.loCol];
133     }
134     M.m[i+M.loRow][j+M.loCol] = s;
135     }
136     }
137     }
138    
139     public void compute() {
140     if (size == BLOCK_SIZE) {
141     schur();
142     }
143     else {
144     int h = size / 2;
145    
146     Block M00 = new Block(M.m, M.loRow, M.loCol);
147     Block M01 = new Block(M.m, M.loRow, M.loCol+h);
148     Block M10 = new Block(M.m, M.loRow+h, M.loCol);
149     Block M11 = new Block(M.m, M.loRow+h, M.loCol+h);
150    
151     Block V00 = new Block(V.m, V.loRow, V.loCol);
152     Block V01 = new Block(V.m, V.loRow, V.loCol+h);
153     Block V10 = new Block(V.m, V.loRow+h, V.loCol);
154     Block V11 = new Block(V.m, V.loRow+h, V.loCol+h);
155    
156     Block W00 = new Block(W.m, W.loRow, W.loCol);
157     Block W01 = new Block(W.m, W.loRow, W.loCol+h);
158     Block W10 = new Block(W.m, W.loRow+h, W.loCol);
159     Block W11 = new Block(W.m, W.loRow+h, W.loCol+h);
160 jsr166 1.2
161 dl 1.1 Seq2 s3 = seq(new Schur(h, V10, W01, M11),
162     new Schur(h, V11, W11, M11));
163     s3.fork();
164     Seq2 s2 = seq(new Schur(h, V10, W00, M10),
165     new Schur(h, V11, W10, M10));
166     s2.fork();
167     Seq2 s1 = seq(new Schur(h, V00, W01, M01),
168     new Schur(h, V01, W11, M01));
169     s1.fork();
170     new Schur(h, V00, W00, M00).compute();
171     new Schur(h, V01, W10, M00).compute();
172     if (s1.tryUnfork()) s1.compute(); else s1.join();
173     if (s2.tryUnfork()) s2.compute(); else s2.join();
174     if (s3.tryUnfork()) s3.compute(); else s3.join();
175     }
176     }
177     }
178    
179     static final class Lower extends RecursiveAction {
180     final int size;
181     final Block L;
182     final Block M;
183     Lower(int size, Block L, Block M) {
184     this.size = size; this.L = L; this.M = M;
185     }
186    
187     void lower() { // base case
188     for (int i = 1; i < BLOCK_SIZE; ++i) {
189     for (int k = 0; k < i; ++k) {
190     double a = L.m[i+L.loRow][k+L.loCol];
191     double[] x = M.m[k+M.loRow];
192     double[] y = M.m[i+M.loRow];
193     int n = BLOCK_SIZE;
194     for (int p = n-1; p >= 0; --p) {
195     y[p+M.loCol] -= a * x[p+M.loCol];
196     }
197     }
198     }
199     }
200    
201     public void compute() {
202     if (size == BLOCK_SIZE) {
203     lower();
204     }
205     else {
206     int h = size / 2;
207    
208     Block M00 = new Block(M.m, M.loRow, M.loCol);
209     Block M01 = new Block(M.m, M.loRow, M.loCol+h);
210     Block M10 = new Block(M.m, M.loRow+h, M.loCol);
211     Block M11 = new Block(M.m, M.loRow+h, M.loCol+h);
212    
213     Block L00 = new Block(L.m, L.loRow, L.loCol);
214     Block L01 = new Block(L.m, L.loRow, L.loCol+h);
215     Block L10 = new Block(L.m, L.loRow+h, L.loCol);
216     Block L11 = new Block(L.m, L.loRow+h, L.loCol+h);
217    
218 jsr166 1.2 Seq3 s1 =
219 dl 1.1 seq(new Lower(h, L00, M00),
220     new Schur(h, L10, M00, M10),
221     new Lower(h, L11, M10));
222 jsr166 1.2 Seq3 s2 =
223 dl 1.1 seq(new Lower(h, L00, M01),
224     new Schur(h, L10, M01, M11),
225     new Lower(h, L11, M11));
226     s2.fork();
227     s1.compute();
228     if (s2.tryUnfork()) s2.compute(); else s2.join();
229     }
230     }
231     }
232    
233     static final class Upper extends RecursiveAction {
234     final int size;
235     final Block U;
236     final Block M;
237     Upper(int size, Block U, Block M) {
238     this.size = size; this.U = U; this.M = M;
239     }
240    
241     void upper() { // base case
242     for (int i = 0; i < BLOCK_SIZE; ++i) {
243     for (int k = 0; k < BLOCK_SIZE; ++k) {
244     double a = M.m[i+M.loRow][k+M.loCol] / U.m[k+U.loRow][k+U.loCol];
245     M.m[i+M.loRow][k+M.loCol] = a;
246     double[] x = U.m[k+U.loRow];
247     double[] y = M.m[i+M.loRow];
248     int n = BLOCK_SIZE - k - 1;
249     for (int p = n - 1; p >= 0; --p) {
250     y[p+k+1+M.loCol] -= a * x[p+k+1+U.loCol];
251     }
252     }
253     }
254     }
255    
256     public void compute() {
257     if (size == BLOCK_SIZE) {
258     upper();
259     }
260     else {
261     int h = size / 2;
262    
263     Block M00 = new Block(M.m, M.loRow, M.loCol);
264     Block M01 = new Block(M.m, M.loRow, M.loCol+h);
265     Block M10 = new Block(M.m, M.loRow+h, M.loCol);
266     Block M11 = new Block(M.m, M.loRow+h, M.loCol+h);
267    
268     Block U00 = new Block(U.m, U.loRow, U.loCol);
269     Block U01 = new Block(U.m, U.loRow, U.loCol+h);
270     Block U10 = new Block(U.m, U.loRow+h, U.loCol);
271     Block U11 = new Block(U.m, U.loRow+h, U.loCol+h);
272    
273 jsr166 1.2 Seq3 s1 =
274 dl 1.1 seq(new Upper(h, U00, M00),
275     new Schur(h, M00, U01, M01),
276     new Upper(h, U11, M01));
277 jsr166 1.2 Seq3 s2 =
278 dl 1.1 seq(new Upper(h, U00, M10),
279     new Schur(h, M10, U01, M11),
280     new Upper(h, U11, M11));
281     s2.fork();
282     s1.compute();
283     if (s2.tryUnfork()) s2.compute(); else s2.join();
284     }
285     }
286     }
287 jsr166 1.2
288 dl 1.1 static final class LowerUpper extends RecursiveAction {
289     final int size;
290     final Block M;
291     LowerUpper(int size, Block M) {
292     this.size = size; this.M = M;
293     }
294    
295     void lu() { // base case
296     for (int k = 0; k < BLOCK_SIZE; ++k) {
297     for (int i = k+1; i < BLOCK_SIZE; ++i) {
298     double b = M.m[k+M.loRow][k+M.loCol];
299     double a = M.m[i+M.loRow][k+M.loCol] / b;
300     M.m[i+M.loRow][k+M.loCol] = a;
301     double[] x = M.m[k+M.loRow];
302     double[] y = M.m[i+M.loRow];
303     int n = BLOCK_SIZE-k-1;
304     for (int p = n-1; p >= 0; --p) {
305     y[k+1+p+M.loCol] -= a * x[k+1+p+M.loCol];
306     }
307     }
308     }
309     }
310    
311     public void compute() {
312     if (size == BLOCK_SIZE) {
313     lu();
314     }
315     else {
316     int h = size / 2;
317     Block M00 = new Block(M.m, M.loRow, M.loCol);
318     Block M01 = new Block(M.m, M.loRow, M.loCol+h);
319     Block M10 = new Block(M.m, M.loRow+h, M.loCol);
320     Block M11 = new Block(M.m, M.loRow+h, M.loCol+h);
321    
322     new LowerUpper(h, M00).compute();
323     Lower sl = new Lower(h, M00, M01);
324     Upper su = new Upper(h, M00, M10);
325     su.fork();
326     sl.compute();
327     if (su.tryUnfork()) su.compute(); else su.join();
328     new Schur(h, M10, M01, M11).compute();
329     new LowerUpper(h, M11).compute();
330     }
331     }
332     }
333    
334 jsr166 1.2 static Seq2 seq(RecursiveAction task1,
335     RecursiveAction task2) {
336     return new Seq2(task1, task2);
337 dl 1.1 }
338    
339     static final class Seq2 extends RecursiveAction {
340     final RecursiveAction fst;
341     final RecursiveAction snd;
342     public Seq2(RecursiveAction task1, RecursiveAction task2) {
343     fst = task1;
344     snd = task2;
345     }
346     public void compute() {
347     fst.invoke();
348     snd.invoke();
349     }
350     }
351    
352 jsr166 1.2 static Seq3 seq(RecursiveAction task1,
353 dl 1.1 RecursiveAction task2,
354 jsr166 1.2 RecursiveAction task3) {
355     return new Seq3(task1, task2, task3);
356 dl 1.1 }
357    
358     static final class Seq3 extends RecursiveAction {
359     final RecursiveAction fst;
360     final RecursiveAction snd;
361     final RecursiveAction thr;
362 jsr166 1.2 public Seq3(RecursiveAction task1,
363 dl 1.1 RecursiveAction task2,
364     RecursiveAction task3) {
365     fst = task1;
366     snd = task2;
367     thr = task3;
368     }
369     public void compute() {
370     fst.invoke();
371     snd.invoke();
372     thr.invoke();
373     }
374     }
375     }