/*
 * Written by Doug Lea with assistance from members of JCP JSR-166
 * Expert Group and released to the public domain, as explained at
 * http://creativecommons.org/publicdomain/zero/1.0/
 */

import java.util.*;
import java.util.function.*;
import java.util.concurrent.*;
import java.util.concurrent.atomic.*;

// parallel sums and cumulations

public class FJSums2 {
    static int THRESHOLD;
    static final int MIN_PARTITION = 64;

    public static void main(String[] args) throws Exception {
        int n = 1 << 25;
        int reps = 10;
        try {
            if (args.length > 0)
                n = Integer.parseInt(args[0]);
            if (args.length > 1)
                reps = Integer.parseInt(args[1]);
        }
        catch (Exception e) {
            System.out.println("Usage: java FJSums n reps");
            return;
        }
        int par = ForkJoinPool.getCommonPoolParallelism();
        System.out.println("Number of procs=" + par);
        int p;
        THRESHOLD = (p = n / (par << 3)) <= MIN_PARTITION ? MIN_PARTITION : p;

        long[] a = new long[n];
        for (int i = 0; i < n; ++i)
            a[i] = i + 1;
        int bn = n >> 4;
        Long[] boxed = new Long[bn];
        for (int i = 0; i < bn; ++i)
            boxed[i] = Long.valueOf(i + 1);

        long expected = ((long)n * (long)(n + 1)) / 2;
        long bexpected = ((long)bn * (long)(bn + 1)) / 2;
        for (int i = 0; i < reps; ++i) {
            seqTest(a, i, expected);
            parTest(a, i, expected);
            boxTest(boxed, i, bexpected);
            minTest(boxed, i, bexpected);
        }
        System.out.println(ForkJoinPool.commonPool());
    }

    static void seqTest(long[] a, int index, long expected) {
        System.out.print("Seq ");
        long last = System.nanoTime();
        int n = a.length;
        long ss = seqSum(Long::sum, 0L, a, 0, n);
        double elapsed = elapsedTime(last);
        System.out.printf("sum = %24d  time:  %7.3f\n", ss, elapsed);
        if (index == 0 && ss != expected)
            throw new Error("expected " + expected + " != " + ss);
    }

    static void parTest(long[] a, int index, long expected) {
        System.out.print("Par ");
        long last = System.nanoTime();
        int n = a.length;
        Summer s = new Summer(null, Long::sum, 0L, a, 0, n, null);
        s.invoke();
        long ss = s.result;
        double elapsed = elapsedTime(last);
        System.out.printf("sum = %24d  time:  %7.3f\n", ss, elapsed);
        if (index == 0 && ss != expected)
            throw new Error("expected " + expected + " != " + ss);
        System.out.print("Par ");
        last = System.nanoTime();
        Arrays.parallelPrefix(a, Long::sum);
        long sc = a[n - 1];
        elapsed = elapsedTime(last);
        System.out.printf("cum = %24d  time:  %7.3f\n", sc, elapsed);
        if (sc != ss)
            throw new Error("expected " + ss + " != " + sc);
        if (index == 0) {
            long cs = 0L;
            for (int j = 0; j < n; ++j) {
                if ((cs += (j + 1)) != a[j])
                    throw new Error("wrong element value" + j);
            }
        }
    }

    static void boxTest(Long[] a, int index, long expected) {
        System.out.print("Box ");
        int n = a.length;
        long last = System.nanoTime();
        last = System.nanoTime();
        Arrays.parallelPrefix(a, (Long x, Long y) -> Long.valueOf(x.longValue() +
                                                                  y.longValue()));
        long sc = a[n - 1].longValue();
        double elapsed = elapsedTime(last);
        System.out.printf("cum = %24d  time:  %7.3f\n", sc, elapsed);
        if (index == 0) {
            if (sc != expected)
                throw new Error("expected " + expected + " != " + sc);
            long cs = 0L;
            for (int j = 0; j < n; ++j) {
                if ((cs += (j + 1)) != a[j].longValue())
                    throw new Error("wrong element value");
            }
        }
    }

    static void minTest(Long[] a, int index, long expected) {
        System.out.print("Min ");
        int n = a.length;
        long last = System.nanoTime();
        last = System.nanoTime();
        Arrays.parallelPrefix(a, (Long x, Long y) -> 
                              x.longValue() <= y.longValue() ? x : y);
        long sc = a[n - 1].longValue();
        double elapsed = elapsedTime(last);
        System.out.printf("cum = %24d  time:  %7.3f\n", sc, elapsed);
        if (index == 0) {
            if (sc != 1)
                throw new Error("expected " + expected + " != " + sc);
            long cs = a[0].longValue();
            for (int j = 0; j < n; ++j) {
                if (cs != a[j].longValue())
                    throw new Error("wrong element value");
            }
        }
    }

    static double elapsedTime(long startTime) {
        return (double)(System.nanoTime() - startTime) / (1000L * 1000 * 1000);
    }

    static long seqSum(LongBinaryOperator fn, long basis,
                       long[] a, int l, int h) {
        long sum = basis;
        for (int i = l; i < h; ++i)
            sum = fn.applyAsLong(sum, a[i]);
        return sum;
    }


    // Uses CC reduction via firstComplete/nextComplete
    static final class Summer extends CountedCompleter<Void> {
        final long[] array;
        final LongBinaryOperator function;
        final int lo, hi;
        final long basis;
        long result;
        Summer forks, next; // keeps track of right-hand-side tasks
        Summer(Summer parent, LongBinaryOperator function, long basis,
               long[] array, int lo, int hi, Summer next) {
            super(parent);
            this.function = function; this.basis = basis;
            this.array = array; this.lo = lo; this.hi = hi;
            this.next = next;
        }

        public final void compute() {
            final long id = basis;
            final LongBinaryOperator fn;
            final long[] a;
            if ((fn = this.function) == null || (a = this.array) == null)
                throw new NullPointerException();
            int l = lo,  h = hi;
            while (h - l >= THRESHOLD) {
                int mid = (l + h) >>> 1;
                addToPendingCount(1);
                (forks = new Summer(this, fn, id, a, mid, h, forks)).fork();
                h = mid;
            }
            long sum = id;
            if (l < h && l >= 0 && h <= a.length) {
                for (int i = l; i < h; ++i)
                    sum = fn.applyAsLong(sum, a[i]);
            }
            result = sum;
            CountedCompleter<?> c;
            for (c = firstComplete(); c != null; c = c.nextComplete()) {
                Summer t = (Summer)c, s = t.forks;
                while (s != null) {
                    t.result = fn.applyAsLong(t.result, s.result);
                    s = t.forks = s.next;
                }
            }
        }
    }

}
