ViewVC Help
View File | Revision Log | Show Annotations | Download File | Root Listing
root/jsr166/jsr166/src/test/loops/MatrixMultiply.java
Revision: 1.7
Committed: Wed Dec 31 16:44:01 2014 UTC (9 years, 4 months ago) by jsr166
Branch: MAIN
Changes since 1.6: +0 -1 lines
Log Message:
remove unused imports

File Contents

# User Rev Content
1 dl 1.1 /*
2     * Written by Doug Lea with assistance from members of JCP JSR-166
3     * Expert Group and released to the public domain, as explained at
4 jsr166 1.5 * http://creativecommons.org/publicdomain/zero/1.0/
5 dl 1.1 */
6    
7     //import jsr166y.*;
8     import java.util.concurrent.*;
9    
10    
11     /**
12     * Divide and Conquer matrix multiply demo
13 jsr166 1.3 */
14 dl 1.1 public class MatrixMultiply {
15    
16     /** for time conversion */
17     static final long NPS = (1000L * 1000 * 1000);
18    
19     static final int DEFAULT_GRANULARITY = 32;
20    
21 jsr166 1.3 /**
22     * The quadrant size at which to stop recursing down
23 dl 1.1 * and instead directly multiply the matrices.
24     * Must be a power of two. Minimum value is 2.
25 jsr166 1.3 */
26 dl 1.1 static int granularity = DEFAULT_GRANULARITY;
27    
28     public static void main(String[] args) throws Exception {
29    
30     final String usage = "Usage: java MatrixMultiply <threads> <matrix size (must be a power of two)> [<granularity>] \n Size and granularity must be powers of two.\n For example, try java MatrixMultiply 2 512 16";
31    
32     int procs = 0;
33     int n = 2048;
34     int runs = 5;
35     try {
36     if (args.length > 0)
37     procs = Integer.parseInt(args[0]);
38     if (args.length > 1)
39     n = Integer.parseInt(args[1]);
40 jsr166 1.2 if (args.length > 2)
41 dl 1.1 granularity = Integer.parseInt(args[2]);
42 jsr166 1.2 if (args.length > 3)
43 dl 1.1 runs = Integer.parseInt(args[2]);
44     }
45 jsr166 1.2
46 dl 1.1 catch (Exception e) {
47     System.out.println(usage);
48     return;
49     }
50 jsr166 1.2
51     if ( ((n & (n - 1)) != 0) ||
52 dl 1.1 ((granularity & (granularity - 1)) != 0) ||
53     granularity < 2) {
54     System.out.println(usage);
55     return;
56     }
57 jsr166 1.2
58 jsr166 1.4 ForkJoinPool pool = (procs == 0) ? new ForkJoinPool() :
59 dl 1.1 new ForkJoinPool(procs);
60 jsr166 1.2 System.out.println("procs: " + pool.getParallelism() +
61 dl 1.1 " n: " + n + " granularity: " + granularity +
62     " runs: " + runs);
63 jsr166 1.2
64 dl 1.1 float[][] a = new float[n][n];
65     float[][] b = new float[n][n];
66     float[][] c = new float[n][n];
67 jsr166 1.2
68 dl 1.1 for (int i = 0; i < runs; ++i) {
69     init(a, b, n);
70     long start = System.nanoTime();
71     pool.invoke(new Multiplier(a, 0, 0, b, 0, 0, c, 0, 0, n));
72     long time = System.nanoTime() - start;
73     double secs = ((double)time) / NPS;
74     Thread.sleep(100);
75     System.out.printf("\tTime: %7.3f\n", secs);
76     // check(c, n);
77     }
78     System.out.println(pool.toString());
79     pool.shutdown();
80     }
81    
82    
83     // To simplify checking, fill with all 1's. Answer should be all n's.
84     static void init(float[][] a, float[][] b, int n) {
85     for (int i = 0; i < n; ++i) {
86     for (int j = 0; j < n; ++j) {
87     a[i][j] = 1.0F;
88     b[i][j] = 1.0F;
89     }
90     }
91     }
92    
93     static void check(float[][] c, int n) {
94     for (int i = 0; i < n; i++ ) {
95     for (int j = 0; j < n; j++ ) {
96     if (c[i][j] != n) {
97     throw new Error("Check Failed at [" + i +"]["+j+"]: " + c[i][j]);
98     }
99     }
100     }
101     }
102    
103 jsr166 1.2 /**
104 dl 1.1 * Multiply matrices AxB by dividing into quadrants, using algorithm:
105     * <pre>
106 jsr166 1.2 * A x B
107 dl 1.1 *
108 jsr166 1.2 * A11 | A12 B11 | B12 A11*B11 | A11*B12 A12*B21 | A12*B22
109 dl 1.1 * |----+----| x |----+----| = |--------+--------| + |---------+-------|
110 jsr166 1.2 * A21 | A22 B21 | B21 A21*B11 | A21*B21 A22*B21 | A22*B22
111 dl 1.1 * </pre>
112     */
113     static class Multiplier extends RecursiveAction {
114     final float[][] A; // Matrix A
115     final int aRow; // first row of current quadrant of A
116     final int aCol; // first column of current quadrant of A
117    
118     final float[][] B; // Similarly for B
119     final int bRow;
120     final int bCol;
121    
122     final float[][] C; // Similarly for result matrix C
123     final int cRow;
124     final int cCol;
125    
126     final int size; // number of elements in current quadrant
127 jsr166 1.2
128 dl 1.1 Multiplier(float[][] A, int aRow, int aCol,
129     float[][] B, int bRow, int bCol,
130     float[][] C, int cRow, int cCol,
131     int size) {
132     this.A = A; this.aRow = aRow; this.aCol = aCol;
133     this.B = B; this.bRow = bRow; this.bCol = bCol;
134     this.C = C; this.cRow = cRow; this.cCol = cCol;
135     this.size = size;
136     }
137    
138     public void compute() {
139    
140     if (size <= granularity) {
141     multiplyStride2();
142     }
143    
144     else {
145     int h = size / 2;
146    
147     invokeAll(new Seq2[] {
148     seq(new Multiplier(A, aRow, aCol, // A11
149     B, bRow, bCol, // B11
150     C, cRow, cCol, // C11
151     h),
152     new Multiplier(A, aRow, aCol+h, // A12
153     B, bRow+h, bCol, // B21
154     C, cRow, cCol, // C11
155     h)),
156 jsr166 1.2
157 dl 1.1 seq(new Multiplier(A, aRow, aCol, // A11
158     B, bRow, bCol+h, // B12
159     C, cRow, cCol+h, // C12
160     h),
161     new Multiplier(A, aRow, aCol+h, // A12
162     B, bRow+h, bCol+h, // B22
163     C, cRow, cCol+h, // C12
164     h)),
165 jsr166 1.2
166 dl 1.1 seq(new Multiplier(A, aRow+h, aCol, // A21
167     B, bRow, bCol, // B11
168     C, cRow+h, cCol, // C21
169     h),
170     new Multiplier(A, aRow+h, aCol+h, // A22
171     B, bRow+h, bCol, // B21
172     C, cRow+h, cCol, // C21
173     h)),
174 jsr166 1.2
175 dl 1.1 seq(new Multiplier(A, aRow+h, aCol, // A21
176     B, bRow, bCol+h, // B12
177     C, cRow+h, cCol+h, // C22
178     h),
179     new Multiplier(A, aRow+h, aCol+h, // A22
180     B, bRow+h, bCol+h, // B22
181     C, cRow+h, cCol+h, // C22
182     h))
183     });
184     }
185     }
186    
187 jsr166 1.2 /**
188 dl 1.1 * Version of matrix multiplication that steps 2 rows and columns
189     * at a time. Adapted from Cilk demos.
190     * Note that the results are added into C, not just set into C.
191     * This works well here because Java array elements
192     * are created with all zero values.
193 jsr166 1.3 */
194 dl 1.1 void multiplyStride2() {
195     for (int j = 0; j < size; j+=2) {
196     for (int i = 0; i < size; i +=2) {
197    
198     float[] a0 = A[aRow+i];
199     float[] a1 = A[aRow+i+1];
200 jsr166 1.2
201     float s00 = 0.0F;
202     float s01 = 0.0F;
203     float s10 = 0.0F;
204     float s11 = 0.0F;
205 dl 1.1
206     for (int k = 0; k < size; k+=2) {
207    
208     float[] b0 = B[bRow+k];
209    
210     s00 += a0[aCol+k] * b0[bCol+j];
211     s10 += a1[aCol+k] * b0[bCol+j];
212     s01 += a0[aCol+k] * b0[bCol+j+1];
213     s11 += a1[aCol+k] * b0[bCol+j+1];
214    
215     float[] b1 = B[bRow+k+1];
216    
217     s00 += a0[aCol+k+1] * b1[bCol+j];
218     s10 += a1[aCol+k+1] * b1[bCol+j];
219     s01 += a0[aCol+k+1] * b1[bCol+j+1];
220     s11 += a1[aCol+k+1] * b1[bCol+j+1];
221     }
222    
223     C[cRow+i] [cCol+j] += s00;
224     C[cRow+i] [cCol+j+1] += s01;
225     C[cRow+i+1][cCol+j] += s10;
226     C[cRow+i+1][cCol+j+1] += s11;
227     }
228     }
229     }
230    
231     }
232    
233 jsr166 1.2 static Seq2 seq(RecursiveAction task1,
234     RecursiveAction task2) {
235     return new Seq2(task1, task2);
236 dl 1.1 }
237    
238     static final class Seq2 extends RecursiveAction {
239     final RecursiveAction fst;
240     final RecursiveAction snd;
241     public Seq2(RecursiveAction task1, RecursiveAction task2) {
242     fst = task1;
243     snd = task2;
244     }
245     public void compute() {
246     fst.invoke();
247     snd.invoke();
248     }
249     }
250    
251    
252     }