1 |
/* |
2 |
* Written by Doug Lea with assistance from members of JCP JSR-166 |
3 |
* Expert Group and released to the public domain, as explained at |
4 |
* http://creativecommons.org/publicdomain/zero/1.0/ |
5 |
*/ |
6 |
|
7 |
import java.util.*; |
8 |
import java.util.concurrent.*; |
9 |
import java.util.concurrent.atomic.*; |
10 |
|
11 |
// parallel sums and cumulations |
12 |
|
13 |
public class FJSums { |
14 |
static int THRESHOLD; |
15 |
static final int MIN_PARTITION = 64; |
16 |
|
17 |
interface LongByLongToLong { long apply(long a, long b); } |
18 |
|
19 |
static final class Add implements LongByLongToLong { |
20 |
public long apply(long a, long b) { return a + b; } |
21 |
} |
22 |
|
23 |
static final Add ADD = new Add(); |
24 |
|
25 |
public static void main(String[] args) throws Exception { |
26 |
int n = 1 << 25; |
27 |
int reps = 10; |
28 |
try { |
29 |
if (args.length > 0) |
30 |
n = Integer.parseInt(args[0]); |
31 |
if (args.length > 1) |
32 |
reps = Integer.parseInt(args[1]); |
33 |
} |
34 |
catch (Exception e) { |
35 |
System.out.println("Usage: java FJSums n reps"); |
36 |
return; |
37 |
} |
38 |
int par = ForkJoinPool.getCommonPoolParallelism(); |
39 |
System.out.println("Number of procs=" + par); |
40 |
int p; |
41 |
THRESHOLD = (p = n / (par << 3)) <= MIN_PARTITION ? MIN_PARTITION : p; |
42 |
|
43 |
long[] a = new long[n]; |
44 |
for (int i = 0; i < n; ++i) |
45 |
a[i] = i; |
46 |
long expected = ((long)n * (long)(n - 1)) / 2; |
47 |
for (int i = 0; i < reps; ++i) { |
48 |
seqTest(a, i, expected); |
49 |
parTest(a, i, expected); |
50 |
} |
51 |
System.out.println(ForkJoinPool.commonPool()); |
52 |
} |
53 |
|
54 |
static void seqTest(long[] a, int index, long expected) { |
55 |
System.out.print("Seq "); |
56 |
long last = System.nanoTime(); |
57 |
int n = a.length; |
58 |
long ss = seqSum(ADD, 0L, a, 0, n); |
59 |
double elapsed = elapsedTime(last); |
60 |
System.out.printf("sum = %24d time: %7.3f\n", ss, elapsed); |
61 |
if (index == 0 && ss != expected) |
62 |
throw new Error("expected " + expected + " != " + ss); |
63 |
} |
64 |
|
65 |
static void parTest(long[] a, int index, long expected) { |
66 |
System.out.print("Par "); |
67 |
long last = System.nanoTime(); |
68 |
int n = a.length; |
69 |
Summer s = new Summer(null, ADD, 0L, a, 0, n, null); |
70 |
s.invoke(); |
71 |
long ss = s.result; |
72 |
double elapsed = elapsedTime(last); |
73 |
System.out.printf("sum = %24d time: %7.3f\n", ss, elapsed); |
74 |
if (index == 0 && ss != expected) |
75 |
throw new Error("expected " + expected + " != " + ss); |
76 |
System.out.print("Par "); |
77 |
last = System.nanoTime(); |
78 |
new Cumulater(null, ADD, a, 0, n).invoke(); |
79 |
long sc = a[n - 1]; |
80 |
elapsed = elapsedTime(last); |
81 |
System.out.printf("cum = %24d time: %7.3f\n", ss, elapsed); |
82 |
if (sc != ss) |
83 |
throw new Error("expected " + ss + " != " + sc); |
84 |
if (index == 0) { |
85 |
long cs = 0L; |
86 |
for (int j = 0; j < n; ++j) { |
87 |
if ((cs += j) != a[j]) |
88 |
throw new Error("wrong element value"); |
89 |
} |
90 |
} |
91 |
} |
92 |
|
93 |
static double elapsedTime(long startTime) { |
94 |
return (double)(System.nanoTime() - startTime) / (1000L * 1000 * 1000); |
95 |
} |
96 |
|
97 |
static long seqSum(LongByLongToLong fn, long basis, |
98 |
long[] a, int l, int h) { |
99 |
long sum = basis; |
100 |
for (int i = l; i < h; ++i) |
101 |
sum = fn.apply(sum, a[i]); |
102 |
return sum; |
103 |
} |
104 |
|
105 |
/** |
106 |
* Cumulative scan, adapted from ParallelArray code |
107 |
* |
108 |
* A basic version of scan is straightforward. |
109 |
* Keep dividing by two to threshold segment size, and then: |
110 |
* Pass 1: Create tree of partial sums for each segment |
111 |
* Pass 2: For each segment, cumulate with offset of left sibling |
112 |
* See G. Blelloch's http://www.cs.cmu.edu/~scandal/alg/scan.html |
113 |
* |
114 |
* This version improves performance within FJ framework mainly by |
115 |
* allowing the second pass of ready left-hand sides to proceed |
116 |
* even if some right-hand side first passes are still executing. |
117 |
* It also combines first and second pass for leftmost segment, |
118 |
* and skips the first pass for rightmost segment (whose result is |
119 |
* not needed for second pass). |
120 |
* |
121 |
* Managing this relies on ORing some bits in the pendingCount for |
122 |
* phases/states: CUMULATE, SUMMED, and FINISHED. CUMULATE is the |
123 |
* main phase bit. When false, segments compute only their sum. |
124 |
* When true, they cumulate array elements. CUMULATE is set at |
125 |
* root at beginning of second pass and then propagated down. But |
126 |
* it may also be set earlier for subtrees with lo==0 (the left |
127 |
* spine of tree). SUMMED is a one bit join count. For leafs, it |
128 |
* is set when summed. For internal nodes, it becomes true when |
129 |
* one child is summed. When the second child finishes summing, |
130 |
* we then moves up tree to trigger the cumulate phase. FINISHED |
131 |
* is also a one bit join count. For leafs, it is set when |
132 |
* cumulated. For internal nodes, it becomes true when one child |
133 |
* is cumulated. When the second child finishes cumulating, it |
134 |
* then moves up tree, completing at the root. |
135 |
* |
136 |
* To better exploit locality and reduce overhead, the compute |
137 |
* method loops starting with the current task, moving if possible |
138 |
* to one of its subtasks rather than forking. |
139 |
*/ |
140 |
static final class Cumulater extends CountedCompleter<Void> { |
141 |
static final int CUMULATE = 1; |
142 |
static final int SUMMED = 2; |
143 |
static final int FINISHED = 4; |
144 |
|
145 |
final long[] array; |
146 |
final LongByLongToLong function; |
147 |
Cumulater left, right; |
148 |
final int lo, hi; |
149 |
long in, out; |
150 |
|
151 |
Cumulater(Cumulater parent, LongByLongToLong function, |
152 |
long[] array, int lo, int hi) { |
153 |
super(parent); |
154 |
this.function = function; this.array = array; |
155 |
this.lo = lo; this.hi = hi; |
156 |
} |
157 |
|
158 |
public final void compute() { |
159 |
final LongByLongToLong fn; |
160 |
final long[] a; |
161 |
if ((fn = this.function) == null || (a = this.array) == null) |
162 |
throw new NullPointerException(); // hoist checks |
163 |
int l, h; |
164 |
Cumulater t = this; |
165 |
outer: while ((l = t.lo) >= 0 && (h = t.hi) <= a.length) { |
166 |
if (h - l > THRESHOLD) { |
167 |
Cumulater lt = t.left, rt = t.right, f; |
168 |
if (lt == null) { // first pass |
169 |
int mid = (l + h) >>> 1; |
170 |
f = rt = t.right = new Cumulater(t, fn, a, mid, h); |
171 |
t = lt = t.left = new Cumulater(t, fn, a, l, mid); |
172 |
} |
173 |
else { // possibly refork |
174 |
long pin = t.in; |
175 |
lt.in = pin; |
176 |
f = t = null; |
177 |
if (rt != null) { |
178 |
rt.in = fn.apply(pin, lt.out); |
179 |
for (int c;;) { |
180 |
if (((c = rt.getPendingCount()) & CUMULATE) != 0) |
181 |
break; |
182 |
if (rt.compareAndSetPendingCount(c, c|CUMULATE)){ |
183 |
t = rt; |
184 |
break; |
185 |
} |
186 |
} |
187 |
} |
188 |
for (int c;;) { |
189 |
if (((c = lt.getPendingCount()) & CUMULATE) != 0) |
190 |
break; |
191 |
if (lt.compareAndSetPendingCount(c, c|CUMULATE)) { |
192 |
if (t != null) |
193 |
f = t; |
194 |
t = lt; |
195 |
break; |
196 |
} |
197 |
} |
198 |
if (t == null) |
199 |
break; |
200 |
} |
201 |
if (f != null) |
202 |
f.fork(); |
203 |
} |
204 |
else { |
205 |
int state; // Transition to sum, cumulate, or both |
206 |
for (int b;;) { |
207 |
if (((b = t.getPendingCount()) & FINISHED) != 0) |
208 |
break outer; // already done |
209 |
state = (((b & CUMULATE) != 0) |
210 |
? FINISHED |
211 |
: (l > 0) ? SUMMED : (SUMMED|FINISHED)); |
212 |
if (t.compareAndSetPendingCount(b, b|state)) |
213 |
break; |
214 |
} |
215 |
|
216 |
long sum = t.in; |
217 |
if (state != SUMMED) { |
218 |
for (int i = l; i < h; ++i) // cumulate |
219 |
a[i] = sum = fn.apply(sum, a[i]); |
220 |
} |
221 |
else if (h < a.length) { // skip rightmost |
222 |
for (int i = l; i < h; ++i) // sum only |
223 |
sum = fn.apply(sum, a[i]); |
224 |
} |
225 |
t.out = sum; |
226 |
for (Cumulater par;;) { // propagate |
227 |
if ((par = (Cumulater)t.getCompleter()) == null) { |
228 |
if ((state & FINISHED) != 0) // enable join |
229 |
t.quietlyComplete(); |
230 |
break outer; |
231 |
} |
232 |
int b = par.getPendingCount(); |
233 |
if ((b & state & FINISHED) != 0) |
234 |
t = par; // both done |
235 |
else if ((b & state & SUMMED) != 0) { // both summed |
236 |
int nextState; Cumulater lt, rt; |
237 |
if ((lt = par.left) != null && |
238 |
(rt = par.right) != null) |
239 |
par.out = fn.apply(lt.out, rt.out); |
240 |
int refork = (((b & CUMULATE) == 0 && |
241 |
par.lo == 0) ? CUMULATE : 0); |
242 |
if ((nextState = b|state|refork) == b || |
243 |
par.compareAndSetPendingCount(b, nextState)) { |
244 |
state = SUMMED; // drop finished |
245 |
t = par; |
246 |
if (refork != 0) |
247 |
par.fork(); |
248 |
} |
249 |
} |
250 |
else if (par.compareAndSetPendingCount(b, b|state)) |
251 |
break outer; // sib not ready |
252 |
} |
253 |
} |
254 |
} |
255 |
} |
256 |
} |
257 |
|
258 |
// Uses CC reduction via firstComplete/nextComplete |
259 |
static final class Summer extends CountedCompleter<Void> { |
260 |
final long[] array; |
261 |
final LongByLongToLong function; |
262 |
final int lo, hi; |
263 |
final long basis; |
264 |
long result; |
265 |
Summer forks, next; // keeps track of right-hand-side tasks |
266 |
Summer(Summer parent, LongByLongToLong function, long basis, |
267 |
long[] array, int lo, int hi, Summer next) { |
268 |
super(parent); |
269 |
this.function = function; this.basis = basis; |
270 |
this.array = array; this.lo = lo; this.hi = hi; |
271 |
this.next = next; |
272 |
} |
273 |
|
274 |
public final void compute() { |
275 |
final long id = basis; |
276 |
final LongByLongToLong fn; |
277 |
final long[] a; |
278 |
if ((fn = this.function) == null || (a = this.array) == null) |
279 |
throw new NullPointerException(); |
280 |
int l = lo, h = hi; |
281 |
while (h - l >= THRESHOLD) { |
282 |
int mid = (l + h) >>> 1; |
283 |
addToPendingCount(1); |
284 |
(forks = new Summer(this, fn, id, a, mid, h, forks)).fork(); |
285 |
h = mid; |
286 |
} |
287 |
long sum = id; |
288 |
if (l < h && l >= 0 && h <= a.length) { |
289 |
for (int i = l; i < h; ++i) |
290 |
sum = fn.apply(sum, a[i]); |
291 |
} |
292 |
result = sum; |
293 |
CountedCompleter<?> c; |
294 |
for (c = firstComplete(); c != null; c = c.nextComplete()) { |
295 |
Summer t = (Summer)c, s = t.forks; |
296 |
while (s != null) { |
297 |
t.result = fn.apply(t.result, s.result); |
298 |
s = t.forks = s.next; |
299 |
} |
300 |
} |
301 |
} |
302 |
} |
303 |
|
304 |
} |