ViewVC Help
View File | Revision Log | Show Annotations | Download File | Root Listing
root/jsr166/jsr166/src/test/loops/FJSums.java
Revision: 1.9
Committed: Sat Sep 12 18:40:09 2015 UTC (8 years, 8 months ago) by dl
Branch: MAIN
Changes since 1.8: +227 -201 lines
Log Message:
Use CC prefix algorithms

File Contents

# User Rev Content
1 dl 1.1 /*
2     * Written by Doug Lea with assistance from members of JCP JSR-166
3     * Expert Group and released to the public domain, as explained at
4 jsr166 1.4 * http://creativecommons.org/publicdomain/zero/1.0/
5 dl 1.1 */
6    
7     import java.util.*;
8     import java.util.concurrent.*;
9     import java.util.concurrent.atomic.*;
10    
11     // parallel sums and cumulations
12    
13     public class FJSums {
14     static int THRESHOLD;
15 dl 1.9 static final int MIN_PARTITION = 64;
16    
17     interface LongByLongToLong { long apply(long a, long b); }
18    
19     static final class Add implements LongByLongToLong {
20     public long apply(long a, long b) { return a + b; }
21     }
22    
23     static final Add ADD = new Add();
24 dl 1.1
25 jsr166 1.5 public static void main(String[] args) throws Exception {
26 dl 1.1 int n = 1 << 25;
27     int reps = 10;
28     try {
29     if (args.length > 0)
30 dl 1.9 n = Integer.parseInt(args[0]);
31 dl 1.1 if (args.length > 1)
32 dl 1.9 reps = Integer.parseInt(args[1]);
33 dl 1.1 }
34     catch (Exception e) {
35 dl 1.9 System.out.println("Usage: java FJSums n reps");
36 dl 1.1 return;
37     }
38 dl 1.9 int par = ForkJoinPool.getCommonPoolParallelism();
39     System.out.println("Number of procs=" + par);
40     int p;
41     THRESHOLD = (p = n / (par << 3)) <= MIN_PARTITION ? MIN_PARTITION : p;
42 jsr166 1.2
43 dl 1.1 long[] a = new long[n];
44     for (int i = 0; i < n; ++i)
45     a[i] = i;
46     long expected = ((long)n * (long)(n - 1)) / 2;
47     for (int i = 0; i < reps; ++i) {
48 dl 1.9 seqTest(a, i, expected);
49     parTest(a, i, expected);
50 dl 1.1 }
51 dl 1.9 System.out.println(ForkJoinPool.commonPool());
52 dl 1.1 }
53    
54 dl 1.9 static void seqTest(long[] a, int index, long expected) {
55     System.out.print("Seq ");
56     long last = System.nanoTime();
57     int n = a.length;
58     long ss = seqSum(ADD, 0L, a, 0, n);
59     double elapsed = elapsedTime(last);
60     System.out.printf("sum = %24d time: %7.3f\n", ss, elapsed);
61     if (index == 0 && ss != expected)
62     throw new Error("expected " + expected + " != " + ss);
63 dl 1.1 }
64    
65 dl 1.9 static void parTest(long[] a, int index, long expected) {
66     System.out.print("Par ");
67     long last = System.nanoTime();
68     int n = a.length;
69     Summer s = new Summer(null, ADD, 0L, a, 0, n, null);
70     s.invoke();
71     long ss = s.result;
72     double elapsed = elapsedTime(last);
73     System.out.printf("sum = %24d time: %7.3f\n", ss, elapsed);
74     if (index == 0 && ss != expected)
75     throw new Error("expected " + expected + " != " + ss);
76     System.out.print("Par ");
77     last = System.nanoTime();
78     new Cumulater(null, ADD, a, 0, n).invoke();
79     long sc = a[n - 1];
80     elapsed = elapsedTime(last);
81     System.out.printf("cum = %24d time: %7.3f\n", ss, elapsed);
82     if (sc != ss)
83     throw new Error("expected " + ss + " != " + sc);
84     if (index == 0) {
85     long cs = 0L;
86     for (int j = 0; j < n; ++j) {
87     if ((cs += j) != a[j])
88     throw new Error("wrong element value");
89     }
90     }
91 dl 1.1 }
92    
93 dl 1.9 static double elapsedTime(long startTime) {
94     return (double)(System.nanoTime() - startTime) / (1000L * 1000 * 1000);
95 dl 1.1 }
96 jsr166 1.2
97 dl 1.9 static long seqSum(LongByLongToLong fn, long basis,
98     long[] a, int l, int h) {
99     long sum = basis;
100     for (int i = l; i < h; ++i)
101     sum = fn.apply(sum, a[i]);
102     return sum;
103 dl 1.1 }
104    
105     /**
106     * Cumulative scan, adapted from ParallelArray code
107     *
108     * A basic version of scan is straightforward.
109     * Keep dividing by two to threshold segment size, and then:
110     * Pass 1: Create tree of partial sums for each segment
111     * Pass 2: For each segment, cumulate with offset of left sibling
112     * See G. Blelloch's http://www.cs.cmu.edu/~scandal/alg/scan.html
113     *
114     * This version improves performance within FJ framework mainly by
115 dl 1.9 * allowing the second pass of ready left-hand sides to proceed
116     * even if some right-hand side first passes are still executing.
117     * It also combines first and second pass for leftmost segment,
118     * and skips the first pass for rightmost segment (whose result is
119     * not needed for second pass).
120 dl 1.1 *
121 dl 1.9 * Managing this relies on ORing some bits in the pendingCount for
122     * phases/states: CUMULATE, SUMMED, and FINISHED. CUMULATE is the
123 dl 1.1 * main phase bit. When false, segments compute only their sum.
124     * When true, they cumulate array elements. CUMULATE is set at
125     * root at beginning of second pass and then propagated down. But
126 dl 1.9 * it may also be set earlier for subtrees with lo==0 (the left
127     * spine of tree). SUMMED is a one bit join count. For leafs, it
128     * is set when summed. For internal nodes, it becomes true when
129     * one child is summed. When the second child finishes summing,
130     * we then moves up tree to trigger the cumulate phase. FINISHED
131     * is also a one bit join count. For leafs, it is set when
132     * cumulated. For internal nodes, it becomes true when one child
133     * is cumulated. When the second child finishes cumulating, it
134     * then moves up tree, completing at the root.
135     *
136     * To better exploit locality and reduce overhead, the compute
137     * method loops starting with the current task, moving if possible
138     * to one of its subtasks rather than forking.
139 dl 1.1 */
140 dl 1.9 static final class Cumulater extends CountedCompleter<Void> {
141     static final int CUMULATE = 1;
142     static final int SUMMED = 2;
143     static final int FINISHED = 4;
144 dl 1.1
145     final long[] array;
146 dl 1.9 final LongByLongToLong function;
147 dl 1.1 Cumulater left, right;
148 dl 1.9 final int lo, hi;
149     long in, out;
150 dl 1.1
151 dl 1.9 Cumulater(Cumulater parent, LongByLongToLong function,
152     long[] array, int lo, int hi) {
153     super(parent);
154     this.function = function; this.array = array;
155     this.lo = lo; this.hi = hi;
156 dl 1.1 }
157    
158 dl 1.9 public final void compute() {
159     final LongByLongToLong fn;
160     final long[] a;
161     if ((fn = this.function) == null || (a = this.array) == null)
162     throw new NullPointerException(); // hoist checks
163     int l, h;
164     Cumulater t = this;
165     outer: while ((l = t.lo) >= 0 && (h = t.hi) <= a.length) {
166     if (h - l > THRESHOLD) {
167     Cumulater lt = t.left, rt = t.right, f;
168     if (lt == null) { // first pass
169     int mid = (l + h) >>> 1;
170     f = rt = t.right = new Cumulater(t, fn, a, mid, h);
171     t = lt = t.left = new Cumulater(t, fn, a, l, mid);
172     }
173     else { // possibly refork
174     long pin = t.in;
175     lt.in = pin;
176     f = t = null;
177     if (rt != null) {
178     rt.in = fn.apply(pin, lt.out);
179     for (int c;;) {
180     if (((c = rt.getPendingCount()) & CUMULATE) != 0)
181     break;
182     if (rt.compareAndSetPendingCount(c, c|CUMULATE)){
183     t = rt;
184     break;
185     }
186     }
187     }
188     for (int c;;) {
189     if (((c = lt.getPendingCount()) & CUMULATE) != 0)
190     break;
191     if (lt.compareAndSetPendingCount(c, c|CUMULATE)) {
192     if (t != null)
193     f = t;
194     t = lt;
195     break;
196     }
197     }
198     if (t == null)
199     break;
200     }
201     if (f != null)
202     f.fork();
203 dl 1.1 }
204 dl 1.9 else {
205     int state; // Transition to sum, cumulate, or both
206     for (int b;;) {
207     if (((b = t.getPendingCount()) & FINISHED) != 0)
208     break outer; // already done
209     state = ((b & CUMULATE) != 0? FINISHED :
210     (l > 0) ? SUMMED : (SUMMED|FINISHED));
211     if (t.compareAndSetPendingCount(b, b|state))
212     break;
213     }
214 dl 1.1
215 dl 1.9 long sum = t.in;
216     if (state != SUMMED) {
217     for (int i = l; i < h; ++i) // cumulate
218     a[i] = sum = fn.apply(sum, a[i]);
219 dl 1.1 }
220 dl 1.9 else if (h < a.length) { // skip rightmost
221     for (int i = l; i < h; ++i) // sum only
222     sum = fn.apply(sum, a[i]);
223 dl 1.1 }
224 dl 1.9 t.out = sum;
225     for (Cumulater par;;) { // propagate
226     if ((par = (Cumulater)t.getCompleter()) == null) {
227     if ((state & FINISHED) != 0) // enable join
228     t.quietlyComplete();
229     break outer;
230 dl 1.1 }
231 dl 1.9 int b = par.getPendingCount();
232     if ((b & state & FINISHED) != 0)
233     t = par; // both done
234     else if ((b & state & SUMMED) != 0) { // both summed
235     int nextState; Cumulater lt, rt;
236     if ((lt = par.left) != null &&
237     (rt = par.right) != null)
238     par.out = fn.apply(lt.out, rt.out);
239     int refork = (((b & CUMULATE) == 0 &&
240     par.lo == 0) ? CUMULATE : 0);
241     if ((nextState = b|state|refork) == b ||
242     par.compareAndSetPendingCount(b, nextState)) {
243     state = SUMMED; // drop finished
244     t = par;
245     if (refork != 0)
246     par.fork();
247     }
248     }
249     else if (par.compareAndSetPendingCount(b, b|state))
250     break outer; // sib not ready
251 dl 1.1 }
252     }
253     }
254     }
255 dl 1.9 }
256 dl 1.1
257 dl 1.9 // Uses CC reduction via firstComplete/nextComplete
258     static final class Summer extends CountedCompleter<Void> {
259     final long[] array;
260     final LongByLongToLong function;
261     final int lo, hi;
262     final long basis;
263     long result;
264     Summer forks, next; // keeps track of right-hand-side tasks
265     Summer(Summer parent, LongByLongToLong function, long basis,
266     long[] array, int lo, int hi, Summer next) {
267     super(parent);
268     this.function = function; this.basis = basis;
269     this.array = array; this.lo = lo; this.hi = hi;
270     this.next = next;
271     }
272    
273     public final void compute() {
274     final long id = basis;
275     final LongByLongToLong fn;
276     final long[] a;
277     if ((fn = this.function) == null || (a = this.array) == null)
278     throw new NullPointerException();
279     int l = lo, h = hi;
280     while (h - l >= THRESHOLD) {
281     int mid = (l + h) >>> 1;
282     addToPendingCount(1);
283     (forks = new Summer(this, fn, id, a, mid, h, forks)).fork();
284     h = mid;
285     }
286     long sum = id;
287     if (l < h && l >= 0 && h <= a.length) {
288     for (int i = l; i < h; ++i)
289     sum = fn.apply(sum, a[i]);
290     }
291     result = sum;
292     CountedCompleter<?> c;
293     for (c = firstComplete(); c != null; c = c.nextComplete()) {
294     Summer t = (Summer)c, s = t.forks;
295     while (s != null) {
296     t.result = fn.apply(t.result, s.result);
297     s = t.forks = s.next;
298     }
299     }
300     }
301 dl 1.1 }
302    
303     }