--- jsr166/src/test/loops/FJSums.java 2012/10/21 06:14:12 1.8 +++ jsr166/src/test/loops/FJSums.java 2015/09/12 18:40:09 1.9 @@ -11,120 +11,95 @@ import java.util.concurrent.atomic.*; // parallel sums and cumulations public class FJSums { - static final long NPS = (1000L * 1000 * 1000); static int THRESHOLD; + static final int MIN_PARTITION = 64; + + interface LongByLongToLong { long apply(long a, long b); } + + static final class Add implements LongByLongToLong { + public long apply(long a, long b) { return a + b; } + } + + static final Add ADD = new Add(); public static void main(String[] args) throws Exception { - int procs = 0; int n = 1 << 25; int reps = 10; try { if (args.length > 0) - procs = Integer.parseInt(args[0]); + n = Integer.parseInt(args[0]); if (args.length > 1) - n = Integer.parseInt(args[1]); - if (args.length > 2) - reps = Integer.parseInt(args[2]); + reps = Integer.parseInt(args[1]); } catch (Exception e) { - System.out.println("Usage: java FJSums threads n reps"); + System.out.println("Usage: java FJSums n reps"); return; } - ForkJoinPool g = (procs == 0) ? new ForkJoinPool() : - new ForkJoinPool(procs); - System.out.println("Number of procs=" + g.getParallelism()); - // for now hardwire Cumulate threshold to 8 * #CPUs leaf tasks - THRESHOLD = 1 + ((n + 7) >>> 3) / g.getParallelism(); + int par = ForkJoinPool.getCommonPoolParallelism(); + System.out.println("Number of procs=" + par); + int p; + THRESHOLD = (p = n / (par << 3)) <= MIN_PARTITION ? MIN_PARTITION : p; long[] a = new long[n]; for (int i = 0; i < n; ++i) a[i] = i; long expected = ((long)n * (long)(n - 1)) / 2; - for (int i = 0; i < 2; ++i) { - System.out.print("Seq: "); - long last = System.nanoTime(); - long ss = seqSum(a, 0, n); - double elapsed = elapsedTime(last); - System.out.printf("sum = %24d time: %7.3f\n", ss, elapsed); - if (ss != expected) - throw new Error("expected " + expected + " != " + ss); - } for (int i = 0; i < reps; ++i) { - System.out.print("Par: "); - long last = System.nanoTime(); - Summer s = new Summer(a, 0, a.length, null); - g.invoke(s); - long ss = s.result; - double elapsed = elapsedTime(last); - System.out.printf("sum = %24d time: %7.3f\n", ss, elapsed); - if (i == 0 && ss != expected) - throw new Error("expected " + expected + " != " + ss); - System.out.print("Cum: "); - last = System.nanoTime(); - g.invoke(new Cumulater(null, a, 0, n)); - long sc = a[n - 1]; - elapsed = elapsedTime(last); - System.out.printf("sum = %24d time: %7.3f\n", ss, elapsed); - if (sc != ss) - throw new Error("expected " + ss + " != " + sc); + seqTest(a, i, expected); + parTest(a, i, expected); } - System.out.println(g); - g.shutdown(); + System.out.println(ForkJoinPool.commonPool()); } - static double elapsedTime(long startTime) { - return (double)(System.nanoTime() - startTime) / NPS; + static void seqTest(long[] a, int index, long expected) { + System.out.print("Seq "); + long last = System.nanoTime(); + int n = a.length; + long ss = seqSum(ADD, 0L, a, 0, n); + double elapsed = elapsedTime(last); + System.out.printf("sum = %24d time: %7.3f\n", ss, elapsed); + if (index == 0 && ss != expected) + throw new Error("expected " + expected + " != " + ss); } - static long seqSum(long[] array, int l, int h) { - long sum = 0; - for (int i = l; i < h; ++i) - sum += array[i]; - return sum; + static void parTest(long[] a, int index, long expected) { + System.out.print("Par "); + long last = System.nanoTime(); + int n = a.length; + Summer s = new Summer(null, ADD, 0L, a, 0, n, null); + s.invoke(); + long ss = s.result; + double elapsed = elapsedTime(last); + System.out.printf("sum = %24d time: %7.3f\n", ss, elapsed); + if (index == 0 && ss != expected) + throw new Error("expected " + expected + " != " + ss); + System.out.print("Par "); + last = System.nanoTime(); + new Cumulater(null, ADD, a, 0, n).invoke(); + long sc = a[n - 1]; + elapsed = elapsedTime(last); + System.out.printf("cum = %24d time: %7.3f\n", ss, elapsed); + if (sc != ss) + throw new Error("expected " + ss + " != " + sc); + if (index == 0) { + long cs = 0L; + for (int j = 0; j < n; ++j) { + if ((cs += j) != a[j]) + throw new Error("wrong element value"); + } + } } - static long seqCumulate(long[] array, int lo, int hi, long base) { - long sum = base; - for (int i = lo; i < hi; ++i) - array[i] = sum += array[i]; - return sum; + static double elapsedTime(long startTime) { + return (double)(System.nanoTime() - startTime) / (1000L * 1000 * 1000); } - /** - * Adapted from Applyer demo in RecursiveAction docs - */ - static final class Summer extends RecursiveAction { - final long[] array; - final int lo, hi; - long result; - Summer next; // keeps track of right-hand-side tasks - Summer(long[] array, int lo, int hi, Summer next) { - this.array = array; this.lo = lo; this.hi = hi; - this.next = next; - } - - protected void compute() { - int l = lo; - int h = hi; - Summer right = null; - while (h - l > 1 && getSurplusQueuedTaskCount() <= 3) { - int mid = (l + h) >>> 1; - right = new Summer(array, mid, h, right); - right.fork(); - h = mid; - } - long sum = seqSum(array, l, h); - while (right != null) { - if (right.tryUnfork()) // directly calculate if not stolen - sum += seqSum(array, right.lo, right.hi); - else { - right.join(); - sum += right.result; - } - right = right.next; - } - result = sum; - } + static long seqSum(LongByLongToLong fn, long basis, + long[] a, int l, int h) { + long sum = basis; + for (int i = l; i < h; ++i) + sum = fn.apply(sum, a[i]); + return sum; } /** @@ -137,141 +112,192 @@ public class FJSums { * See G. Blelloch's http://www.cs.cmu.edu/~scandal/alg/scan.html * * This version improves performance within FJ framework mainly by - * allowing second pass of ready left-hand sides to proceed even - * if some right-hand side first passes are still executing. It - * also combines first and second pass for leftmost segment, and - * for cumulate (not precumulate) also skips first pass for - * rightmost segment (whose result is not needed for second pass). + * allowing the second pass of ready left-hand sides to proceed + * even if some right-hand side first passes are still executing. + * It also combines first and second pass for leftmost segment, + * and skips the first pass for rightmost segment (whose result is + * not needed for second pass). * - * To manage this, it relies on "phase" phase/state control field - * maintaining bits CUMULATE, SUMMED, and FINISHED. CUMULATE is + * Managing this relies on ORing some bits in the pendingCount for + * phases/states: CUMULATE, SUMMED, and FINISHED. CUMULATE is the * main phase bit. When false, segments compute only their sum. * When true, they cumulate array elements. CUMULATE is set at * root at beginning of second pass and then propagated down. But - * it may also be set earlier for subtrees with lo==0 (the - * left spine of tree). SUMMED is a one bit join count. For leafs, - * set when summed. For internal nodes, becomes true when one - * child is summed. When second child finishes summing, it then - * moves up tree to trigger cumulate phase. FINISHED is also a one - * bit join count. For leafs, it is set when cumulated. For - * internal nodes, it becomes true when one child is cumulated. - * When second child finishes cumulating, it then moves up tree, - * executing complete() at the root. + * it may also be set earlier for subtrees with lo==0 (the left + * spine of tree). SUMMED is a one bit join count. For leafs, it + * is set when summed. For internal nodes, it becomes true when + * one child is summed. When the second child finishes summing, + * we then moves up tree to trigger the cumulate phase. FINISHED + * is also a one bit join count. For leafs, it is set when + * cumulated. For internal nodes, it becomes true when one child + * is cumulated. When the second child finishes cumulating, it + * then moves up tree, completing at the root. + * + * To better exploit locality and reduce overhead, the compute + * method loops starting with the current task, moving if possible + * to one of its subtasks rather than forking. */ - static final class Cumulater extends ForkJoinTask { - static final short CUMULATE = (short)1; - static final short SUMMED = (short)2; - static final short FINISHED = (short)4; + static final class Cumulater extends CountedCompleter { + static final int CUMULATE = 1; + static final int SUMMED = 2; + static final int FINISHED = 4; - final Cumulater parent; final long[] array; + final LongByLongToLong function; Cumulater left, right; - final int lo; - final int hi; - volatile int phase; // phase/state - long in, out; // initially zero - - static final AtomicIntegerFieldUpdater phaseUpdater = - AtomicIntegerFieldUpdater.newUpdater(Cumulater.class, "phase"); - - Cumulater(Cumulater parent, long[] array, int lo, int hi) { - this.parent = parent; - this.array = array; - this.lo = lo; - this.hi = hi; - } - - public final Void getRawResult() { return null; } - protected final void setRawResult(Void mustBeNull) { } + final int lo, hi; + long in, out; - /** Returns true if can CAS CUMULATE bit true */ - final boolean transitionToCumulate() { - int c; - while (((c = phase) & CUMULATE) == 0) - if (phaseUpdater.compareAndSet(this, c, c | CUMULATE)) - return true; - return false; + Cumulater(Cumulater parent, LongByLongToLong function, + long[] array, int lo, int hi) { + super(parent); + this.function = function; this.array = array; + this.lo = lo; this.hi = hi; } - public final boolean exec() { - if (hi - lo > THRESHOLD) { - if (left == null) { // first pass - int mid = (lo + hi) >>> 1; - left = new Cumulater(this, array, lo, mid); - right = new Cumulater(this, array, mid, hi); - } - - boolean cumulate = (phase & CUMULATE) != 0; - if (cumulate) { - long pin = in; - left.in = pin; - right.in = pin + left.out; - } - - if (!cumulate || right.transitionToCumulate()) - right.fork(); - if (!cumulate || left.transitionToCumulate()) - left.exec(); - } - else { - int cb; - for (;;) { // Establish action: sum, cumulate, or both - int b = phase; - if ((b & FINISHED) != 0) // already done - return false; - if ((b & CUMULATE) != 0) - cb = FINISHED; - else if (lo == 0) // combine leftmost - cb = (SUMMED|FINISHED); - else - cb = SUMMED; - if (phaseUpdater.compareAndSet(this, b, b|cb)) - break; + public final void compute() { + final LongByLongToLong fn; + final long[] a; + if ((fn = this.function) == null || (a = this.array) == null) + throw new NullPointerException(); // hoist checks + int l, h; + Cumulater t = this; + outer: while ((l = t.lo) >= 0 && (h = t.hi) <= a.length) { + if (h - l > THRESHOLD) { + Cumulater lt = t.left, rt = t.right, f; + if (lt == null) { // first pass + int mid = (l + h) >>> 1; + f = rt = t.right = new Cumulater(t, fn, a, mid, h); + t = lt = t.left = new Cumulater(t, fn, a, l, mid); + } + else { // possibly refork + long pin = t.in; + lt.in = pin; + f = t = null; + if (rt != null) { + rt.in = fn.apply(pin, lt.out); + for (int c;;) { + if (((c = rt.getPendingCount()) & CUMULATE) != 0) + break; + if (rt.compareAndSetPendingCount(c, c|CUMULATE)){ + t = rt; + break; + } + } + } + for (int c;;) { + if (((c = lt.getPendingCount()) & CUMULATE) != 0) + break; + if (lt.compareAndSetPendingCount(c, c|CUMULATE)) { + if (t != null) + f = t; + t = lt; + break; + } + } + if (t == null) + break; + } + if (f != null) + f.fork(); } + else { + int state; // Transition to sum, cumulate, or both + for (int b;;) { + if (((b = t.getPendingCount()) & FINISHED) != 0) + break outer; // already done + state = ((b & CUMULATE) != 0? FINISHED : + (l > 0) ? SUMMED : (SUMMED|FINISHED)); + if (t.compareAndSetPendingCount(b, b|state)) + break; + } - if (cb == SUMMED) - out = seqSum(array, lo, hi); - else if (cb == FINISHED) - seqCumulate(array, lo, hi, in); - else if (cb == (SUMMED|FINISHED)) - out = seqCumulate(array, lo, hi, 0L); - - // propagate up - Cumulater ch = this; - Cumulater par = parent; - for (;;) { - if (par == null) { - if ((cb & FINISHED) != 0) - ch.complete(null); - break; + long sum = t.in; + if (state != SUMMED) { + for (int i = l; i < h; ++i) // cumulate + a[i] = sum = fn.apply(sum, a[i]); } - int pb = par.phase; - if ((pb & cb & FINISHED) != 0) { // both finished - ch = par; - par = par.parent; + else if (h < a.length) { // skip rightmost + for (int i = l; i < h; ++i) // sum only + sum = fn.apply(sum, a[i]); } - else if ((pb & cb & SUMMED) != 0) { // both summed - par.out = par.left.out + par.right.out; - int refork = - ((pb & CUMULATE) == 0 && - par.lo == 0) ? CUMULATE : 0; - int nextPhase = pb|cb|refork; - if (pb == nextPhase || - phaseUpdater.compareAndSet(par, pb, nextPhase)) { - if (refork != 0) - par.fork(); - cb = SUMMED; // drop finished bit - ch = par; - par = par.parent; + t.out = sum; + for (Cumulater par;;) { // propagate + if ((par = (Cumulater)t.getCompleter()) == null) { + if ((state & FINISHED) != 0) // enable join + t.quietlyComplete(); + break outer; } + int b = par.getPendingCount(); + if ((b & state & FINISHED) != 0) + t = par; // both done + else if ((b & state & SUMMED) != 0) { // both summed + int nextState; Cumulater lt, rt; + if ((lt = par.left) != null && + (rt = par.right) != null) + par.out = fn.apply(lt.out, rt.out); + int refork = (((b & CUMULATE) == 0 && + par.lo == 0) ? CUMULATE : 0); + if ((nextState = b|state|refork) == b || + par.compareAndSetPendingCount(b, nextState)) { + state = SUMMED; // drop finished + t = par; + if (refork != 0) + par.fork(); + } + } + else if (par.compareAndSetPendingCount(b, b|state)) + break outer; // sib not ready } - else if (phaseUpdater.compareAndSet(par, pb, pb|cb)) - break; } } - return false; } + } + // Uses CC reduction via firstComplete/nextComplete + static final class Summer extends CountedCompleter { + final long[] array; + final LongByLongToLong function; + final int lo, hi; + final long basis; + long result; + Summer forks, next; // keeps track of right-hand-side tasks + Summer(Summer parent, LongByLongToLong function, long basis, + long[] array, int lo, int hi, Summer next) { + super(parent); + this.function = function; this.basis = basis; + this.array = array; this.lo = lo; this.hi = hi; + this.next = next; + } + + public final void compute() { + final long id = basis; + final LongByLongToLong fn; + final long[] a; + if ((fn = this.function) == null || (a = this.array) == null) + throw new NullPointerException(); + int l = lo, h = hi; + while (h - l >= THRESHOLD) { + int mid = (l + h) >>> 1; + addToPendingCount(1); + (forks = new Summer(this, fn, id, a, mid, h, forks)).fork(); + h = mid; + } + long sum = id; + if (l < h && l >= 0 && h <= a.length) { + for (int i = l; i < h; ++i) + sum = fn.apply(sum, a[i]); + } + result = sum; + CountedCompleter c; + for (c = firstComplete(); c != null; c = c.nextComplete()) { + Summer t = (Summer)c, s = t.forks; + while (s != null) { + t.result = fn.apply(t.result, s.result); + s = t.forks = s.next; + } + } + } } }