--- jsr166/src/test/loops/ScalarLongSort.java 2010/09/01 07:47:27 1.3 +++ jsr166/src/test/loops/ScalarLongSort.java 2010/09/19 12:55:37 1.4 @@ -4,22 +4,37 @@ * http://creativecommons.org/licenses/publicdomain */ -import java.util.*; import java.util.concurrent.*; - -// Based very loosely on cilksort +import java.util.*; class ScalarLongSort { - static final long NPS = (1000L * 1000 * 1000); // time conversion + static final long NPS = (1000L * 1000 * 1000); static int THRESHOLD; static final boolean warmup = true; - public static void main(String[] args) throws Exception { + public static void main (String[] args) throws Exception { + int procs = 0; int n = 1 << 22; - int sreps = 2; int reps = 20; + int sreps = 2; + try { + if (args.length > 0) + procs = Integer.parseInt(args[0]); + if (args.length > 1) + n = Integer.parseInt(args[1]); + if (args.length > 2) + reps = Integer.parseInt(args[1]); + } + catch (Exception e) { + System.out.println("Usage: java ScalarLongSort threads n reps"); + return; + } + ForkJoinPool pool = procs == 0? new ForkJoinPool() : + new ForkJoinPool(procs); + long[] a = new long[n]; + seqRandomFill(a, 0, n); if (warmup) { System.out.printf("Sorting %d longs, %d replications\n", n, 1); @@ -31,7 +46,6 @@ class ScalarLongSort { checkSorted(a); } - ForkJoinPool pool = new ForkJoinPool(); // for now hardwire 8 * #CPUs leaf tasks THRESHOLD = 1 + ((n + 7) >>> 3) / pool.getParallelism(); // THRESHOLD = 1 + ((n + 15) >>> 4) / pool.getParallelism(); @@ -66,9 +80,9 @@ class ScalarLongSort { } static final class Sorter extends RecursiveAction { - final long[] a; + final long[] a; final long[] w; - final int origin; + final int origin; final int n; Sorter(long[] a, long[] w, int origin, int n) { this.a = a; this.w = w; this.origin = origin; this.n = n; @@ -79,24 +93,25 @@ class ScalarLongSort { if (n <= THRESHOLD) Arrays.sort(a, l, l+n); else { // divide in quarters to ensure sorted array in a not w - SubSorter rs; int h = n >>> 1; int q = n >>> 2; int u = h + q; - (rs = new SubSorter - (new Sorter(a, w, l+h, q), - new Sorter(a, w, l+u, n-u), - new Merger(a, w, l+h, q, l+u, n-u, l+h, null))).fork(); - (new SubSorter - (new Sorter(a, w, l, q), - new Sorter(a, w, l+q, h-q), - new Merger(a, w, l, q, l+q, h-q, l, null))).compute(); + SubSorter rs = new SubSorter + (new Sorter(a, w, l+h, q), + new Sorter(a, w, l+u, n-u), + new Merger(a, w, l+h, q, l+u, n-u, l+h, null)); + rs.fork(); + Sorter rl = new Sorter(a, w, l+q, h-q); + rl.fork(); + (new Sorter(a, w, l, q)).compute(); + rl.join(); + (new Merger(a, w, l, q, l+q, h-q, l, null)).compute(); rs.join(); new Merger(w, a, l, h, l+h, n-h, l, null).compute(); } } } - + static final class SubSorter extends RecursiveAction { final Sorter left; final Sorter right; @@ -115,7 +130,7 @@ class ScalarLongSort { static final class Merger extends RecursiveAction { final long[] a; final long[] w; final int lo; final int ln; final int ro; final int rn; final int wo; - final Merger next; + Merger next; Merger(long[] a, long[] w, int lo, int ln, int ro, int rn, int wo, Merger next) { this.a = a; this.w = w; @@ -130,9 +145,9 @@ class ScalarLongSort { * and finding index of right closest to split point. * Uses left-spine decomposition to generate all * merge tasks before bottomming out at base case. - * + * */ - public void compute() { + public final void compute() { Merger rights = null; int nleft = ln; int nright = rn; @@ -154,8 +169,14 @@ class ScalarLongSort { nleft = lh; nright = rh; } + + merge(nleft, nright); + if (rights != null) + collectRights(rights); + + } - // Base case -- merge left and right + final void merge(int nleft, int nright) { int l = lo; int lFence = lo + nleft; int r = ro; @@ -172,20 +193,24 @@ class ScalarLongSort { w[k++] = a[l++]; while (r < rFence) w[k++] = a[r++]; + } - while (rights != null) { - rights.join(); - rights = rights.next; + static void collectRights(Merger rt) { + while (rt != null) { + Merger next = rt.next; + rt.next = null; + if (rt.tryUnfork()) rt.compute(); else rt.join(); + rt = next; } } } - static void checkSorted(long[] a) { + static void checkSorted (long[] a) { int n = a.length; for (int i = 0; i < n - 1; i++) { if (a[i] > a[i+1]) { - throw new Error("Unsorted at " + i + ": " + + throw new Error("Unsorted at " + i + ": " + a[i] + " / " + a[i+1]); } } @@ -204,8 +229,12 @@ class ScalarLongSort { this.array = array; this.lo = lo; this.hi = hi; } public void compute() { - if (hi - lo <= THRESHOLD) - seqRandomFill(array, lo, hi); + if (hi - lo <= THRESHOLD) { + long[] a = array; + ThreadLocalRandom rng = ThreadLocalRandom.current(); + for (int i = lo; i < hi; ++i) + a[i] = rng.nextLong(); + } else { int mid = (lo + hi) >>> 1; RandomFiller r = new RandomFiller(array, mid, hi);