/*
 * Written by Doug Lea with assistance from members of JCP JSR-166
 * Expert Group and released to the public domain, as explained at
 * http://creativecommons.org/publicdomain/zero/1.0/
 */

import java.util.*;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinTask;
import java.util.concurrent.RecursiveAction;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CyclicBarrier;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingDeque;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.LinkedTransferQueue;
import java.util.concurrent.Phaser;
import java.util.concurrent.PriorityBlockingQueue;
import java.util.concurrent.SynchronousQueue;

public class FJSQPCL {
    static final int NCPUS = Runtime.getRuntime().availableProcessors();
    static final ForkJoinPool pool = new ForkJoinPool();
    static boolean print = false;
    static int producerSum;
    static int consumerSum;
    static synchronized void addProducerSum(int x) {
        producerSum += x;
    }

    static synchronized void addConsumerSum(int x) {
        consumerSum += x;
    }

    static synchronized void checkSum() {
        if (producerSum != consumerSum)
            throw new Error("CheckSum mismatch");
    }

    // Number of elements passed around -- must be power of two
    // Elements are reused from pool to minimize alloc impact
    static final int POOL_SIZE = 1 << 7;
    static final int POOL_MASK = POOL_SIZE-1;
    static final Integer[] intPool = new Integer[POOL_SIZE];
    static {
        for (int i = 0; i < POOL_SIZE; ++i)
            intPool[i] = Integer.valueOf(i);
    }

    // Number of puts by producers or takes by consumers
    static final int ITERS = 1 << 20;

    public static void main(String[] args) throws Exception {
        int maxPairs = NCPUS * 3 / 2;

        if (args.length > 0)
            maxPairs = Integer.parseInt(args[0]);

        warmup();
        print = true;
        for (int reps = 0; reps < 2; ++reps) {
            for (int k = 1, i = 1; i <= maxPairs;) {
                System.out.println("Pairs:" + i);
                oneTest(i, ITERS);
                if (i == k) {
                    k = i << 1;
                    i = i + (i >>> 1);
                }
                else
                    i = k;
            }
            System.out.println(pool);
        }
        pool.shutdown();
    }

    static void warmup() throws Exception {
        print = false;
        System.out.print("Warmup ");
        int it = 2000;
        for (int j = 5; j > 0; --j) {
            oneTest(j, it);
            System.out.print(".");
            it += 1000;
        }
        System.gc();
        it = 20000;
        for (int j = 5; j > 0; --j) {
            oneTest(j, it);
            System.out.print(".");
            it += 10000;
        }
        System.gc();
        System.out.println();
    }

    static void oneTest(int n, int iters) throws Exception {
        //        Thread.sleep(100); // System.gc();
        if (print)
            System.out.print("SynchronousQueue        ");
        oneRun(new SynchronousQueue<Integer>(), n, iters);

        //        Thread.sleep(100); // System.gc();
        if (print)
            System.out.print("SynchronousQueue(fair)  ");
        oneRun(new SynchronousQueue<Integer>(true), n, iters);
    }

    abstract static class Stage implements Runnable {
        final int iters;
        final BlockingQueue<Integer> queue;
        Stage(BlockingQueue<Integer> q, int iters) {
            queue = q;
            this.iters = iters;
        }
    }

    static class Producer extends Stage {
        Producer(BlockingQueue<Integer> q, int iters) {
            super(q, iters);
        }

        public void run() {
            try {
                int ps = 0;
                int r = hashCode();
                for (int i = 0; i < iters; ++i) {
                    r = LoopHelpers.compute7(r);
                    Integer v = intPool[r & POOL_MASK];
                    int k = v.intValue();
                    queue.put(v);
                    ps += k;
                }
                addProducerSum(ps);
            }
            catch (Exception ie) {
                ie.printStackTrace();
                return;
            }
        }
    }

    static class Consumer extends Stage {
        Consumer(BlockingQueue<Integer> q, int iters) {
            super(q, iters);
        }

        public void run() {
            try {
                int cs = 0;
                for (int i = 0; i < iters; ++i) {
                    Integer v = queue.take();
                    int k = v.intValue();
                    cs += k;
                }
                addConsumerSum(cs);
            }
            catch (Exception ie) {
                ie.printStackTrace();
                return;
            }
        }

    }

    static final class RunTasks extends RecursiveAction {
        ForkJoinTask<?>[] tasks;
        RunTasks(ForkJoinTask<?>[] ts) { tasks = ts; }
        public void compute() { invokeAll(tasks); }
    }
    
    static void oneRun(BlockingQueue<Integer> q, int n, int iters) throws Exception {
        ForkJoinTask<?>[] tasks = (ForkJoinTask<?>[])new ForkJoinTask[n * 2];
        int j = 0;
        for (int i = 0; i < n; ++i) {
            //            BlockingQueue<Integer> q = new SynchronousQueue<Integer>(true);
            tasks[j++] = ForkJoinTask.adapt(new Producer(q, iters));
            tasks[j++] = ForkJoinTask.adapt(new Consumer(q, iters));
        }
        long startTime = System.nanoTime();
        pool.invoke(new RunTasks(tasks));
        long time = System.nanoTime() - startTime;
        checkSum();
        if (print)
            System.out.println("\t: " + LoopHelpers.rightJustify(time / (iters * n)) + " ns per transfer");
    }

}
