ViewVC Help
View File | Revision Log | Show Annotations | Download File | Root Listing
root/jsr166/jsr166/src/test/loops/MatrixMultiply.java
Revision: 1.8
Committed: Thu Jan 15 18:34:19 2015 UTC (9 years, 4 months ago) by jsr166
Branch: MAIN
Changes since 1.7: +0 -4 lines
Log Message:
delete extraneous blank lines

File Contents

# Content
1 /*
2 * Written by Doug Lea with assistance from members of JCP JSR-166
3 * Expert Group and released to the public domain, as explained at
4 * http://creativecommons.org/publicdomain/zero/1.0/
5 */
6
7 //import jsr166y.*;
8 import java.util.concurrent.*;
9
10 /**
11 * Divide and Conquer matrix multiply demo
12 */
13 public class MatrixMultiply {
14
15 /** for time conversion */
16 static final long NPS = (1000L * 1000 * 1000);
17
18 static final int DEFAULT_GRANULARITY = 32;
19
20 /**
21 * The quadrant size at which to stop recursing down
22 * and instead directly multiply the matrices.
23 * Must be a power of two. Minimum value is 2.
24 */
25 static int granularity = DEFAULT_GRANULARITY;
26
27 public static void main(String[] args) throws Exception {
28
29 final String usage = "Usage: java MatrixMultiply <threads> <matrix size (must be a power of two)> [<granularity>] \n Size and granularity must be powers of two.\n For example, try java MatrixMultiply 2 512 16";
30
31 int procs = 0;
32 int n = 2048;
33 int runs = 5;
34 try {
35 if (args.length > 0)
36 procs = Integer.parseInt(args[0]);
37 if (args.length > 1)
38 n = Integer.parseInt(args[1]);
39 if (args.length > 2)
40 granularity = Integer.parseInt(args[2]);
41 if (args.length > 3)
42 runs = Integer.parseInt(args[2]);
43 }
44
45 catch (Exception e) {
46 System.out.println(usage);
47 return;
48 }
49
50 if ( ((n & (n - 1)) != 0) ||
51 ((granularity & (granularity - 1)) != 0) ||
52 granularity < 2) {
53 System.out.println(usage);
54 return;
55 }
56
57 ForkJoinPool pool = (procs == 0) ? new ForkJoinPool() :
58 new ForkJoinPool(procs);
59 System.out.println("procs: " + pool.getParallelism() +
60 " n: " + n + " granularity: " + granularity +
61 " runs: " + runs);
62
63 float[][] a = new float[n][n];
64 float[][] b = new float[n][n];
65 float[][] c = new float[n][n];
66
67 for (int i = 0; i < runs; ++i) {
68 init(a, b, n);
69 long start = System.nanoTime();
70 pool.invoke(new Multiplier(a, 0, 0, b, 0, 0, c, 0, 0, n));
71 long time = System.nanoTime() - start;
72 double secs = ((double)time) / NPS;
73 Thread.sleep(100);
74 System.out.printf("\tTime: %7.3f\n", secs);
75 // check(c, n);
76 }
77 System.out.println(pool.toString());
78 pool.shutdown();
79 }
80
81 // To simplify checking, fill with all 1's. Answer should be all n's.
82 static void init(float[][] a, float[][] b, int n) {
83 for (int i = 0; i < n; ++i) {
84 for (int j = 0; j < n; ++j) {
85 a[i][j] = 1.0F;
86 b[i][j] = 1.0F;
87 }
88 }
89 }
90
91 static void check(float[][] c, int n) {
92 for (int i = 0; i < n; i++ ) {
93 for (int j = 0; j < n; j++ ) {
94 if (c[i][j] != n) {
95 throw new Error("Check Failed at [" + i +"]["+j+"]: " + c[i][j]);
96 }
97 }
98 }
99 }
100
101 /**
102 * Multiply matrices AxB by dividing into quadrants, using algorithm:
103 * <pre>
104 * A x B
105 *
106 * A11 | A12 B11 | B12 A11*B11 | A11*B12 A12*B21 | A12*B22
107 * |----+----| x |----+----| = |--------+--------| + |---------+-------|
108 * A21 | A22 B21 | B21 A21*B11 | A21*B21 A22*B21 | A22*B22
109 * </pre>
110 */
111 static class Multiplier extends RecursiveAction {
112 final float[][] A; // Matrix A
113 final int aRow; // first row of current quadrant of A
114 final int aCol; // first column of current quadrant of A
115
116 final float[][] B; // Similarly for B
117 final int bRow;
118 final int bCol;
119
120 final float[][] C; // Similarly for result matrix C
121 final int cRow;
122 final int cCol;
123
124 final int size; // number of elements in current quadrant
125
126 Multiplier(float[][] A, int aRow, int aCol,
127 float[][] B, int bRow, int bCol,
128 float[][] C, int cRow, int cCol,
129 int size) {
130 this.A = A; this.aRow = aRow; this.aCol = aCol;
131 this.B = B; this.bRow = bRow; this.bCol = bCol;
132 this.C = C; this.cRow = cRow; this.cCol = cCol;
133 this.size = size;
134 }
135
136 public void compute() {
137
138 if (size <= granularity) {
139 multiplyStride2();
140 }
141
142 else {
143 int h = size / 2;
144
145 invokeAll(new Seq2[] {
146 seq(new Multiplier(A, aRow, aCol, // A11
147 B, bRow, bCol, // B11
148 C, cRow, cCol, // C11
149 h),
150 new Multiplier(A, aRow, aCol+h, // A12
151 B, bRow+h, bCol, // B21
152 C, cRow, cCol, // C11
153 h)),
154
155 seq(new Multiplier(A, aRow, aCol, // A11
156 B, bRow, bCol+h, // B12
157 C, cRow, cCol+h, // C12
158 h),
159 new Multiplier(A, aRow, aCol+h, // A12
160 B, bRow+h, bCol+h, // B22
161 C, cRow, cCol+h, // C12
162 h)),
163
164 seq(new Multiplier(A, aRow+h, aCol, // A21
165 B, bRow, bCol, // B11
166 C, cRow+h, cCol, // C21
167 h),
168 new Multiplier(A, aRow+h, aCol+h, // A22
169 B, bRow+h, bCol, // B21
170 C, cRow+h, cCol, // C21
171 h)),
172
173 seq(new Multiplier(A, aRow+h, aCol, // A21
174 B, bRow, bCol+h, // B12
175 C, cRow+h, cCol+h, // C22
176 h),
177 new Multiplier(A, aRow+h, aCol+h, // A22
178 B, bRow+h, bCol+h, // B22
179 C, cRow+h, cCol+h, // C22
180 h))
181 });
182 }
183 }
184
185 /**
186 * Version of matrix multiplication that steps 2 rows and columns
187 * at a time. Adapted from Cilk demos.
188 * Note that the results are added into C, not just set into C.
189 * This works well here because Java array elements
190 * are created with all zero values.
191 */
192 void multiplyStride2() {
193 for (int j = 0; j < size; j+=2) {
194 for (int i = 0; i < size; i +=2) {
195
196 float[] a0 = A[aRow+i];
197 float[] a1 = A[aRow+i+1];
198
199 float s00 = 0.0F;
200 float s01 = 0.0F;
201 float s10 = 0.0F;
202 float s11 = 0.0F;
203
204 for (int k = 0; k < size; k+=2) {
205
206 float[] b0 = B[bRow+k];
207
208 s00 += a0[aCol+k] * b0[bCol+j];
209 s10 += a1[aCol+k] * b0[bCol+j];
210 s01 += a0[aCol+k] * b0[bCol+j+1];
211 s11 += a1[aCol+k] * b0[bCol+j+1];
212
213 float[] b1 = B[bRow+k+1];
214
215 s00 += a0[aCol+k+1] * b1[bCol+j];
216 s10 += a1[aCol+k+1] * b1[bCol+j];
217 s01 += a0[aCol+k+1] * b1[bCol+j+1];
218 s11 += a1[aCol+k+1] * b1[bCol+j+1];
219 }
220
221 C[cRow+i] [cCol+j] += s00;
222 C[cRow+i] [cCol+j+1] += s01;
223 C[cRow+i+1][cCol+j] += s10;
224 C[cRow+i+1][cCol+j+1] += s11;
225 }
226 }
227 }
228
229 }
230
231 static Seq2 seq(RecursiveAction task1,
232 RecursiveAction task2) {
233 return new Seq2(task1, task2);
234 }
235
236 static final class Seq2 extends RecursiveAction {
237 final RecursiveAction fst;
238 final RecursiveAction snd;
239 public Seq2(RecursiveAction task1, RecursiveAction task2) {
240 fst = task1;
241 snd = task2;
242 }
243 public void compute() {
244 fst.invoke();
245 snd.invoke();
246 }
247 }
248 }