/*
 * Written by Doug Lea with assistance from members of JCP JSR-166
 * Expert Group and released to the public domain, as explained at
 * http://creativecommons.org/publicdomain/zero/1.0/
 */

import java.util.concurrent.*;
import java.util.*;

class ScalarArraysSort {
    static final long NPS = (1000L * 1000 * 1000);
    static int THRESHOLD;
    static long[] numbers;

    public static void main(String[] args) throws Exception {
        int n = 1 << 26;
        int reps = 30;
        int sreps = 2;
        try {
            if (args.length > 0)
                n = Integer.parseInt(args[1]);
            if (args.length > 1)
                reps = Integer.parseInt(args[2]);
        }
        catch (Exception e) {
            System.out.println("Usage: java ScalarArraysSort n reps");
            return;
        }

        // for packing/checking
        int thr = ((n + 7) >>> 3) / ForkJoinPool.getCommonPoolParallelism();
        THRESHOLD = (thr <= 1 << 13) ? 1 << 13 : thr;
        
        long[] a = new long[n];
        ForkJoinPool pool = ForkJoinPool.commonPool();
        seqTest(a, n, 1);
        System.out.println(pool);
        parTest(a, n, reps);
        System.out.println(pool);
        seqTest(a, n, 2);
        System.out.println(pool);
    }

    static void seqTest(long[] a, int n, int reps) {
        System.out.printf("Sorting %d longs, %d replications\n", n, reps);
        long start = System.nanoTime();
        for (int i = 0; i < reps; ++i) {
            new RandomPacker(null, new SplittableRandom(), a, 0, n, n).invoke();
            long last = System.nanoTime();
            java.util.Arrays.sort(a);
            long now = System.nanoTime();
            double total = (double)(now - start) / NPS;
            double elapsed = (double)(now - last) / NPS;
            System.out.printf("Arrays.sort   time:  %7.3f total %9.3f\n",
                              elapsed, total);
            new OrderChecker(null, a, 0, n, n).invoke();
        }
    }

    static void parTest(long[] a, int n, int reps) throws Exception {
        System.out.printf("Sorting %d longs, %d replications\n", n, reps);
        long start = System.nanoTime();
        for (int i = 0; i < reps; ++i) {
            new RandomPacker(null, new SplittableRandom(), a, 0, n, n).invoke();
            long last = System.nanoTime();
            java.util.Arrays.parallelSort(a);
            long now = System.nanoTime();
            double total = (double)(now - start) / NPS;
            double elapsed = (double)(now - last) / NPS;
            System.out.printf("Parallel sort time:  %7.3f total %9.3f\n",
                              elapsed, total);
            new OrderChecker(null, a, 0, n, n).invoke();
        }
    }

    static final class RandomPacker extends CountedCompleter<Void> {
        final SplittableRandom rng;
        final long[] dst;
        final int lo, hi, size;
        RandomPacker(CountedCompleter<?> par,
                     SplittableRandom rng,
                     long[] dst,
                     int lo, int hi, int size) {
            super(par);
            this.rng = rng; this.dst = dst;
            this.lo = lo; this.hi = hi; this.size = size;
        }

        public final void compute() {
            SplittableRandom r = rng;
            long[] d = dst;
            int l = lo, h = hi, n = size;
            while (h - l > THRESHOLD << 1) {
                int m = (l + h) >>> 1;
                addToPendingCount(1);
                new RandomPacker(this, r.split(), d, m, h, n).fork();
                h = m;
            }
            for (int i = l; i < h; ++i)
                d[i] = r.nextLong();
            tryComplete();
        }
    }

    static final class OrderChecker extends CountedCompleter<Void> {
        final long[] array;
        final int lo, hi, size;
        OrderChecker(CountedCompleter<?> par, long[] a, int lo, int hi, int size) {
            super(par);
            this.array = a;
            this.lo = lo; this.hi = hi; this.size = size;
        }

        public final void compute() {
            long[] a = this.array;
            int l = lo, h = hi, n = size;
            while (h - l > THRESHOLD) {
                int m = (l + h) >>> 1;
                addToPendingCount(1);
                new OrderChecker(this, a, m, h, n).fork();
                h = m;
            }
            int bound = h < n ? h : n - 1;
            int i = l;
            long x = a[i], y;
            while (i < bound) {
                if (x > (y = a[++i]))
                    throw new Error("Unsorted " + x + " / " + y);
                x = y;
            }
            tryComplete();
        }
    }

}
