ViewVC Help
View File | Revision Log | Show Annotations | Download File | Root Listing
root/jsr166/jsr166/src/test/loops/MatrixMultiply.java
Revision: 1.8
Committed: Thu Jan 15 18:34:19 2015 UTC (9 years, 4 months ago) by jsr166
Branch: MAIN
Changes since 1.7: +0 -4 lines
Log Message:
delete extraneous blank lines

File Contents

# User Rev Content
1 dl 1.1 /*
2     * Written by Doug Lea with assistance from members of JCP JSR-166
3     * Expert Group and released to the public domain, as explained at
4 jsr166 1.5 * http://creativecommons.org/publicdomain/zero/1.0/
5 dl 1.1 */
6    
7     //import jsr166y.*;
8     import java.util.concurrent.*;
9    
10     /**
11     * Divide and Conquer matrix multiply demo
12 jsr166 1.3 */
13 dl 1.1 public class MatrixMultiply {
14    
15     /** for time conversion */
16     static final long NPS = (1000L * 1000 * 1000);
17    
18     static final int DEFAULT_GRANULARITY = 32;
19    
20 jsr166 1.3 /**
21     * The quadrant size at which to stop recursing down
22 dl 1.1 * and instead directly multiply the matrices.
23     * Must be a power of two. Minimum value is 2.
24 jsr166 1.3 */
25 dl 1.1 static int granularity = DEFAULT_GRANULARITY;
26    
27     public static void main(String[] args) throws Exception {
28    
29     final String usage = "Usage: java MatrixMultiply <threads> <matrix size (must be a power of two)> [<granularity>] \n Size and granularity must be powers of two.\n For example, try java MatrixMultiply 2 512 16";
30    
31     int procs = 0;
32     int n = 2048;
33     int runs = 5;
34     try {
35     if (args.length > 0)
36     procs = Integer.parseInt(args[0]);
37     if (args.length > 1)
38     n = Integer.parseInt(args[1]);
39 jsr166 1.2 if (args.length > 2)
40 dl 1.1 granularity = Integer.parseInt(args[2]);
41 jsr166 1.2 if (args.length > 3)
42 dl 1.1 runs = Integer.parseInt(args[2]);
43     }
44 jsr166 1.2
45 dl 1.1 catch (Exception e) {
46     System.out.println(usage);
47     return;
48     }
49 jsr166 1.2
50     if ( ((n & (n - 1)) != 0) ||
51 dl 1.1 ((granularity & (granularity - 1)) != 0) ||
52     granularity < 2) {
53     System.out.println(usage);
54     return;
55     }
56 jsr166 1.2
57 jsr166 1.4 ForkJoinPool pool = (procs == 0) ? new ForkJoinPool() :
58 dl 1.1 new ForkJoinPool(procs);
59 jsr166 1.2 System.out.println("procs: " + pool.getParallelism() +
60 dl 1.1 " n: " + n + " granularity: " + granularity +
61     " runs: " + runs);
62 jsr166 1.2
63 dl 1.1 float[][] a = new float[n][n];
64     float[][] b = new float[n][n];
65     float[][] c = new float[n][n];
66 jsr166 1.2
67 dl 1.1 for (int i = 0; i < runs; ++i) {
68     init(a, b, n);
69     long start = System.nanoTime();
70     pool.invoke(new Multiplier(a, 0, 0, b, 0, 0, c, 0, 0, n));
71     long time = System.nanoTime() - start;
72     double secs = ((double)time) / NPS;
73     Thread.sleep(100);
74     System.out.printf("\tTime: %7.3f\n", secs);
75     // check(c, n);
76     }
77     System.out.println(pool.toString());
78     pool.shutdown();
79     }
80    
81     // To simplify checking, fill with all 1's. Answer should be all n's.
82     static void init(float[][] a, float[][] b, int n) {
83     for (int i = 0; i < n; ++i) {
84     for (int j = 0; j < n; ++j) {
85     a[i][j] = 1.0F;
86     b[i][j] = 1.0F;
87     }
88     }
89     }
90    
91     static void check(float[][] c, int n) {
92     for (int i = 0; i < n; i++ ) {
93     for (int j = 0; j < n; j++ ) {
94     if (c[i][j] != n) {
95     throw new Error("Check Failed at [" + i +"]["+j+"]: " + c[i][j]);
96     }
97     }
98     }
99     }
100    
101 jsr166 1.2 /**
102 dl 1.1 * Multiply matrices AxB by dividing into quadrants, using algorithm:
103     * <pre>
104 jsr166 1.2 * A x B
105 dl 1.1 *
106 jsr166 1.2 * A11 | A12 B11 | B12 A11*B11 | A11*B12 A12*B21 | A12*B22
107 dl 1.1 * |----+----| x |----+----| = |--------+--------| + |---------+-------|
108 jsr166 1.2 * A21 | A22 B21 | B21 A21*B11 | A21*B21 A22*B21 | A22*B22
109 dl 1.1 * </pre>
110     */
111     static class Multiplier extends RecursiveAction {
112     final float[][] A; // Matrix A
113     final int aRow; // first row of current quadrant of A
114     final int aCol; // first column of current quadrant of A
115    
116     final float[][] B; // Similarly for B
117     final int bRow;
118     final int bCol;
119    
120     final float[][] C; // Similarly for result matrix C
121     final int cRow;
122     final int cCol;
123    
124     final int size; // number of elements in current quadrant
125 jsr166 1.2
126 dl 1.1 Multiplier(float[][] A, int aRow, int aCol,
127     float[][] B, int bRow, int bCol,
128     float[][] C, int cRow, int cCol,
129     int size) {
130     this.A = A; this.aRow = aRow; this.aCol = aCol;
131     this.B = B; this.bRow = bRow; this.bCol = bCol;
132     this.C = C; this.cRow = cRow; this.cCol = cCol;
133     this.size = size;
134     }
135    
136     public void compute() {
137    
138     if (size <= granularity) {
139     multiplyStride2();
140     }
141    
142     else {
143     int h = size / 2;
144    
145     invokeAll(new Seq2[] {
146     seq(new Multiplier(A, aRow, aCol, // A11
147     B, bRow, bCol, // B11
148     C, cRow, cCol, // C11
149     h),
150     new Multiplier(A, aRow, aCol+h, // A12
151     B, bRow+h, bCol, // B21
152     C, cRow, cCol, // C11
153     h)),
154 jsr166 1.2
155 dl 1.1 seq(new Multiplier(A, aRow, aCol, // A11
156     B, bRow, bCol+h, // B12
157     C, cRow, cCol+h, // C12
158     h),
159     new Multiplier(A, aRow, aCol+h, // A12
160     B, bRow+h, bCol+h, // B22
161     C, cRow, cCol+h, // C12
162     h)),
163 jsr166 1.2
164 dl 1.1 seq(new Multiplier(A, aRow+h, aCol, // A21
165     B, bRow, bCol, // B11
166     C, cRow+h, cCol, // C21
167     h),
168     new Multiplier(A, aRow+h, aCol+h, // A22
169     B, bRow+h, bCol, // B21
170     C, cRow+h, cCol, // C21
171     h)),
172 jsr166 1.2
173 dl 1.1 seq(new Multiplier(A, aRow+h, aCol, // A21
174     B, bRow, bCol+h, // B12
175     C, cRow+h, cCol+h, // C22
176     h),
177     new Multiplier(A, aRow+h, aCol+h, // A22
178     B, bRow+h, bCol+h, // B22
179     C, cRow+h, cCol+h, // C22
180     h))
181     });
182     }
183     }
184    
185 jsr166 1.2 /**
186 dl 1.1 * Version of matrix multiplication that steps 2 rows and columns
187     * at a time. Adapted from Cilk demos.
188     * Note that the results are added into C, not just set into C.
189     * This works well here because Java array elements
190     * are created with all zero values.
191 jsr166 1.3 */
192 dl 1.1 void multiplyStride2() {
193     for (int j = 0; j < size; j+=2) {
194     for (int i = 0; i < size; i +=2) {
195    
196     float[] a0 = A[aRow+i];
197     float[] a1 = A[aRow+i+1];
198 jsr166 1.2
199     float s00 = 0.0F;
200     float s01 = 0.0F;
201     float s10 = 0.0F;
202     float s11 = 0.0F;
203 dl 1.1
204     for (int k = 0; k < size; k+=2) {
205    
206     float[] b0 = B[bRow+k];
207    
208     s00 += a0[aCol+k] * b0[bCol+j];
209     s10 += a1[aCol+k] * b0[bCol+j];
210     s01 += a0[aCol+k] * b0[bCol+j+1];
211     s11 += a1[aCol+k] * b0[bCol+j+1];
212    
213     float[] b1 = B[bRow+k+1];
214    
215     s00 += a0[aCol+k+1] * b1[bCol+j];
216     s10 += a1[aCol+k+1] * b1[bCol+j];
217     s01 += a0[aCol+k+1] * b1[bCol+j+1];
218     s11 += a1[aCol+k+1] * b1[bCol+j+1];
219     }
220    
221     C[cRow+i] [cCol+j] += s00;
222     C[cRow+i] [cCol+j+1] += s01;
223     C[cRow+i+1][cCol+j] += s10;
224     C[cRow+i+1][cCol+j+1] += s11;
225     }
226     }
227     }
228    
229     }
230    
231 jsr166 1.2 static Seq2 seq(RecursiveAction task1,
232     RecursiveAction task2) {
233     return new Seq2(task1, task2);
234 dl 1.1 }
235    
236     static final class Seq2 extends RecursiveAction {
237     final RecursiveAction fst;
238     final RecursiveAction snd;
239     public Seq2(RecursiveAction task1, RecursiveAction task2) {
240     fst = task1;
241     snd = task2;
242     }
243     public void compute() {
244     fst.invoke();
245     snd.invoke();
246     }
247     }
248     }