/*
 * Written by Doug Lea and released to the public domain, as explained at
 * http://creativecommons.org/licenses/publicdomain
 */

import java.util.concurrent.atomic.*;
import java.util.concurrent.*;
import java.util.concurrent.locks.*;
import java.util.*;

public class RNGLoopsV3 {

    static final int TRIALS = 2;
    static final long BASE_SECS_PER_RUN = 3;
    static final int NCPUS = Runtime.getRuntime().availableProcessors();
    static int maxThreads = 1;

    static long[] loopIters;

    public static void main(String[] args) throws Exception {
        if (args.length > 0) 
            maxThreads = Integer.parseInt(args[0]);

        loopIters = new long[maxThreads+1];

        System.out.println("Warmup...");
        for (int j = 0; j < 2; ++j) {
            for (int i = 1; i <= maxThreads; ++i) {
                runCalibration(i, 1000);
                oneRun(i, loopIters[i] / 8, false);
                System.out.print(".");
            }
        }

        for (int i = 1; i <= maxThreads; ++i) 
            loopIters[i] = 0;

        for (int j = 0; j < TRIALS; ++j) {
            System.out.println("Trial " + j);
            for (int i = 1; i <= maxThreads; ++i) {
                runCalibration(i, BASE_SECS_PER_RUN * 1000L);
                oneRun(i, loopIters[i], true);
            }
        }
    }

    static final AtomicLong totalIters = new AtomicLong(0);
    static final AtomicInteger sum = new AtomicInteger(0);


    // All these versions are copy-paste-hacked to avoid 
    // contamination with virtual call resolution etc.

    // Use fixed-length unrollable inner loops to reduce safepoint checks
    static final int innerPerOuter = 16;

    static final class RLoop implements Runnable {
        final long iters;
        final Random rng;
        final CyclicBarrier barrier;
        RLoop(long iters, Random rng, CyclicBarrier b) {
            this.iters = iters;
            this.rng = rng;
            this.barrier = b;
        }

        public void run() {
            try {
                barrier.await(); 
                int i = (int)iters;
                int y = 0;
                int succ = 0;
                while (i-- > 0) {
                    y += rng.nextInt();
                }
                sum.getAndAdd(y);
                barrier.await();
            }
            catch (Exception ie) { 
                return; 
            }
        }
    }

    static final int loopsPerTimeCheck = 2048;

    static final class NACalibrationLoop implements Runnable {
        final long endTime;
        final Random rng;
        final CyclicBarrier barrier;
        NACalibrationLoop(long endTime, Random rng, CyclicBarrier b) {
            this.endTime = endTime;
            this.rng = rng;
            this.barrier = b;
        }

        public void run() {
            try {
                barrier.await(); 
                long iters = 0;
                int y = 0;
                int succ = 0;
                do {
                    int i = loopsPerTimeCheck;
                    while (i > 0) {
                        for (int k = 0; k < innerPerOuter; ++k) {
                            y += rng.nextInt();
                        }
                        i -= innerPerOuter;
                    }
                    iters += loopsPerTimeCheck;
                } while (System.currentTimeMillis() < endTime);
                totalIters.getAndAdd(iters);
                sum.getAndAdd(y);
                barrier.await();
            }
            catch (Exception ie) { 
                return; 
            }
        }
    }

    static void runCalibration(int n, long nms) throws Exception {
        long now = System.currentTimeMillis();
        long endTime = now + nms;
        CyclicBarrier b = new CyclicBarrier(n+1);
        totalIters.set(0);
        Random a = new UnsynchedRandom();
        for (int j = 0; j < n; ++j) 
            new Thread(new NACalibrationLoop(endTime, a, b)).start();
        b.await();
        b.await();
        long ipt = totalIters.get() / n;
        if (ipt > loopIters[n])
            loopIters[n] = ipt;
        if (sum.get() == 0) System.out.print(" ");
    }

    static long runR(int n, long iters, Random r) throws Exception {
        LoopHelpers.BarrierTimer timer = new LoopHelpers.BarrierTimer();
        CyclicBarrier b = new CyclicBarrier(n+1, timer);
        for (int j = 0; j < n; ++j) 
            new Thread(new RLoop(iters, r, b)).start();
        b.await();
        b.await();
        if (sum.get() == 0) System.out.print(" ");
        return timer.getTime();
    }

    static void report(String tag, long runtime, long basetime, 
                       int nthreads, long iters) {
        System.out.print(tag);
        long t = (runtime) / iters;
        if (nthreads > NCPUS)
            t = t * NCPUS / nthreads;
        System.out.print(LoopHelpers.rightJustify(t));
        double secs = (double)(runtime) / 1000000000.0;
        System.out.println("\t " + secs + "s run time");
    }
        

    static void oneRun(int i, long iters, boolean print) throws Exception {
        if (print) 
            System.out.println("threads : " + i + 
                               " base iters per thread per run : " + 
                               LoopHelpers.rightJustify(loopIters[i]));
        long ntime = 0;
        if (true) {
            ntime = runR(i,  iters, XorShiftRandomFactory.xorShift4());
            if (print)
                report("X4          : ", ntime, ntime, i, iters);
            Thread.sleep(100L);

            ntime = runR(i,  iters, XorShiftRandomFactory.xorShift1());
            if (print)
                report("X1          : ", ntime, ntime, i, iters);
            Thread.sleep(100L);

            ntime = runR(i,  iters, XorShiftRandomFactory.xorShift8());
            if (print)
                report("X8          : ", ntime, ntime, i, iters);
            Thread.sleep(100L);

            long stime = runR(i, iters, new UnsynchedRandom());
            if (print) 
                report("No Synch    : ", stime, ntime, i, iters);
            Thread.sleep(100L);
        }
        if (false) {
            long vtime = runR(i, iters, new Random());
            if (print)
                report("JURandom    : ", vtime, ntime, i, iters);
        }
    }

    public static class UnsynchedRandom extends java.util.Random {
        private final static long multiplier = 0x5DEECE66DL;
        private final static long addend = 0xBL;
        private final static long mask = (1L << 48) - 1;
        private long seed = System.nanoTime();

        UnsynchedRandom() { super(); }
        UnsynchedRandom(long seed) { super(seed); }
        public void setSeed(long seed) {
            super.setSeed(seed);
            this.seed = seed;
        }

        protected final int next(int bits) {  
            long nextseed = (seed * multiplier + addend) & mask;
            seed = nextseed;
            return (int)(nextseed >>> (48 - bits));
        }
    }


}
