ViewVC Help
View File | Revision Log | Show Annotations | Download File | Root Listing
root/jsr166/jsr166/src/test/loops/FJSums.java
Revision: 1.10
Committed: Sat Sep 12 18:59:49 2015 UTC (8 years, 7 months ago) by jsr166
Branch: MAIN
CVS Tags: HEAD
Changes since 1.9: +4 -3 lines
Log Message:
whitespace

File Contents

# Content
1 /*
2 * Written by Doug Lea with assistance from members of JCP JSR-166
3 * Expert Group and released to the public domain, as explained at
4 * http://creativecommons.org/publicdomain/zero/1.0/
5 */
6
7 import java.util.*;
8 import java.util.concurrent.*;
9 import java.util.concurrent.atomic.*;
10
11 // parallel sums and cumulations
12
13 public class FJSums {
14 static int THRESHOLD;
15 static final int MIN_PARTITION = 64;
16
17 interface LongByLongToLong { long apply(long a, long b); }
18
19 static final class Add implements LongByLongToLong {
20 public long apply(long a, long b) { return a + b; }
21 }
22
23 static final Add ADD = new Add();
24
25 public static void main(String[] args) throws Exception {
26 int n = 1 << 25;
27 int reps = 10;
28 try {
29 if (args.length > 0)
30 n = Integer.parseInt(args[0]);
31 if (args.length > 1)
32 reps = Integer.parseInt(args[1]);
33 }
34 catch (Exception e) {
35 System.out.println("Usage: java FJSums n reps");
36 return;
37 }
38 int par = ForkJoinPool.getCommonPoolParallelism();
39 System.out.println("Number of procs=" + par);
40 int p;
41 THRESHOLD = (p = n / (par << 3)) <= MIN_PARTITION ? MIN_PARTITION : p;
42
43 long[] a = new long[n];
44 for (int i = 0; i < n; ++i)
45 a[i] = i;
46 long expected = ((long)n * (long)(n - 1)) / 2;
47 for (int i = 0; i < reps; ++i) {
48 seqTest(a, i, expected);
49 parTest(a, i, expected);
50 }
51 System.out.println(ForkJoinPool.commonPool());
52 }
53
54 static void seqTest(long[] a, int index, long expected) {
55 System.out.print("Seq ");
56 long last = System.nanoTime();
57 int n = a.length;
58 long ss = seqSum(ADD, 0L, a, 0, n);
59 double elapsed = elapsedTime(last);
60 System.out.printf("sum = %24d time: %7.3f\n", ss, elapsed);
61 if (index == 0 && ss != expected)
62 throw new Error("expected " + expected + " != " + ss);
63 }
64
65 static void parTest(long[] a, int index, long expected) {
66 System.out.print("Par ");
67 long last = System.nanoTime();
68 int n = a.length;
69 Summer s = new Summer(null, ADD, 0L, a, 0, n, null);
70 s.invoke();
71 long ss = s.result;
72 double elapsed = elapsedTime(last);
73 System.out.printf("sum = %24d time: %7.3f\n", ss, elapsed);
74 if (index == 0 && ss != expected)
75 throw new Error("expected " + expected + " != " + ss);
76 System.out.print("Par ");
77 last = System.nanoTime();
78 new Cumulater(null, ADD, a, 0, n).invoke();
79 long sc = a[n - 1];
80 elapsed = elapsedTime(last);
81 System.out.printf("cum = %24d time: %7.3f\n", ss, elapsed);
82 if (sc != ss)
83 throw new Error("expected " + ss + " != " + sc);
84 if (index == 0) {
85 long cs = 0L;
86 for (int j = 0; j < n; ++j) {
87 if ((cs += j) != a[j])
88 throw new Error("wrong element value");
89 }
90 }
91 }
92
93 static double elapsedTime(long startTime) {
94 return (double)(System.nanoTime() - startTime) / (1000L * 1000 * 1000);
95 }
96
97 static long seqSum(LongByLongToLong fn, long basis,
98 long[] a, int l, int h) {
99 long sum = basis;
100 for (int i = l; i < h; ++i)
101 sum = fn.apply(sum, a[i]);
102 return sum;
103 }
104
105 /**
106 * Cumulative scan, adapted from ParallelArray code
107 *
108 * A basic version of scan is straightforward.
109 * Keep dividing by two to threshold segment size, and then:
110 * Pass 1: Create tree of partial sums for each segment
111 * Pass 2: For each segment, cumulate with offset of left sibling
112 * See G. Blelloch's http://www.cs.cmu.edu/~scandal/alg/scan.html
113 *
114 * This version improves performance within FJ framework mainly by
115 * allowing the second pass of ready left-hand sides to proceed
116 * even if some right-hand side first passes are still executing.
117 * It also combines first and second pass for leftmost segment,
118 * and skips the first pass for rightmost segment (whose result is
119 * not needed for second pass).
120 *
121 * Managing this relies on ORing some bits in the pendingCount for
122 * phases/states: CUMULATE, SUMMED, and FINISHED. CUMULATE is the
123 * main phase bit. When false, segments compute only their sum.
124 * When true, they cumulate array elements. CUMULATE is set at
125 * root at beginning of second pass and then propagated down. But
126 * it may also be set earlier for subtrees with lo==0 (the left
127 * spine of tree). SUMMED is a one bit join count. For leafs, it
128 * is set when summed. For internal nodes, it becomes true when
129 * one child is summed. When the second child finishes summing,
130 * we then moves up tree to trigger the cumulate phase. FINISHED
131 * is also a one bit join count. For leafs, it is set when
132 * cumulated. For internal nodes, it becomes true when one child
133 * is cumulated. When the second child finishes cumulating, it
134 * then moves up tree, completing at the root.
135 *
136 * To better exploit locality and reduce overhead, the compute
137 * method loops starting with the current task, moving if possible
138 * to one of its subtasks rather than forking.
139 */
140 static final class Cumulater extends CountedCompleter<Void> {
141 static final int CUMULATE = 1;
142 static final int SUMMED = 2;
143 static final int FINISHED = 4;
144
145 final long[] array;
146 final LongByLongToLong function;
147 Cumulater left, right;
148 final int lo, hi;
149 long in, out;
150
151 Cumulater(Cumulater parent, LongByLongToLong function,
152 long[] array, int lo, int hi) {
153 super(parent);
154 this.function = function; this.array = array;
155 this.lo = lo; this.hi = hi;
156 }
157
158 public final void compute() {
159 final LongByLongToLong fn;
160 final long[] a;
161 if ((fn = this.function) == null || (a = this.array) == null)
162 throw new NullPointerException(); // hoist checks
163 int l, h;
164 Cumulater t = this;
165 outer: while ((l = t.lo) >= 0 && (h = t.hi) <= a.length) {
166 if (h - l > THRESHOLD) {
167 Cumulater lt = t.left, rt = t.right, f;
168 if (lt == null) { // first pass
169 int mid = (l + h) >>> 1;
170 f = rt = t.right = new Cumulater(t, fn, a, mid, h);
171 t = lt = t.left = new Cumulater(t, fn, a, l, mid);
172 }
173 else { // possibly refork
174 long pin = t.in;
175 lt.in = pin;
176 f = t = null;
177 if (rt != null) {
178 rt.in = fn.apply(pin, lt.out);
179 for (int c;;) {
180 if (((c = rt.getPendingCount()) & CUMULATE) != 0)
181 break;
182 if (rt.compareAndSetPendingCount(c, c|CUMULATE)){
183 t = rt;
184 break;
185 }
186 }
187 }
188 for (int c;;) {
189 if (((c = lt.getPendingCount()) & CUMULATE) != 0)
190 break;
191 if (lt.compareAndSetPendingCount(c, c|CUMULATE)) {
192 if (t != null)
193 f = t;
194 t = lt;
195 break;
196 }
197 }
198 if (t == null)
199 break;
200 }
201 if (f != null)
202 f.fork();
203 }
204 else {
205 int state; // Transition to sum, cumulate, or both
206 for (int b;;) {
207 if (((b = t.getPendingCount()) & FINISHED) != 0)
208 break outer; // already done
209 state = (((b & CUMULATE) != 0)
210 ? FINISHED
211 : (l > 0) ? SUMMED : (SUMMED|FINISHED));
212 if (t.compareAndSetPendingCount(b, b|state))
213 break;
214 }
215
216 long sum = t.in;
217 if (state != SUMMED) {
218 for (int i = l; i < h; ++i) // cumulate
219 a[i] = sum = fn.apply(sum, a[i]);
220 }
221 else if (h < a.length) { // skip rightmost
222 for (int i = l; i < h; ++i) // sum only
223 sum = fn.apply(sum, a[i]);
224 }
225 t.out = sum;
226 for (Cumulater par;;) { // propagate
227 if ((par = (Cumulater)t.getCompleter()) == null) {
228 if ((state & FINISHED) != 0) // enable join
229 t.quietlyComplete();
230 break outer;
231 }
232 int b = par.getPendingCount();
233 if ((b & state & FINISHED) != 0)
234 t = par; // both done
235 else if ((b & state & SUMMED) != 0) { // both summed
236 int nextState; Cumulater lt, rt;
237 if ((lt = par.left) != null &&
238 (rt = par.right) != null)
239 par.out = fn.apply(lt.out, rt.out);
240 int refork = (((b & CUMULATE) == 0 &&
241 par.lo == 0) ? CUMULATE : 0);
242 if ((nextState = b|state|refork) == b ||
243 par.compareAndSetPendingCount(b, nextState)) {
244 state = SUMMED; // drop finished
245 t = par;
246 if (refork != 0)
247 par.fork();
248 }
249 }
250 else if (par.compareAndSetPendingCount(b, b|state))
251 break outer; // sib not ready
252 }
253 }
254 }
255 }
256 }
257
258 // Uses CC reduction via firstComplete/nextComplete
259 static final class Summer extends CountedCompleter<Void> {
260 final long[] array;
261 final LongByLongToLong function;
262 final int lo, hi;
263 final long basis;
264 long result;
265 Summer forks, next; // keeps track of right-hand-side tasks
266 Summer(Summer parent, LongByLongToLong function, long basis,
267 long[] array, int lo, int hi, Summer next) {
268 super(parent);
269 this.function = function; this.basis = basis;
270 this.array = array; this.lo = lo; this.hi = hi;
271 this.next = next;
272 }
273
274 public final void compute() {
275 final long id = basis;
276 final LongByLongToLong fn;
277 final long[] a;
278 if ((fn = this.function) == null || (a = this.array) == null)
279 throw new NullPointerException();
280 int l = lo, h = hi;
281 while (h - l >= THRESHOLD) {
282 int mid = (l + h) >>> 1;
283 addToPendingCount(1);
284 (forks = new Summer(this, fn, id, a, mid, h, forks)).fork();
285 h = mid;
286 }
287 long sum = id;
288 if (l < h && l >= 0 && h <= a.length) {
289 for (int i = l; i < h; ++i)
290 sum = fn.apply(sum, a[i]);
291 }
292 result = sum;
293 CountedCompleter<?> c;
294 for (c = firstComplete(); c != null; c = c.nextComplete()) {
295 Summer t = (Summer)c, s = t.forks;
296 while (s != null) {
297 t.result = fn.apply(t.result, s.result);
298 s = t.forks = s.next;
299 }
300 }
301 }
302 }
303
304 }