/* * Written by Doug Lea with assistance from members of JCP JSR-166 * Expert Group and released to the public domain, as explained at * http://creativecommons.org/publicdomain/zero/1.0/ */ import java.util.*; import java.util.concurrent.*; import java.util.concurrent.atomic.*; // parallel sums and cumulations public class FJSums { static int THRESHOLD; static final int MIN_PARTITION = 64; interface LongByLongToLong { long apply(long a, long b); } static final class Add implements LongByLongToLong { public long apply(long a, long b) { return a + b; } } static final Add ADD = new Add(); public static void main(String[] args) throws Exception { int n = 1 << 25; int reps = 10; try { if (args.length > 0) n = Integer.parseInt(args[0]); if (args.length > 1) reps = Integer.parseInt(args[1]); } catch (Exception e) { System.out.println("Usage: java FJSums n reps"); return; } int par = ForkJoinPool.getCommonPoolParallelism(); System.out.println("Number of procs=" + par); int p; THRESHOLD = (p = n / (par << 3)) <= MIN_PARTITION ? MIN_PARTITION : p; long[] a = new long[n]; for (int i = 0; i < n; ++i) a[i] = i; long expected = ((long)n * (long)(n - 1)) / 2; for (int i = 0; i < reps; ++i) { seqTest(a, i, expected); parTest(a, i, expected); } System.out.println(ForkJoinPool.commonPool()); } static void seqTest(long[] a, int index, long expected) { System.out.print("Seq "); long last = System.nanoTime(); int n = a.length; long ss = seqSum(ADD, 0L, a, 0, n); double elapsed = elapsedTime(last); System.out.printf("sum = %24d time: %7.3f\n", ss, elapsed); if (index == 0 && ss != expected) throw new Error("expected " + expected + " != " + ss); } static void parTest(long[] a, int index, long expected) { System.out.print("Par "); long last = System.nanoTime(); int n = a.length; Summer s = new Summer(null, ADD, 0L, a, 0, n, null); s.invoke(); long ss = s.result; double elapsed = elapsedTime(last); System.out.printf("sum = %24d time: %7.3f\n", ss, elapsed); if (index == 0 && ss != expected) throw new Error("expected " + expected + " != " + ss); System.out.print("Par "); last = System.nanoTime(); new Cumulater(null, ADD, a, 0, n).invoke(); long sc = a[n - 1]; elapsed = elapsedTime(last); System.out.printf("cum = %24d time: %7.3f\n", ss, elapsed); if (sc != ss) throw new Error("expected " + ss + " != " + sc); if (index == 0) { long cs = 0L; for (int j = 0; j < n; ++j) { if ((cs += j) != a[j]) throw new Error("wrong element value"); } } } static double elapsedTime(long startTime) { return (double)(System.nanoTime() - startTime) / (1000L * 1000 * 1000); } static long seqSum(LongByLongToLong fn, long basis, long[] a, int l, int h) { long sum = basis; for (int i = l; i < h; ++i) sum = fn.apply(sum, a[i]); return sum; } /** * Cumulative scan, adapted from ParallelArray code * * A basic version of scan is straightforward. * Keep dividing by two to threshold segment size, and then: * Pass 1: Create tree of partial sums for each segment * Pass 2: For each segment, cumulate with offset of left sibling * See G. Blelloch's http://www.cs.cmu.edu/~scandal/alg/scan.html * * This version improves performance within FJ framework mainly by * allowing the second pass of ready left-hand sides to proceed * even if some right-hand side first passes are still executing. * It also combines first and second pass for leftmost segment, * and skips the first pass for rightmost segment (whose result is * not needed for second pass). * * Managing this relies on ORing some bits in the pendingCount for * phases/states: CUMULATE, SUMMED, and FINISHED. CUMULATE is the * main phase bit. When false, segments compute only their sum. * When true, they cumulate array elements. CUMULATE is set at * root at beginning of second pass and then propagated down. But * it may also be set earlier for subtrees with lo==0 (the left * spine of tree). SUMMED is a one bit join count. For leafs, it * is set when summed. For internal nodes, it becomes true when * one child is summed. When the second child finishes summing, * we then moves up tree to trigger the cumulate phase. FINISHED * is also a one bit join count. For leafs, it is set when * cumulated. For internal nodes, it becomes true when one child * is cumulated. When the second child finishes cumulating, it * then moves up tree, completing at the root. * * To better exploit locality and reduce overhead, the compute * method loops starting with the current task, moving if possible * to one of its subtasks rather than forking. */ static final class Cumulater extends CountedCompleter { static final int CUMULATE = 1; static final int SUMMED = 2; static final int FINISHED = 4; final long[] array; final LongByLongToLong function; Cumulater left, right; final int lo, hi; long in, out; Cumulater(Cumulater parent, LongByLongToLong function, long[] array, int lo, int hi) { super(parent); this.function = function; this.array = array; this.lo = lo; this.hi = hi; } public final void compute() { final LongByLongToLong fn; final long[] a; if ((fn = this.function) == null || (a = this.array) == null) throw new NullPointerException(); // hoist checks int l, h; Cumulater t = this; outer: while ((l = t.lo) >= 0 && (h = t.hi) <= a.length) { if (h - l > THRESHOLD) { Cumulater lt = t.left, rt = t.right, f; if (lt == null) { // first pass int mid = (l + h) >>> 1; f = rt = t.right = new Cumulater(t, fn, a, mid, h); t = lt = t.left = new Cumulater(t, fn, a, l, mid); } else { // possibly refork long pin = t.in; lt.in = pin; f = t = null; if (rt != null) { rt.in = fn.apply(pin, lt.out); for (int c;;) { if (((c = rt.getPendingCount()) & CUMULATE) != 0) break; if (rt.compareAndSetPendingCount(c, c|CUMULATE)){ t = rt; break; } } } for (int c;;) { if (((c = lt.getPendingCount()) & CUMULATE) != 0) break; if (lt.compareAndSetPendingCount(c, c|CUMULATE)) { if (t != null) f = t; t = lt; break; } } if (t == null) break; } if (f != null) f.fork(); } else { int state; // Transition to sum, cumulate, or both for (int b;;) { if (((b = t.getPendingCount()) & FINISHED) != 0) break outer; // already done state = (((b & CUMULATE) != 0) ? FINISHED : (l > 0) ? SUMMED : (SUMMED|FINISHED)); if (t.compareAndSetPendingCount(b, b|state)) break; } long sum = t.in; if (state != SUMMED) { for (int i = l; i < h; ++i) // cumulate a[i] = sum = fn.apply(sum, a[i]); } else if (h < a.length) { // skip rightmost for (int i = l; i < h; ++i) // sum only sum = fn.apply(sum, a[i]); } t.out = sum; for (Cumulater par;;) { // propagate if ((par = (Cumulater)t.getCompleter()) == null) { if ((state & FINISHED) != 0) // enable join t.quietlyComplete(); break outer; } int b = par.getPendingCount(); if ((b & state & FINISHED) != 0) t = par; // both done else if ((b & state & SUMMED) != 0) { // both summed int nextState; Cumulater lt, rt; if ((lt = par.left) != null && (rt = par.right) != null) par.out = fn.apply(lt.out, rt.out); int refork = (((b & CUMULATE) == 0 && par.lo == 0) ? CUMULATE : 0); if ((nextState = b|state|refork) == b || par.compareAndSetPendingCount(b, nextState)) { state = SUMMED; // drop finished t = par; if (refork != 0) par.fork(); } } else if (par.compareAndSetPendingCount(b, b|state)) break outer; // sib not ready } } } } } // Uses CC reduction via firstComplete/nextComplete static final class Summer extends CountedCompleter { final long[] array; final LongByLongToLong function; final int lo, hi; final long basis; long result; Summer forks, next; // keeps track of right-hand-side tasks Summer(Summer parent, LongByLongToLong function, long basis, long[] array, int lo, int hi, Summer next) { super(parent); this.function = function; this.basis = basis; this.array = array; this.lo = lo; this.hi = hi; this.next = next; } public final void compute() { final long id = basis; final LongByLongToLong fn; final long[] a; if ((fn = this.function) == null || (a = this.array) == null) throw new NullPointerException(); int l = lo, h = hi; while (h - l >= THRESHOLD) { int mid = (l + h) >>> 1; addToPendingCount(1); (forks = new Summer(this, fn, id, a, mid, h, forks)).fork(); h = mid; } long sum = id; if (l < h && l >= 0 && h <= a.length) { for (int i = l; i < h; ++i) sum = fn.apply(sum, a[i]); } result = sum; CountedCompleter c; for (c = firstComplete(); c != null; c = c.nextComplete()) { Summer t = (Summer)c, s = t.forks; while (s != null) { t.result = fn.apply(t.result, s.result); s = t.forks = s.next; } } } } }