ViewVC Help
View File | Revision Log | Show Annotations | Download File | Root Listing
root/jsr166/jsr166/src/test/loops/MatrixMultiply.java
Revision: 1.7
Committed: Wed Dec 31 16:44:01 2014 UTC (9 years, 4 months ago) by jsr166
Branch: MAIN
Changes since 1.6: +0 -1 lines
Log Message:
remove unused imports

File Contents

# Content
1 /*
2 * Written by Doug Lea with assistance from members of JCP JSR-166
3 * Expert Group and released to the public domain, as explained at
4 * http://creativecommons.org/publicdomain/zero/1.0/
5 */
6
7 //import jsr166y.*;
8 import java.util.concurrent.*;
9
10
11 /**
12 * Divide and Conquer matrix multiply demo
13 */
14 public class MatrixMultiply {
15
16 /** for time conversion */
17 static final long NPS = (1000L * 1000 * 1000);
18
19 static final int DEFAULT_GRANULARITY = 32;
20
21 /**
22 * The quadrant size at which to stop recursing down
23 * and instead directly multiply the matrices.
24 * Must be a power of two. Minimum value is 2.
25 */
26 static int granularity = DEFAULT_GRANULARITY;
27
28 public static void main(String[] args) throws Exception {
29
30 final String usage = "Usage: java MatrixMultiply <threads> <matrix size (must be a power of two)> [<granularity>] \n Size and granularity must be powers of two.\n For example, try java MatrixMultiply 2 512 16";
31
32 int procs = 0;
33 int n = 2048;
34 int runs = 5;
35 try {
36 if (args.length > 0)
37 procs = Integer.parseInt(args[0]);
38 if (args.length > 1)
39 n = Integer.parseInt(args[1]);
40 if (args.length > 2)
41 granularity = Integer.parseInt(args[2]);
42 if (args.length > 3)
43 runs = Integer.parseInt(args[2]);
44 }
45
46 catch (Exception e) {
47 System.out.println(usage);
48 return;
49 }
50
51 if ( ((n & (n - 1)) != 0) ||
52 ((granularity & (granularity - 1)) != 0) ||
53 granularity < 2) {
54 System.out.println(usage);
55 return;
56 }
57
58 ForkJoinPool pool = (procs == 0) ? new ForkJoinPool() :
59 new ForkJoinPool(procs);
60 System.out.println("procs: " + pool.getParallelism() +
61 " n: " + n + " granularity: " + granularity +
62 " runs: " + runs);
63
64 float[][] a = new float[n][n];
65 float[][] b = new float[n][n];
66 float[][] c = new float[n][n];
67
68 for (int i = 0; i < runs; ++i) {
69 init(a, b, n);
70 long start = System.nanoTime();
71 pool.invoke(new Multiplier(a, 0, 0, b, 0, 0, c, 0, 0, n));
72 long time = System.nanoTime() - start;
73 double secs = ((double)time) / NPS;
74 Thread.sleep(100);
75 System.out.printf("\tTime: %7.3f\n", secs);
76 // check(c, n);
77 }
78 System.out.println(pool.toString());
79 pool.shutdown();
80 }
81
82
83 // To simplify checking, fill with all 1's. Answer should be all n's.
84 static void init(float[][] a, float[][] b, int n) {
85 for (int i = 0; i < n; ++i) {
86 for (int j = 0; j < n; ++j) {
87 a[i][j] = 1.0F;
88 b[i][j] = 1.0F;
89 }
90 }
91 }
92
93 static void check(float[][] c, int n) {
94 for (int i = 0; i < n; i++ ) {
95 for (int j = 0; j < n; j++ ) {
96 if (c[i][j] != n) {
97 throw new Error("Check Failed at [" + i +"]["+j+"]: " + c[i][j]);
98 }
99 }
100 }
101 }
102
103 /**
104 * Multiply matrices AxB by dividing into quadrants, using algorithm:
105 * <pre>
106 * A x B
107 *
108 * A11 | A12 B11 | B12 A11*B11 | A11*B12 A12*B21 | A12*B22
109 * |----+----| x |----+----| = |--------+--------| + |---------+-------|
110 * A21 | A22 B21 | B21 A21*B11 | A21*B21 A22*B21 | A22*B22
111 * </pre>
112 */
113 static class Multiplier extends RecursiveAction {
114 final float[][] A; // Matrix A
115 final int aRow; // first row of current quadrant of A
116 final int aCol; // first column of current quadrant of A
117
118 final float[][] B; // Similarly for B
119 final int bRow;
120 final int bCol;
121
122 final float[][] C; // Similarly for result matrix C
123 final int cRow;
124 final int cCol;
125
126 final int size; // number of elements in current quadrant
127
128 Multiplier(float[][] A, int aRow, int aCol,
129 float[][] B, int bRow, int bCol,
130 float[][] C, int cRow, int cCol,
131 int size) {
132 this.A = A; this.aRow = aRow; this.aCol = aCol;
133 this.B = B; this.bRow = bRow; this.bCol = bCol;
134 this.C = C; this.cRow = cRow; this.cCol = cCol;
135 this.size = size;
136 }
137
138 public void compute() {
139
140 if (size <= granularity) {
141 multiplyStride2();
142 }
143
144 else {
145 int h = size / 2;
146
147 invokeAll(new Seq2[] {
148 seq(new Multiplier(A, aRow, aCol, // A11
149 B, bRow, bCol, // B11
150 C, cRow, cCol, // C11
151 h),
152 new Multiplier(A, aRow, aCol+h, // A12
153 B, bRow+h, bCol, // B21
154 C, cRow, cCol, // C11
155 h)),
156
157 seq(new Multiplier(A, aRow, aCol, // A11
158 B, bRow, bCol+h, // B12
159 C, cRow, cCol+h, // C12
160 h),
161 new Multiplier(A, aRow, aCol+h, // A12
162 B, bRow+h, bCol+h, // B22
163 C, cRow, cCol+h, // C12
164 h)),
165
166 seq(new Multiplier(A, aRow+h, aCol, // A21
167 B, bRow, bCol, // B11
168 C, cRow+h, cCol, // C21
169 h),
170 new Multiplier(A, aRow+h, aCol+h, // A22
171 B, bRow+h, bCol, // B21
172 C, cRow+h, cCol, // C21
173 h)),
174
175 seq(new Multiplier(A, aRow+h, aCol, // A21
176 B, bRow, bCol+h, // B12
177 C, cRow+h, cCol+h, // C22
178 h),
179 new Multiplier(A, aRow+h, aCol+h, // A22
180 B, bRow+h, bCol+h, // B22
181 C, cRow+h, cCol+h, // C22
182 h))
183 });
184 }
185 }
186
187 /**
188 * Version of matrix multiplication that steps 2 rows and columns
189 * at a time. Adapted from Cilk demos.
190 * Note that the results are added into C, not just set into C.
191 * This works well here because Java array elements
192 * are created with all zero values.
193 */
194 void multiplyStride2() {
195 for (int j = 0; j < size; j+=2) {
196 for (int i = 0; i < size; i +=2) {
197
198 float[] a0 = A[aRow+i];
199 float[] a1 = A[aRow+i+1];
200
201 float s00 = 0.0F;
202 float s01 = 0.0F;
203 float s10 = 0.0F;
204 float s11 = 0.0F;
205
206 for (int k = 0; k < size; k+=2) {
207
208 float[] b0 = B[bRow+k];
209
210 s00 += a0[aCol+k] * b0[bCol+j];
211 s10 += a1[aCol+k] * b0[bCol+j];
212 s01 += a0[aCol+k] * b0[bCol+j+1];
213 s11 += a1[aCol+k] * b0[bCol+j+1];
214
215 float[] b1 = B[bRow+k+1];
216
217 s00 += a0[aCol+k+1] * b1[bCol+j];
218 s10 += a1[aCol+k+1] * b1[bCol+j];
219 s01 += a0[aCol+k+1] * b1[bCol+j+1];
220 s11 += a1[aCol+k+1] * b1[bCol+j+1];
221 }
222
223 C[cRow+i] [cCol+j] += s00;
224 C[cRow+i] [cCol+j+1] += s01;
225 C[cRow+i+1][cCol+j] += s10;
226 C[cRow+i+1][cCol+j+1] += s11;
227 }
228 }
229 }
230
231 }
232
233 static Seq2 seq(RecursiveAction task1,
234 RecursiveAction task2) {
235 return new Seq2(task1, task2);
236 }
237
238 static final class Seq2 extends RecursiveAction {
239 final RecursiveAction fst;
240 final RecursiveAction snd;
241 public Seq2(RecursiveAction task1, RecursiveAction task2) {
242 fst = task1;
243 snd = task2;
244 }
245 public void compute() {
246 fst.invoke();
247 snd.invoke();
248 }
249 }
250
251
252 }