ViewVC Help
View File | Revision Log | Show Annotations | Download File | Root Listing
root/jsr166/jsr166/src/test/loops/LU.java
Revision: 1.7
Committed: Wed Dec 31 16:44:01 2014 UTC (9 years, 4 months ago) by jsr166
Branch: MAIN
Changes since 1.6: +0 -1 lines
Log Message:
remove unused imports

File Contents

# User Rev Content
1 dl 1.1 /*
2     * Written by Doug Lea with assistance from members of JCP JSR-166
3     * Expert Group and released to the public domain, as explained at
4 jsr166 1.5 * http://creativecommons.org/publicdomain/zero/1.0/
5 dl 1.1 */
6    
7     //import jsr166y.*;
8     import java.util.concurrent.*;
9    
10    
11     /**
12     * LU matrix decomposition demo
13     * Based on those in Cilk and Hood
14 jsr166 1.3 */
15 dl 1.1 public final class LU {
16    
17     /** for time conversion */
18     static final long NPS = (1000L * 1000 * 1000);
19    
20     // granularity is hard-wired as compile-time constant here
21 jsr166 1.2 static final int BLOCK_SIZE = 16;
22 dl 1.1 static final boolean CHECK = false; // set true to check answer
23    
24     public static void main(String[] args) throws Exception {
25    
26     final String usage = "Usage: java LU <threads> <matrix size (must be a power of two)> [runs] \n For example, try java LU 2 512";
27    
28     int procs = 0;
29     int n = 2048;
30     int runs = 5;
31     try {
32     if (args.length > 0)
33     procs = Integer.parseInt(args[0]);
34     if (args.length > 1)
35     n = Integer.parseInt(args[1]);
36 jsr166 1.2 if (args.length > 2)
37 dl 1.1 runs = Integer.parseInt(args[2]);
38     } catch (Exception e) {
39     System.out.println(usage);
40     return;
41     }
42    
43     if ( ((n & (n - 1)) != 0)) {
44     System.out.println(usage);
45     return;
46     }
47 jsr166 1.4 ForkJoinPool pool = (procs == 0) ? new ForkJoinPool() :
48 dl 1.1 new ForkJoinPool(procs);
49 jsr166 1.2 System.out.println("procs: " + pool.getParallelism() +
50 dl 1.1 " n: " + n + " runs: " + runs);
51     for (int run = 0; run < runs; ++run) {
52     double[][] m = new double[n][n];
53     randomInit(m, n);
54     double[][] copy = null;
55     if (CHECK) {
56     copy = new double[n][n];
57     for (int i = 0; i < n; ++i) {
58     for (int j = 0; j < n; ++j) {
59     copy[i][j] = m[i][j];
60     }
61     }
62     }
63     Block M = new Block(m, 0, 0);
64     long start = System.nanoTime();
65     pool.invoke(new LowerUpper(n, M));
66     long time = System.nanoTime() - start;
67     double secs = ((double)time) / NPS;
68     System.out.printf("\tTime: %7.3f\n", secs);
69    
70     if (CHECK) check(m, copy, n);
71     }
72     System.out.println(pool.toString());
73     pool.shutdown();
74     }
75    
76    
77     static void randomInit(double[][] M, int n) {
78     java.util.Random rng = new java.util.Random();
79     for (int i = 0; i < n; ++i)
80     for (int j = 0; j < n; ++j)
81     M[i][j] = rng.nextDouble();
82     // for compatibility with hood demo, force larger diagonals
83     for (int k = 0; k < n; ++k)
84     M[k][k] *= 10.0;
85     }
86    
87     static void check(double[][] LU, double[][] M, int n) {
88     double maxDiff = 0.0; // track max difference
89     for (int i = 0; i < n; ++i) {
90     for (int j = 0; j < n; ++j) {
91     double v = 0.0;
92     int k;
93     for (k = 0; k < i && k <= j; k++ ) v += LU[i][k] * LU[k][j];
94     if (k == i && k <= j ) v += LU[k][j];
95     double diff = M[i][j] - v;
96     if (diff < 0) diff = -diff;
97     if (diff > 0.001) {
98     System.out.println("large diff at[" + i + "," + j + "]: " + M[i][j] + " vs " + v);
99     }
100     if (diff > maxDiff) maxDiff = diff;
101     }
102     }
103    
104     System.out.println("Max difference = " + maxDiff);
105     }
106    
107    
108     // Blocks record underlying matrix, and offsets into current block
109     static final class Block {
110     final double[][] m;
111     final int loRow;
112     final int loCol;
113    
114     Block(double[][] mat, int lr, int lc) {
115     m = mat; loRow = lr; loCol = lc;
116     }
117     }
118    
119     static final class Schur extends RecursiveAction {
120     final int size;
121     final Block V;
122     final Block W;
123     final Block M;
124    
125     Schur(int size, Block V, Block W, Block M) {
126     this.size = size; this.V = V; this.W = W; this.M = M;
127     }
128    
129     void schur() { // base case
130     for (int j = 0; j < BLOCK_SIZE; ++j) {
131     for (int i = 0; i < BLOCK_SIZE; ++i) {
132     double s = M.m[i+M.loRow][j+M.loCol];
133     for (int k = 0; k < BLOCK_SIZE; ++k) {
134     s -= V.m[i+V.loRow][k+V.loCol] * W.m[k+W.loRow][j+W.loCol];
135     }
136     M.m[i+M.loRow][j+M.loCol] = s;
137     }
138     }
139     }
140    
141     public void compute() {
142     if (size == BLOCK_SIZE) {
143     schur();
144     }
145     else {
146     int h = size / 2;
147    
148     Block M00 = new Block(M.m, M.loRow, M.loCol);
149     Block M01 = new Block(M.m, M.loRow, M.loCol+h);
150     Block M10 = new Block(M.m, M.loRow+h, M.loCol);
151     Block M11 = new Block(M.m, M.loRow+h, M.loCol+h);
152    
153     Block V00 = new Block(V.m, V.loRow, V.loCol);
154     Block V01 = new Block(V.m, V.loRow, V.loCol+h);
155     Block V10 = new Block(V.m, V.loRow+h, V.loCol);
156     Block V11 = new Block(V.m, V.loRow+h, V.loCol+h);
157    
158     Block W00 = new Block(W.m, W.loRow, W.loCol);
159     Block W01 = new Block(W.m, W.loRow, W.loCol+h);
160     Block W10 = new Block(W.m, W.loRow+h, W.loCol);
161     Block W11 = new Block(W.m, W.loRow+h, W.loCol+h);
162 jsr166 1.2
163 dl 1.1 Seq2 s3 = seq(new Schur(h, V10, W01, M11),
164     new Schur(h, V11, W11, M11));
165     s3.fork();
166     Seq2 s2 = seq(new Schur(h, V10, W00, M10),
167     new Schur(h, V11, W10, M10));
168     s2.fork();
169     Seq2 s1 = seq(new Schur(h, V00, W01, M01),
170     new Schur(h, V01, W11, M01));
171     s1.fork();
172     new Schur(h, V00, W00, M00).compute();
173     new Schur(h, V01, W10, M00).compute();
174     if (s1.tryUnfork()) s1.compute(); else s1.join();
175     if (s2.tryUnfork()) s2.compute(); else s2.join();
176     if (s3.tryUnfork()) s3.compute(); else s3.join();
177     }
178     }
179     }
180    
181     static final class Lower extends RecursiveAction {
182     final int size;
183     final Block L;
184     final Block M;
185     Lower(int size, Block L, Block M) {
186     this.size = size; this.L = L; this.M = M;
187     }
188    
189     void lower() { // base case
190     for (int i = 1; i < BLOCK_SIZE; ++i) {
191     for (int k = 0; k < i; ++k) {
192     double a = L.m[i+L.loRow][k+L.loCol];
193     double[] x = M.m[k+M.loRow];
194     double[] y = M.m[i+M.loRow];
195     int n = BLOCK_SIZE;
196     for (int p = n-1; p >= 0; --p) {
197     y[p+M.loCol] -= a * x[p+M.loCol];
198     }
199     }
200     }
201     }
202    
203     public void compute() {
204     if (size == BLOCK_SIZE) {
205     lower();
206     }
207     else {
208     int h = size / 2;
209    
210     Block M00 = new Block(M.m, M.loRow, M.loCol);
211     Block M01 = new Block(M.m, M.loRow, M.loCol+h);
212     Block M10 = new Block(M.m, M.loRow+h, M.loCol);
213     Block M11 = new Block(M.m, M.loRow+h, M.loCol+h);
214    
215     Block L00 = new Block(L.m, L.loRow, L.loCol);
216     Block L01 = new Block(L.m, L.loRow, L.loCol+h);
217     Block L10 = new Block(L.m, L.loRow+h, L.loCol);
218     Block L11 = new Block(L.m, L.loRow+h, L.loCol+h);
219    
220 jsr166 1.2
221     Seq3 s1 =
222 dl 1.1 seq(new Lower(h, L00, M00),
223     new Schur(h, L10, M00, M10),
224     new Lower(h, L11, M10));
225 jsr166 1.2 Seq3 s2 =
226 dl 1.1 seq(new Lower(h, L00, M01),
227     new Schur(h, L10, M01, M11),
228     new Lower(h, L11, M11));
229     s2.fork();
230     s1.compute();
231     if (s2.tryUnfork()) s2.compute(); else s2.join();
232     }
233     }
234     }
235    
236    
237     static final class Upper extends RecursiveAction {
238     final int size;
239     final Block U;
240     final Block M;
241     Upper(int size, Block U, Block M) {
242     this.size = size; this.U = U; this.M = M;
243     }
244    
245     void upper() { // base case
246     for (int i = 0; i < BLOCK_SIZE; ++i) {
247     for (int k = 0; k < BLOCK_SIZE; ++k) {
248     double a = M.m[i+M.loRow][k+M.loCol] / U.m[k+U.loRow][k+U.loCol];
249     M.m[i+M.loRow][k+M.loCol] = a;
250     double[] x = U.m[k+U.loRow];
251     double[] y = M.m[i+M.loRow];
252     int n = BLOCK_SIZE - k - 1;
253     for (int p = n - 1; p >= 0; --p) {
254     y[p+k+1+M.loCol] -= a * x[p+k+1+U.loCol];
255     }
256     }
257     }
258     }
259    
260    
261     public void compute() {
262     if (size == BLOCK_SIZE) {
263     upper();
264     }
265     else {
266     int h = size / 2;
267    
268     Block M00 = new Block(M.m, M.loRow, M.loCol);
269     Block M01 = new Block(M.m, M.loRow, M.loCol+h);
270     Block M10 = new Block(M.m, M.loRow+h, M.loCol);
271     Block M11 = new Block(M.m, M.loRow+h, M.loCol+h);
272    
273     Block U00 = new Block(U.m, U.loRow, U.loCol);
274     Block U01 = new Block(U.m, U.loRow, U.loCol+h);
275     Block U10 = new Block(U.m, U.loRow+h, U.loCol);
276     Block U11 = new Block(U.m, U.loRow+h, U.loCol+h);
277    
278 jsr166 1.2 Seq3 s1 =
279 dl 1.1 seq(new Upper(h, U00, M00),
280     new Schur(h, M00, U01, M01),
281     new Upper(h, U11, M01));
282 jsr166 1.2 Seq3 s2 =
283 dl 1.1 seq(new Upper(h, U00, M10),
284     new Schur(h, M10, U01, M11),
285     new Upper(h, U11, M11));
286     s2.fork();
287     s1.compute();
288     if (s2.tryUnfork()) s2.compute(); else s2.join();
289     }
290     }
291     }
292 jsr166 1.2
293 dl 1.1
294     static final class LowerUpper extends RecursiveAction {
295     final int size;
296     final Block M;
297     LowerUpper(int size, Block M) {
298     this.size = size; this.M = M;
299     }
300    
301     void lu() { // base case
302     for (int k = 0; k < BLOCK_SIZE; ++k) {
303     for (int i = k+1; i < BLOCK_SIZE; ++i) {
304     double b = M.m[k+M.loRow][k+M.loCol];
305     double a = M.m[i+M.loRow][k+M.loCol] / b;
306     M.m[i+M.loRow][k+M.loCol] = a;
307     double[] x = M.m[k+M.loRow];
308     double[] y = M.m[i+M.loRow];
309     int n = BLOCK_SIZE-k-1;
310     for (int p = n-1; p >= 0; --p) {
311     y[k+1+p+M.loCol] -= a * x[k+1+p+M.loCol];
312     }
313     }
314     }
315     }
316    
317     public void compute() {
318     if (size == BLOCK_SIZE) {
319     lu();
320     }
321     else {
322     int h = size / 2;
323     Block M00 = new Block(M.m, M.loRow, M.loCol);
324     Block M01 = new Block(M.m, M.loRow, M.loCol+h);
325     Block M10 = new Block(M.m, M.loRow+h, M.loCol);
326     Block M11 = new Block(M.m, M.loRow+h, M.loCol+h);
327    
328     new LowerUpper(h, M00).compute();
329     Lower sl = new Lower(h, M00, M01);
330     Upper su = new Upper(h, M00, M10);
331     su.fork();
332     sl.compute();
333     if (su.tryUnfork()) su.compute(); else su.join();
334     new Schur(h, M10, M01, M11).compute();
335     new LowerUpper(h, M11).compute();
336     }
337     }
338     }
339    
340 jsr166 1.2 static Seq2 seq(RecursiveAction task1,
341     RecursiveAction task2) {
342     return new Seq2(task1, task2);
343 dl 1.1 }
344    
345     static final class Seq2 extends RecursiveAction {
346     final RecursiveAction fst;
347     final RecursiveAction snd;
348     public Seq2(RecursiveAction task1, RecursiveAction task2) {
349     fst = task1;
350     snd = task2;
351     }
352     public void compute() {
353     fst.invoke();
354     snd.invoke();
355     }
356     }
357    
358 jsr166 1.2 static Seq3 seq(RecursiveAction task1,
359 dl 1.1 RecursiveAction task2,
360 jsr166 1.2 RecursiveAction task3) {
361     return new Seq3(task1, task2, task3);
362 dl 1.1 }
363    
364     static final class Seq3 extends RecursiveAction {
365     final RecursiveAction fst;
366     final RecursiveAction snd;
367     final RecursiveAction thr;
368 jsr166 1.2 public Seq3(RecursiveAction task1,
369 dl 1.1 RecursiveAction task2,
370     RecursiveAction task3) {
371     fst = task1;
372     snd = task2;
373     thr = task3;
374     }
375     public void compute() {
376     fst.invoke();
377     snd.invoke();
378     thr.invoke();
379     }
380     }
381     }