
import java.net.URL;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.RecursiveAction;

public class FJChains2 {
    public static final int CHAINS = 2;
    public static final long LENGTH = 20_000_000;
    public static final int PARALLELISM = 2; // Runtime.getRuntime().availableProcessors()

    private static final int W = 20;
    static long[] array = new long[CHAINS * W];
    static CountDownLatch latch;

    static ForkJoinPool fjp = new ForkJoinPool(PARALLELISM, ForkJoinPool.defaultForkJoinWorkerThreadFactory, null, true);

    public static void main(String[] args) throws Exception {
        System.out.println("CHAINS: " + CHAINS);
        System.out.println("LENGTH: " + LENGTH);
        System.out.println("PARALLELISM: " + fjp.getParallelism());
        System.out.println("ForkJoinPool: " + whereIs(ForkJoinPool.class));
        
        for (int i = 0; i < 5; i++)
            run();
        
        System.out.println("========");
        starvation();
    }

    static void run() throws Exception {
        latch = new CountDownLatch(CHAINS);
        long start = System.nanoTime();

        for (int i = 0; i < CHAINS; i++) {
            array[i * W] = LENGTH;
            fjp.submit(new Action(i));
        }

        latch.await();
        System.out.println((System.nanoTime() - start) / 1_000_000);
    }

    static void starvation() throws Exception {
        latch = new CountDownLatch(CHAINS);
        final CountDownLatch latch2 = new CountDownLatch(1);
        long start = System.nanoTime();

        for (int i = 0; i < CHAINS; i++) {
            array[i * W] = LENGTH;
            fjp.submit(new Action(i));
        }

        Thread.sleep(10);
        fjp.submit(new RecursiveAction() {
            @Override
            protected void compute() {
                latch2.countDown();
            }
        });

        latch2.await();
        System.out.println("Innocent: " + (System.nanoTime() - start) / 1_000_000);
        latch.await();
        System.out.println((System.nanoTime() - start) / 1_000_000);
    }

    static class Action extends RecursiveAction {
        final int chain;

        public Action(int chain) {
            this.chain = chain;
        }

        @Override
        protected void compute() {
            if (array[chain * W]-- <= 0)
                latch.countDown();
            else
                new Action(chain).fork();
        }
    }

    public static String whereIs(Class<?> clazz) {
        if (clazz == null)
            return null;
        final String resource = clazz.getName().replace('.', '/') + ".class";
        URL url = clazz.getResource(resource);
        if (url == null)
            url = (clazz.getClassLoader() != null ? clazz.getClassLoader() : ClassLoader.getSystemClassLoader()).getResource(resource);
        return url != null ? url.toString() : null;
    }
}
