ViewVC Help
View File | Revision Log | Show Annotations | Download File | Root Listing
root/jsr166/jsr166/src/main/java/util/ArrayPrefixHelpers.java
Revision: 1.8
Committed: Sun Sep 20 17:29:14 2015 UTC (8 years, 7 months ago) by jsr166
Branch: MAIN
Changes since 1.7: +4 -8 lines
Log Message:
rollback previous change; CumulateTask classes must be package private

File Contents

# User Rev Content
1 dl 1.1 /*
2     * Written by Doug Lea with assistance from members of JCP JSR-166
3     * Expert Group and released to the public domain, as explained at
4     * http://creativecommons.org/publicdomain/zero/1.0/
5     */
6    
7     package java.util;
8 jsr166 1.4
9     import java.util.concurrent.CountedCompleter;
10 dl 1.1 import java.util.function.BinaryOperator;
11 jsr166 1.4 import java.util.function.DoubleBinaryOperator;
12 dl 1.1 import java.util.function.IntBinaryOperator;
13     import java.util.function.LongBinaryOperator;
14    
15     /**
16     * ForkJoin tasks to perform Arrays.parallelPrefix operations.
17     *
18     * @author Doug Lea
19     * @since 1.8
20     */
21     class ArrayPrefixHelpers {
22 jsr166 1.3 private ArrayPrefixHelpers() {} // non-instantiable
23 dl 1.1
24     /*
25     * Parallel prefix (aka cumulate, scan) task classes
26     * are based loosely on Guy Blelloch's original
27     * algorithm (http://www.cs.cmu.edu/~scandal/alg/scan.html):
28     * Keep dividing by two to threshold segment size, and then:
29     * Pass 1: Create tree of partial sums for each segment
30     * Pass 2: For each segment, cumulate with offset of left sibling
31     *
32     * This version improves performance within FJ framework mainly by
33     * allowing the second pass of ready left-hand sides to proceed
34     * even if some right-hand side first passes are still executing.
35     * It also combines first and second pass for leftmost segment,
36     * and skips the first pass for rightmost segment (whose result is
37     * not needed for second pass). It similarly manages to avoid
38     * requiring that users supply an identity basis for accumulations
39     * by tracking those segments/subtasks for which the first
40     * existing element is used as base.
41     *
42     * Managing this relies on ORing some bits in the pendingCount for
43     * phases/states: CUMULATE, SUMMED, and FINISHED. CUMULATE is the
44     * main phase bit. When false, segments compute only their sum.
45     * When true, they cumulate array elements. CUMULATE is set at
46     * root at beginning of second pass and then propagated down. But
47     * it may also be set earlier for subtrees with lo==0 (the left
48     * spine of tree). SUMMED is a one bit join count. For leafs, it
49     * is set when summed. For internal nodes, it becomes true when
50     * one child is summed. When the second child finishes summing,
51     * we then moves up tree to trigger the cumulate phase. FINISHED
52     * is also a one bit join count. For leafs, it is set when
53     * cumulated. For internal nodes, it becomes true when one child
54     * is cumulated. When the second child finishes cumulating, it
55     * then moves up tree, completing at the root.
56     *
57     * To better exploit locality and reduce overhead, the compute
58     * method loops starting with the current task, moving if possible
59     * to one of its subtasks rather than forking.
60     *
61     * As usual for this sort of utility, there are 4 versions, that
62     * are simple copy/paste/adapt variants of each other. (The
63     * double and int versions differ from long version solely by
64     * replacing "long" (with case-matching)).
65     */
66    
67     // see above
68     static final int CUMULATE = 1;
69     static final int SUMMED = 2;
70     static final int FINISHED = 4;
71    
72     /** The smallest subtask array partition size to use as threshold */
73     static final int MIN_PARTITION = 16;
74    
75 jsr166 1.8 static final class CumulateTask<T> extends CountedCompleter<Void> {
76 dl 1.1 final T[] array;
77     final BinaryOperator<T> function;
78     CumulateTask<T> left, right;
79     T in, out;
80     final int lo, hi, origin, fence, threshold;
81    
82     CumulateTask(CumulateTask<T> parent, BinaryOperator<T> function,
83     T[] array, int origin, int fence, int threshold,
84     int lo, int hi) {
85     super(parent);
86     this.function = function; this.array = array;
87     this.origin = origin; this.fence = fence;
88     this.threshold = threshold;
89     this.lo = lo; this.hi = hi;
90     }
91    
92     public final void compute() {
93     final BinaryOperator<T> fn;
94     final T[] a;
95     if ((fn = this.function) == null || (a = this.array) == null)
96     throw new NullPointerException(); // hoist checks
97     int th = threshold, org = origin, fnc = fence, l, h;
98     CumulateTask<T> t = this;
99     outer: while ((l = t.lo) >= 0 && (h = t.hi) <= a.length) {
100     if (h - l > th) {
101     CumulateTask<T> lt = t.left, rt = t.right, f;
102     if (lt == null) { // first pass
103     int mid = (l + h) >>> 1;
104     f = rt = t.right =
105     new CumulateTask<T>(t, fn, a, org, fnc, th, mid, h);
106     t = lt = t.left =
107     new CumulateTask<T>(t, fn, a, org, fnc, th, l, mid);
108     }
109     else { // possibly refork
110     T pin = t.in;
111     lt.in = pin;
112     f = t = null;
113     if (rt != null) {
114     T lout = lt.out;
115     rt.in = (l == org ? lout :
116     fn.apply(pin, lout));
117     for (int c;;) {
118     if (((c = rt.getPendingCount()) & CUMULATE) != 0)
119     break;
120     if (rt.compareAndSetPendingCount(c, c|CUMULATE)){
121     t = rt;
122     break;
123     }
124     }
125     }
126     for (int c;;) {
127     if (((c = lt.getPendingCount()) & CUMULATE) != 0)
128     break;
129     if (lt.compareAndSetPendingCount(c, c|CUMULATE)) {
130     if (t != null)
131     f = t;
132     t = lt;
133     break;
134     }
135     }
136     if (t == null)
137     break;
138     }
139     if (f != null)
140     f.fork();
141     }
142     else {
143     int state; // Transition to sum, cumulate, or both
144     for (int b;;) {
145     if (((b = t.getPendingCount()) & FINISHED) != 0)
146     break outer; // already done
147     state = ((b & CUMULATE) != 0 ? FINISHED :
148     (l > org) ? SUMMED : (SUMMED|FINISHED));
149     if (t.compareAndSetPendingCount(b, b|state))
150     break;
151     }
152    
153     T sum;
154     if (state != SUMMED) {
155     int first;
156     if (l == org) { // leftmost; no in
157     sum = a[org];
158     first = org + 1;
159     }
160     else {
161     sum = t.in;
162     first = l;
163     }
164     for (int i = first; i < h; ++i) // cumulate
165     a[i] = sum = fn.apply(sum, a[i]);
166     }
167     else if (h < fnc) { // skip rightmost
168     sum = a[l];
169     for (int i = l + 1; i < h; ++i) // sum only
170     sum = fn.apply(sum, a[i]);
171     }
172     else
173     sum = t.in;
174     t.out = sum;
175     for (CumulateTask<T> par;;) { // propagate
176 jsr166 1.2 @SuppressWarnings("unchecked") CumulateTask<T> partmp
177     = (CumulateTask<T>)t.getCompleter();
178     if ((par = partmp) == null) {
179 dl 1.1 if ((state & FINISHED) != 0) // enable join
180     t.quietlyComplete();
181     break outer;
182     }
183     int b = par.getPendingCount();
184     if ((b & state & FINISHED) != 0)
185     t = par; // both done
186     else if ((b & state & SUMMED) != 0) { // both summed
187     int nextState; CumulateTask<T> lt, rt;
188     if ((lt = par.left) != null &&
189     (rt = par.right) != null) {
190     T lout = lt.out;
191     par.out = (rt.hi == fnc ? lout :
192     fn.apply(lout, rt.out));
193     }
194     int refork = (((b & CUMULATE) == 0 &&
195     par.lo == org) ? CUMULATE : 0);
196     if ((nextState = b|state|refork) == b ||
197     par.compareAndSetPendingCount(b, nextState)) {
198     state = SUMMED; // drop finished
199     t = par;
200     if (refork != 0)
201     par.fork();
202     }
203     }
204     else if (par.compareAndSetPendingCount(b, b|state))
205     break outer; // sib not ready
206     }
207     }
208     }
209     }
210     private static final long serialVersionUID = 5293554502939613543L;
211     }
212    
213 jsr166 1.8 static final class LongCumulateTask extends CountedCompleter<Void> {
214 dl 1.1 final long[] array;
215     final LongBinaryOperator function;
216     LongCumulateTask left, right;
217     long in, out;
218     final int lo, hi, origin, fence, threshold;
219    
220     LongCumulateTask(LongCumulateTask parent, LongBinaryOperator function,
221     long[] array, int origin, int fence, int threshold,
222     int lo, int hi) {
223     super(parent);
224     this.function = function; this.array = array;
225     this.origin = origin; this.fence = fence;
226     this.threshold = threshold;
227     this.lo = lo; this.hi = hi;
228     }
229    
230     public final void compute() {
231     final LongBinaryOperator fn;
232     final long[] a;
233     if ((fn = this.function) == null || (a = this.array) == null)
234     throw new NullPointerException(); // hoist checks
235     int th = threshold, org = origin, fnc = fence, l, h;
236     LongCumulateTask t = this;
237     outer: while ((l = t.lo) >= 0 && (h = t.hi) <= a.length) {
238     if (h - l > th) {
239     LongCumulateTask lt = t.left, rt = t.right, f;
240     if (lt == null) { // first pass
241     int mid = (l + h) >>> 1;
242     f = rt = t.right =
243     new LongCumulateTask(t, fn, a, org, fnc, th, mid, h);
244     t = lt = t.left =
245     new LongCumulateTask(t, fn, a, org, fnc, th, l, mid);
246     }
247     else { // possibly refork
248     long pin = t.in;
249     lt.in = pin;
250     f = t = null;
251     if (rt != null) {
252     long lout = lt.out;
253     rt.in = (l == org ? lout :
254     fn.applyAsLong(pin, lout));
255     for (int c;;) {
256     if (((c = rt.getPendingCount()) & CUMULATE) != 0)
257     break;
258     if (rt.compareAndSetPendingCount(c, c|CUMULATE)){
259     t = rt;
260     break;
261     }
262     }
263     }
264     for (int c;;) {
265     if (((c = lt.getPendingCount()) & CUMULATE) != 0)
266     break;
267     if (lt.compareAndSetPendingCount(c, c|CUMULATE)) {
268     if (t != null)
269     f = t;
270     t = lt;
271     break;
272     }
273     }
274     if (t == null)
275     break;
276     }
277     if (f != null)
278     f.fork();
279     }
280     else {
281     int state; // Transition to sum, cumulate, or both
282     for (int b;;) {
283     if (((b = t.getPendingCount()) & FINISHED) != 0)
284     break outer; // already done
285     state = ((b & CUMULATE) != 0 ? FINISHED :
286     (l > org) ? SUMMED : (SUMMED|FINISHED));
287     if (t.compareAndSetPendingCount(b, b|state))
288     break;
289     }
290    
291     long sum;
292     if (state != SUMMED) {
293     int first;
294     if (l == org) { // leftmost; no in
295     sum = a[org];
296     first = org + 1;
297     }
298     else {
299     sum = t.in;
300     first = l;
301     }
302     for (int i = first; i < h; ++i) // cumulate
303     a[i] = sum = fn.applyAsLong(sum, a[i]);
304     }
305     else if (h < fnc) { // skip rightmost
306     sum = a[l];
307     for (int i = l + 1; i < h; ++i) // sum only
308     sum = fn.applyAsLong(sum, a[i]);
309     }
310     else
311     sum = t.in;
312     t.out = sum;
313     for (LongCumulateTask par;;) { // propagate
314     if ((par = (LongCumulateTask)t.getCompleter()) == null) {
315     if ((state & FINISHED) != 0) // enable join
316     t.quietlyComplete();
317     break outer;
318     }
319     int b = par.getPendingCount();
320     if ((b & state & FINISHED) != 0)
321     t = par; // both done
322     else if ((b & state & SUMMED) != 0) { // both summed
323     int nextState; LongCumulateTask lt, rt;
324     if ((lt = par.left) != null &&
325     (rt = par.right) != null) {
326     long lout = lt.out;
327     par.out = (rt.hi == fnc ? lout :
328     fn.applyAsLong(lout, rt.out));
329     }
330     int refork = (((b & CUMULATE) == 0 &&
331     par.lo == org) ? CUMULATE : 0);
332     if ((nextState = b|state|refork) == b ||
333     par.compareAndSetPendingCount(b, nextState)) {
334     state = SUMMED; // drop finished
335     t = par;
336     if (refork != 0)
337     par.fork();
338     }
339     }
340     else if (par.compareAndSetPendingCount(b, b|state))
341     break outer; // sib not ready
342     }
343     }
344     }
345     }
346     private static final long serialVersionUID = -5074099945909284273L;
347     }
348    
349 jsr166 1.8 static final class DoubleCumulateTask extends CountedCompleter<Void> {
350 dl 1.1 final double[] array;
351     final DoubleBinaryOperator function;
352     DoubleCumulateTask left, right;
353     double in, out;
354     final int lo, hi, origin, fence, threshold;
355    
356     DoubleCumulateTask(DoubleCumulateTask parent, DoubleBinaryOperator function,
357 jsr166 1.5 double[] array, int origin, int fence, int threshold,
358     int lo, int hi) {
359 dl 1.1 super(parent);
360     this.function = function; this.array = array;
361     this.origin = origin; this.fence = fence;
362     this.threshold = threshold;
363     this.lo = lo; this.hi = hi;
364     }
365    
366     public final void compute() {
367     final DoubleBinaryOperator fn;
368     final double[] a;
369     if ((fn = this.function) == null || (a = this.array) == null)
370     throw new NullPointerException(); // hoist checks
371     int th = threshold, org = origin, fnc = fence, l, h;
372     DoubleCumulateTask t = this;
373     outer: while ((l = t.lo) >= 0 && (h = t.hi) <= a.length) {
374     if (h - l > th) {
375     DoubleCumulateTask lt = t.left, rt = t.right, f;
376     if (lt == null) { // first pass
377     int mid = (l + h) >>> 1;
378     f = rt = t.right =
379     new DoubleCumulateTask(t, fn, a, org, fnc, th, mid, h);
380     t = lt = t.left =
381     new DoubleCumulateTask(t, fn, a, org, fnc, th, l, mid);
382     }
383     else { // possibly refork
384     double pin = t.in;
385     lt.in = pin;
386     f = t = null;
387     if (rt != null) {
388     double lout = lt.out;
389     rt.in = (l == org ? lout :
390     fn.applyAsDouble(pin, lout));
391     for (int c;;) {
392     if (((c = rt.getPendingCount()) & CUMULATE) != 0)
393     break;
394     if (rt.compareAndSetPendingCount(c, c|CUMULATE)){
395     t = rt;
396     break;
397     }
398     }
399     }
400     for (int c;;) {
401     if (((c = lt.getPendingCount()) & CUMULATE) != 0)
402     break;
403     if (lt.compareAndSetPendingCount(c, c|CUMULATE)) {
404     if (t != null)
405     f = t;
406     t = lt;
407     break;
408     }
409     }
410     if (t == null)
411     break;
412     }
413     if (f != null)
414     f.fork();
415     }
416     else {
417     int state; // Transition to sum, cumulate, or both
418     for (int b;;) {
419     if (((b = t.getPendingCount()) & FINISHED) != 0)
420     break outer; // already done
421     state = ((b & CUMULATE) != 0 ? FINISHED :
422     (l > org) ? SUMMED : (SUMMED|FINISHED));
423     if (t.compareAndSetPendingCount(b, b|state))
424     break;
425     }
426    
427     double sum;
428     if (state != SUMMED) {
429     int first;
430     if (l == org) { // leftmost; no in
431     sum = a[org];
432     first = org + 1;
433     }
434     else {
435     sum = t.in;
436     first = l;
437     }
438     for (int i = first; i < h; ++i) // cumulate
439     a[i] = sum = fn.applyAsDouble(sum, a[i]);
440     }
441     else if (h < fnc) { // skip rightmost
442     sum = a[l];
443     for (int i = l + 1; i < h; ++i) // sum only
444     sum = fn.applyAsDouble(sum, a[i]);
445     }
446     else
447     sum = t.in;
448     t.out = sum;
449     for (DoubleCumulateTask par;;) { // propagate
450     if ((par = (DoubleCumulateTask)t.getCompleter()) == null) {
451     if ((state & FINISHED) != 0) // enable join
452     t.quietlyComplete();
453     break outer;
454     }
455     int b = par.getPendingCount();
456     if ((b & state & FINISHED) != 0)
457     t = par; // both done
458     else if ((b & state & SUMMED) != 0) { // both summed
459     int nextState; DoubleCumulateTask lt, rt;
460     if ((lt = par.left) != null &&
461     (rt = par.right) != null) {
462     double lout = lt.out;
463     par.out = (rt.hi == fnc ? lout :
464     fn.applyAsDouble(lout, rt.out));
465     }
466     int refork = (((b & CUMULATE) == 0 &&
467     par.lo == org) ? CUMULATE : 0);
468     if ((nextState = b|state|refork) == b ||
469     par.compareAndSetPendingCount(b, nextState)) {
470     state = SUMMED; // drop finished
471     t = par;
472     if (refork != 0)
473     par.fork();
474     }
475     }
476     else if (par.compareAndSetPendingCount(b, b|state))
477     break outer; // sib not ready
478     }
479     }
480     }
481     }
482     private static final long serialVersionUID = -586947823794232033L;
483     }
484    
485 jsr166 1.8 static final class IntCumulateTask extends CountedCompleter<Void> {
486 dl 1.1 final int[] array;
487     final IntBinaryOperator function;
488     IntCumulateTask left, right;
489     int in, out;
490     final int lo, hi, origin, fence, threshold;
491    
492     IntCumulateTask(IntCumulateTask parent, IntBinaryOperator function,
493 jsr166 1.5 int[] array, int origin, int fence, int threshold,
494     int lo, int hi) {
495 dl 1.1 super(parent);
496     this.function = function; this.array = array;
497     this.origin = origin; this.fence = fence;
498     this.threshold = threshold;
499     this.lo = lo; this.hi = hi;
500     }
501    
502     public final void compute() {
503     final IntBinaryOperator fn;
504     final int[] a;
505     if ((fn = this.function) == null || (a = this.array) == null)
506     throw new NullPointerException(); // hoist checks
507     int th = threshold, org = origin, fnc = fence, l, h;
508     IntCumulateTask t = this;
509     outer: while ((l = t.lo) >= 0 && (h = t.hi) <= a.length) {
510     if (h - l > th) {
511     IntCumulateTask lt = t.left, rt = t.right, f;
512     if (lt == null) { // first pass
513     int mid = (l + h) >>> 1;
514     f = rt = t.right =
515     new IntCumulateTask(t, fn, a, org, fnc, th, mid, h);
516     t = lt = t.left =
517     new IntCumulateTask(t, fn, a, org, fnc, th, l, mid);
518     }
519     else { // possibly refork
520     int pin = t.in;
521     lt.in = pin;
522     f = t = null;
523     if (rt != null) {
524     int lout = lt.out;
525     rt.in = (l == org ? lout :
526     fn.applyAsInt(pin, lout));
527     for (int c;;) {
528     if (((c = rt.getPendingCount()) & CUMULATE) != 0)
529     break;
530     if (rt.compareAndSetPendingCount(c, c|CUMULATE)){
531     t = rt;
532     break;
533     }
534     }
535     }
536     for (int c;;) {
537     if (((c = lt.getPendingCount()) & CUMULATE) != 0)
538     break;
539     if (lt.compareAndSetPendingCount(c, c|CUMULATE)) {
540     if (t != null)
541     f = t;
542     t = lt;
543     break;
544     }
545     }
546     if (t == null)
547     break;
548     }
549     if (f != null)
550     f.fork();
551     }
552     else {
553     int state; // Transition to sum, cumulate, or both
554     for (int b;;) {
555     if (((b = t.getPendingCount()) & FINISHED) != 0)
556     break outer; // already done
557     state = ((b & CUMULATE) != 0 ? FINISHED :
558     (l > org) ? SUMMED : (SUMMED|FINISHED));
559     if (t.compareAndSetPendingCount(b, b|state))
560     break;
561     }
562    
563     int sum;
564     if (state != SUMMED) {
565     int first;
566     if (l == org) { // leftmost; no in
567     sum = a[org];
568     first = org + 1;
569     }
570     else {
571     sum = t.in;
572     first = l;
573     }
574     for (int i = first; i < h; ++i) // cumulate
575     a[i] = sum = fn.applyAsInt(sum, a[i]);
576     }
577     else if (h < fnc) { // skip rightmost
578     sum = a[l];
579     for (int i = l + 1; i < h; ++i) // sum only
580     sum = fn.applyAsInt(sum, a[i]);
581     }
582     else
583     sum = t.in;
584     t.out = sum;
585     for (IntCumulateTask par;;) { // propagate
586     if ((par = (IntCumulateTask)t.getCompleter()) == null) {
587     if ((state & FINISHED) != 0) // enable join
588     t.quietlyComplete();
589     break outer;
590     }
591     int b = par.getPendingCount();
592     if ((b & state & FINISHED) != 0)
593     t = par; // both done
594     else if ((b & state & SUMMED) != 0) { // both summed
595     int nextState; IntCumulateTask lt, rt;
596     if ((lt = par.left) != null &&
597     (rt = par.right) != null) {
598     int lout = lt.out;
599     par.out = (rt.hi == fnc ? lout :
600     fn.applyAsInt(lout, rt.out));
601     }
602     int refork = (((b & CUMULATE) == 0 &&
603     par.lo == org) ? CUMULATE : 0);
604     if ((nextState = b|state|refork) == b ||
605     par.compareAndSetPendingCount(b, nextState)) {
606     state = SUMMED; // drop finished
607     t = par;
608     if (refork != 0)
609     par.fork();
610     }
611     }
612     else if (par.compareAndSetPendingCount(b, b|state))
613     break outer; // sib not ready
614     }
615     }
616     }
617     }
618     private static final long serialVersionUID = 3731755594596840961L;
619     }
620    
621     }