ViewVC Help
View File | Revision Log | Show Annotations | Download File | Root Listing
root/jsr166/jsr166/src/test/loops/FJSums.java
(Generate patch)

Comparing jsr166/src/test/loops/FJSums.java (file contents):
Revision 1.8 by jsr166, Sun Oct 21 06:14:12 2012 UTC vs.
Revision 1.9 by dl, Sat Sep 12 18:40:09 2015 UTC

# Line 11 | Line 11 | import java.util.concurrent.atomic.*;
11   // parallel sums and cumulations
12  
13   public class FJSums {
14    static final long NPS = (1000L * 1000 * 1000);
14      static int THRESHOLD;
15 +    static final int MIN_PARTITION = 64;
16 +
17 +    interface LongByLongToLong { long apply(long a, long b); }
18 +
19 +    static final class Add implements LongByLongToLong {
20 +        public long apply(long a, long b) { return a + b; }
21 +    }
22 +
23 +    static final Add ADD = new Add();
24  
25      public static void main(String[] args) throws Exception {
18        int procs = 0;
26          int n = 1 << 25;
27          int reps = 10;
28          try {
29              if (args.length > 0)
30 <                procs = Integer.parseInt(args[0]);
30 >                n = Integer.parseInt(args[0]);
31              if (args.length > 1)
32 <                n = Integer.parseInt(args[1]);
26 <            if (args.length > 2)
27 <                reps = Integer.parseInt(args[2]);
32 >                reps = Integer.parseInt(args[1]);
33          }
34          catch (Exception e) {
35 <            System.out.println("Usage: java FJSums threads n reps");
35 >            System.out.println("Usage: java FJSums n reps");
36              return;
37          }
38 <        ForkJoinPool g = (procs == 0) ? new ForkJoinPool() :
39 <            new ForkJoinPool(procs);
40 <        System.out.println("Number of procs=" + g.getParallelism());
41 <        // for now hardwire Cumulate threshold to 8 * #CPUs leaf tasks
37 <        THRESHOLD = 1 + ((n + 7) >>> 3) / g.getParallelism();
38 >        int par = ForkJoinPool.getCommonPoolParallelism();
39 >        System.out.println("Number of procs=" + par);
40 >        int p;
41 >        THRESHOLD = (p = n / (par << 3)) <= MIN_PARTITION ? MIN_PARTITION : p;
42  
43          long[] a = new long[n];
44          for (int i = 0; i < n; ++i)
45              a[i] = i;
46          long expected = ((long)n * (long)(n - 1)) / 2;
43        for (int i = 0; i < 2; ++i) {
44            System.out.print("Seq: ");
45            long last = System.nanoTime();
46            long ss = seqSum(a, 0, n);
47            double elapsed = elapsedTime(last);
48            System.out.printf("sum = %24d  time:  %7.3f\n", ss, elapsed);
49            if (ss != expected)
50                throw new Error("expected " + expected + " != " + ss);
51        }
47          for (int i = 0; i < reps; ++i) {
48 <            System.out.print("Par: ");
49 <            long last = System.nanoTime();
55 <            Summer s = new Summer(a, 0, a.length, null);
56 <            g.invoke(s);
57 <            long ss = s.result;
58 <            double elapsed = elapsedTime(last);
59 <            System.out.printf("sum = %24d  time:  %7.3f\n", ss, elapsed);
60 <            if (i == 0 && ss != expected)
61 <                throw new Error("expected " + expected + " != " + ss);
62 <            System.out.print("Cum: ");
63 <            last = System.nanoTime();
64 <            g.invoke(new Cumulater(null, a, 0, n));
65 <            long sc = a[n - 1];
66 <            elapsed = elapsedTime(last);
67 <            System.out.printf("sum = %24d  time:  %7.3f\n", ss, elapsed);
68 <            if (sc != ss)
69 <                throw new Error("expected " + ss + " != " + sc);
48 >            seqTest(a, i, expected);
49 >            parTest(a, i, expected);
50          }
51 <        System.out.println(g);
72 <        g.shutdown();
51 >        System.out.println(ForkJoinPool.commonPool());
52      }
53  
54 <    static double elapsedTime(long startTime) {
55 <        return (double)(System.nanoTime() - startTime) / NPS;
54 >    static void seqTest(long[] a, int index, long expected) {
55 >        System.out.print("Seq ");
56 >        long last = System.nanoTime();
57 >        int n = a.length;
58 >        long ss = seqSum(ADD, 0L, a, 0, n);
59 >        double elapsed = elapsedTime(last);
60 >        System.out.printf("sum = %24d  time:  %7.3f\n", ss, elapsed);
61 >        if (index == 0 && ss != expected)
62 >            throw new Error("expected " + expected + " != " + ss);
63      }
64  
65 <    static long seqSum(long[] array, int l, int h) {
66 <        long sum = 0;
67 <        for (int i = l; i < h; ++i)
68 <            sum += array[i];
69 <        return sum;
65 >    static void parTest(long[] a, int index, long expected) {
66 >        System.out.print("Par ");
67 >        long last = System.nanoTime();
68 >        int n = a.length;
69 >        Summer s = new Summer(null, ADD, 0L, a, 0, n, null);
70 >        s.invoke();
71 >        long ss = s.result;
72 >        double elapsed = elapsedTime(last);
73 >        System.out.printf("sum = %24d  time:  %7.3f\n", ss, elapsed);
74 >        if (index == 0 && ss != expected)
75 >            throw new Error("expected " + expected + " != " + ss);
76 >        System.out.print("Par ");
77 >        last = System.nanoTime();
78 >        new Cumulater(null, ADD, a, 0, n).invoke();
79 >        long sc = a[n - 1];
80 >        elapsed = elapsedTime(last);
81 >        System.out.printf("cum = %24d  time:  %7.3f\n", ss, elapsed);
82 >        if (sc != ss)
83 >            throw new Error("expected " + ss + " != " + sc);
84 >        if (index == 0) {
85 >            long cs = 0L;
86 >            for (int j = 0; j < n; ++j) {
87 >                if ((cs += j) != a[j])
88 >                    throw new Error("wrong element value");
89 >            }
90 >        }
91      }
92  
93 <    static long seqCumulate(long[] array, int lo, int hi, long base) {
94 <        long sum = base;
88 <        for (int i = lo; i < hi; ++i)
89 <            array[i] = sum += array[i];
90 <        return sum;
93 >    static double elapsedTime(long startTime) {
94 >        return (double)(System.nanoTime() - startTime) / (1000L * 1000 * 1000);
95      }
96  
97 <    /**
98 <     * Adapted from Applyer demo in RecursiveAction docs
99 <     */
100 <    static final class Summer extends RecursiveAction {
101 <        final long[] array;
102 <        final int lo, hi;
99 <        long result;
100 <        Summer next; // keeps track of right-hand-side tasks
101 <        Summer(long[] array, int lo, int hi, Summer next) {
102 <            this.array = array; this.lo = lo; this.hi = hi;
103 <            this.next = next;
104 <        }
105 <
106 <        protected void compute() {
107 <            int l = lo;
108 <            int h = hi;
109 <            Summer right = null;
110 <            while (h - l > 1 && getSurplusQueuedTaskCount() <= 3) {
111 <                int mid = (l + h) >>> 1;
112 <                right = new Summer(array, mid, h, right);
113 <                right.fork();
114 <                h = mid;
115 <            }
116 <            long sum = seqSum(array, l, h);
117 <            while (right != null) {
118 <                if (right.tryUnfork()) // directly calculate if not stolen
119 <                    sum += seqSum(array, right.lo, right.hi);
120 <                else {
121 <                    right.join();
122 <                    sum += right.result;
123 <                }
124 <                right = right.next;
125 <            }
126 <            result = sum;
127 <        }
97 >    static long seqSum(LongByLongToLong fn, long basis,
98 >                       long[] a, int l, int h) {
99 >        long sum = basis;
100 >        for (int i = l; i < h; ++i)
101 >            sum = fn.apply(sum, a[i]);
102 >        return sum;
103      }
104  
105      /**
# Line 137 | Line 112 | public class FJSums {
112       * See G. Blelloch's http://www.cs.cmu.edu/~scandal/alg/scan.html
113       *
114       * This version improves performance within FJ framework mainly by
115 <     * allowing second pass of ready left-hand sides to proceed even
116 <     * if some right-hand side first passes are still executing.  It
117 <     * also combines first and second pass for leftmost segment, and
118 <     * for cumulate (not precumulate) also skips first pass for
119 <     * rightmost segment (whose result is not needed for second pass).
115 >     * allowing the second pass of ready left-hand sides to proceed
116 >     * even if some right-hand side first passes are still executing.
117 >     * It also combines first and second pass for leftmost segment,
118 >     * and skips the first pass for rightmost segment (whose result is
119 >     * not needed for second pass).
120       *
121 <     * To manage this, it relies on "phase" phase/state control field
122 <     * maintaining bits CUMULATE, SUMMED, and FINISHED. CUMULATE is
121 >     * Managing this relies on ORing some bits in the pendingCount for
122 >     * phases/states: CUMULATE, SUMMED, and FINISHED. CUMULATE is the
123       * main phase bit. When false, segments compute only their sum.
124       * When true, they cumulate array elements. CUMULATE is set at
125       * root at beginning of second pass and then propagated down. But
126 <     * it may also be set earlier for subtrees with lo==0 (the
127 <     * left spine of tree). SUMMED is a one bit join count. For leafs,
128 <     * set when summed. For internal nodes, becomes true when one
129 <     * child is summed.  When second child finishes summing, it then
130 <     * moves up tree to trigger cumulate phase. FINISHED is also a one
131 <     * bit join count. For leafs, it is set when cumulated. For
132 <     * internal nodes, it becomes true when one child is cumulated.
133 <     * When second child finishes cumulating, it then moves up tree,
134 <     * executing complete() at the root.
126 >     * it may also be set earlier for subtrees with lo==0 (the left
127 >     * spine of tree). SUMMED is a one bit join count. For leafs, it
128 >     * is set when summed. For internal nodes, it becomes true when
129 >     * one child is summed.  When the second child finishes summing,
130 >     * we then moves up tree to trigger the cumulate phase. FINISHED
131 >     * is also a one bit join count. For leafs, it is set when
132 >     * cumulated. For internal nodes, it becomes true when one child
133 >     * is cumulated.  When the second child finishes cumulating, it
134 >     * then moves up tree, completing at the root.
135 >     *
136 >     * To better exploit locality and reduce overhead, the compute
137 >     * method loops starting with the current task, moving if possible
138 >     * to one of its subtasks rather than forking.
139       */
140 <    static final class Cumulater extends ForkJoinTask<Void> {
141 <        static final short CUMULATE = (short)1;
142 <        static final short SUMMED   = (short)2;
143 <        static final short FINISHED = (short)4;
140 >    static final class Cumulater extends CountedCompleter<Void> {
141 >        static final int CUMULATE = 1;
142 >        static final int SUMMED   = 2;
143 >        static final int FINISHED = 4;
144  
166        final Cumulater parent;
145          final long[] array;
146 +        final LongByLongToLong function;
147          Cumulater left, right;
148 <        final int lo;
149 <        final int hi;
171 <        volatile int phase;  // phase/state
172 <        long in, out; // initially zero
173 <
174 <        static final AtomicIntegerFieldUpdater<Cumulater> phaseUpdater =
175 <            AtomicIntegerFieldUpdater.newUpdater(Cumulater.class, "phase");
176 <
177 <        Cumulater(Cumulater parent, long[] array, int lo, int hi) {
178 <            this.parent = parent;
179 <            this.array = array;
180 <            this.lo = lo;
181 <            this.hi = hi;
182 <        }
183 <
184 <        public final Void getRawResult() { return null; }
185 <        protected final void setRawResult(Void mustBeNull) { }
148 >        final int lo, hi;
149 >        long in, out;
150  
151 <        /** Returns true if can CAS CUMULATE bit true */
152 <        final boolean transitionToCumulate() {
153 <            int c;
154 <            while (((c = phase) & CUMULATE) == 0)
155 <                if (phaseUpdater.compareAndSet(this, c, c | CUMULATE))
192 <                    return true;
193 <            return false;
151 >        Cumulater(Cumulater parent, LongByLongToLong function,
152 >                  long[] array, int lo, int hi) {
153 >            super(parent);
154 >            this.function = function; this.array = array;
155 >            this.lo = lo; this.hi = hi;
156          }
157  
158 <        public final boolean exec() {
159 <            if (hi - lo > THRESHOLD) {
160 <                if (left == null) { // first pass
161 <                    int mid = (lo + hi) >>> 1;
162 <                    left =  new Cumulater(this, array, lo, mid);
163 <                    right = new Cumulater(this, array, mid, hi);
164 <                }
165 <
166 <                boolean cumulate = (phase & CUMULATE) != 0;
167 <                if (cumulate) {
168 <                    long pin = in;
169 <                    left.in = pin;
170 <                    right.in = pin + left.out;
171 <                }
172 <
173 <                if (!cumulate || right.transitionToCumulate())
174 <                    right.fork();
175 <                if (!cumulate || left.transitionToCumulate())
176 <                    left.exec();
177 <            }
178 <            else {
179 <                int cb;
180 <                for (;;) { // Establish action: sum, cumulate, or both
181 <                    int b = phase;
182 <                    if ((b & FINISHED) != 0) // already done
183 <                        return false;
184 <                    if ((b & CUMULATE) != 0)
185 <                        cb = FINISHED;
186 <                    else if (lo == 0) // combine leftmost
187 <                        cb = (SUMMED|FINISHED);
188 <                    else
189 <                        cb = SUMMED;
190 <                    if (phaseUpdater.compareAndSet(this, b, b|cb))
191 <                        break;
158 >        public final void compute() {
159 >            final LongByLongToLong fn;
160 >            final long[] a;
161 >            if ((fn = this.function) == null || (a = this.array) == null)
162 >                throw new NullPointerException();    // hoist checks
163 >            int l, h;
164 >            Cumulater t = this;
165 >            outer: while ((l = t.lo) >= 0 && (h = t.hi) <= a.length) {
166 >                if (h - l > THRESHOLD) {
167 >                    Cumulater lt = t.left, rt = t.right, f;
168 >                    if (lt == null) {                // first pass
169 >                        int mid = (l + h) >>> 1;
170 >                        f = rt = t.right = new Cumulater(t, fn, a, mid, h);
171 >                        t = lt = t.left  = new Cumulater(t, fn, a, l, mid);
172 >                    }
173 >                    else {                           // possibly refork
174 >                        long pin = t.in;
175 >                        lt.in = pin;
176 >                        f = t = null;
177 >                        if (rt != null) {
178 >                            rt.in = fn.apply(pin, lt.out);
179 >                            for (int c;;) {
180 >                                if (((c = rt.getPendingCount()) & CUMULATE) != 0)
181 >                                    break;
182 >                                if (rt.compareAndSetPendingCount(c, c|CUMULATE)){
183 >                                    t = rt;
184 >                                    break;
185 >                                }
186 >                            }
187 >                        }
188 >                        for (int c;;) {
189 >                            if (((c = lt.getPendingCount()) & CUMULATE) != 0)
190 >                                break;
191 >                            if (lt.compareAndSetPendingCount(c, c|CUMULATE)) {
192 >                                if (t != null)
193 >                                    f = t;
194 >                                t = lt;
195 >                                break;
196 >                            }
197 >                        }
198 >                        if (t == null)
199 >                            break;
200 >                    }
201 >                    if (f != null)
202 >                        f.fork();
203                  }
204 +                else {
205 +                    int state; // Transition to sum, cumulate, or both
206 +                    for (int b;;) {
207 +                        if (((b = t.getPendingCount()) & FINISHED) != 0)
208 +                            break outer;                      // already done
209 +                        state = ((b & CUMULATE) != 0? FINISHED :
210 +                                 (l > 0) ? SUMMED : (SUMMED|FINISHED));
211 +                        if (t.compareAndSetPendingCount(b, b|state))
212 +                            break;
213 +                    }
214  
215 <                if (cb == SUMMED)
216 <                    out = seqSum(array, lo, hi);
217 <                else if (cb == FINISHED)
218 <                    seqCumulate(array, lo, hi, in);
236 <                else if (cb == (SUMMED|FINISHED))
237 <                    out = seqCumulate(array, lo, hi, 0L);
238 <
239 <                // propagate up
240 <                Cumulater ch = this;
241 <                Cumulater par = parent;
242 <                for (;;) {
243 <                    if (par == null) {
244 <                        if ((cb & FINISHED) != 0)
245 <                            ch.complete(null);
246 <                        break;
215 >                    long sum = t.in;
216 >                    if (state != SUMMED) {
217 >                        for (int i = l; i < h; ++i)           // cumulate
218 >                            a[i] = sum = fn.apply(sum, a[i]);
219                      }
220 <                    int pb = par.phase;
221 <                    if ((pb & cb & FINISHED) != 0) { // both finished
222 <                        ch = par;
251 <                        par = par.parent;
220 >                    else if (h < a.length) {                  // skip rightmost
221 >                        for (int i = l; i < h; ++i)           // sum only
222 >                            sum = fn.apply(sum, a[i]);
223                      }
224 <                    else if ((pb & cb & SUMMED) != 0) { // both summed
225 <                        par.out = par.left.out + par.right.out;
226 <                        int refork =
227 <                            ((pb & CUMULATE) == 0 &&
228 <                             par.lo == 0) ? CUMULATE : 0;
229 <                        int nextPhase = pb|cb|refork;
259 <                        if (pb == nextPhase ||
260 <                            phaseUpdater.compareAndSet(par, pb, nextPhase)) {
261 <                            if (refork != 0)
262 <                                par.fork();
263 <                            cb = SUMMED; // drop finished bit
264 <                            ch = par;
265 <                            par = par.parent;
224 >                    t.out = sum;
225 >                    for (Cumulater par;;) {                   // propagate
226 >                        if ((par = (Cumulater)t.getCompleter()) == null) {
227 >                            if ((state & FINISHED) != 0)      // enable join
228 >                                t.quietlyComplete();
229 >                            break outer;
230                          }
231 +                        int b = par.getPendingCount();
232 +                        if ((b & state & FINISHED) != 0)
233 +                            t = par;                          // both done
234 +                        else if ((b & state & SUMMED) != 0) { // both summed
235 +                            int nextState; Cumulater lt, rt;
236 +                            if ((lt = par.left) != null &&
237 +                                (rt = par.right) != null)
238 +                                par.out = fn.apply(lt.out, rt.out);
239 +                            int refork = (((b & CUMULATE) == 0 &&
240 +                                           par.lo == 0) ? CUMULATE : 0);
241 +                            if ((nextState = b|state|refork) == b ||
242 +                                par.compareAndSetPendingCount(b, nextState)) {
243 +                                state = SUMMED;               // drop finished
244 +                                t = par;
245 +                                if (refork != 0)
246 +                                    par.fork();
247 +                            }
248 +                        }
249 +                        else if (par.compareAndSetPendingCount(b, b|state))
250 +                            break outer;                      // sib not ready
251                      }
268                    else if (phaseUpdater.compareAndSet(par, pb, pb|cb))
269                        break;
252                  }
253              }
272            return false;
254          }
255 +    }
256  
257 +    // Uses CC reduction via firstComplete/nextComplete
258 +    static final class Summer extends CountedCompleter<Void> {
259 +        final long[] array;
260 +        final LongByLongToLong function;
261 +        final int lo, hi;
262 +        final long basis;
263 +        long result;
264 +        Summer forks, next; // keeps track of right-hand-side tasks
265 +        Summer(Summer parent, LongByLongToLong function, long basis,
266 +               long[] array, int lo, int hi, Summer next) {
267 +            super(parent);
268 +            this.function = function; this.basis = basis;
269 +            this.array = array; this.lo = lo; this.hi = hi;
270 +            this.next = next;
271 +        }
272 +
273 +        public final void compute() {
274 +            final long id = basis;
275 +            final LongByLongToLong fn;
276 +            final long[] a;
277 +            if ((fn = this.function) == null || (a = this.array) == null)
278 +                throw new NullPointerException();
279 +            int l = lo,  h = hi;
280 +            while (h - l >= THRESHOLD) {
281 +                int mid = (l + h) >>> 1;
282 +                addToPendingCount(1);
283 +                (forks = new Summer(this, fn, id, a, mid, h, forks)).fork();
284 +                h = mid;
285 +            }
286 +            long sum = id;
287 +            if (l < h && l >= 0 && h <= a.length) {
288 +                for (int i = l; i < h; ++i)
289 +                    sum = fn.apply(sum, a[i]);
290 +            }
291 +            result = sum;
292 +            CountedCompleter<?> c;
293 +            for (c = firstComplete(); c != null; c = c.nextComplete()) {
294 +                Summer t = (Summer)c, s = t.forks;
295 +                while (s != null) {
296 +                    t.result = fn.apply(t.result, s.result);
297 +                    s = t.forks = s.next;
298 +                }
299 +            }
300 +        }
301      }
302  
303   }

Diff Legend

Removed lines
+ Added lines
< Changed lines
> Changed lines