/* * Written by Doug Lea with assistance from members of JCP JSR-166 * Expert Group and released to the public domain, as explained at * http://creativecommons.org/publicdomain/zero/1.0/ */ //import jsr166y.*; import java.util.concurrent.*; /** * Divide and Conquer matrix multiply demo */ public class MatrixMultiply { /** for time conversion */ static final long NPS = (1000L * 1000 * 1000); static final int DEFAULT_GRANULARITY = 16; // 32; /** * The quadrant size at which to stop recursing down * and instead directly multiply the matrices. * Must be a power of two. Minimum value is 2. */ static int granularity = DEFAULT_GRANULARITY; public static void main(String[] args) throws Exception { final String usage = "Usage: java MatrixMultiply [] \n Size and granularity must be powers of two.\n For example, try java MatrixMultiply 2 512 16"; int procs = 0; int n = 2048; int runs = 32; try { if (args.length > 0) procs = Integer.parseInt(args[0]); if (args.length > 1) n = Integer.parseInt(args[1]); if (args.length > 2) granularity = Integer.parseInt(args[2]); if (args.length > 3) runs = Integer.parseInt(args[2]); } catch (Exception e) { System.out.println(usage); return; } if ( ((n & (n - 1)) != 0) || ((granularity & (granularity - 1)) != 0) || granularity < 2) { System.out.println(usage); return; } ForkJoinPool pool = (procs == 0) ? ForkJoinPool.commonPool() : new ForkJoinPool(procs); System.out.println("procs: " + pool.getParallelism() + " n: " + n + " granularity: " + granularity + " runs: " + runs); float[][] a = new float[n][n]; float[][] b = new float[n][n]; float[][] c = new float[n][n]; for (int i = 0; i < runs; ++i) { init(a, b, n); long start = System.nanoTime(); new Multiplier(a, 0, 0, b, 0, 0, c, 0, 0, n).invoke(); long time = System.nanoTime() - start; double secs = ((double)time) / NPS; Thread.sleep(100); System.out.printf("Time: %7.3f ", secs); if ((i & 3) == 3) System.out.println(); // check(c, n); } System.out.println(pool.toString()); if (pool != ForkJoinPool.commonPool()) pool.shutdown(); Thread.sleep(100); } // To simplify checking, fill with all 1's. Answer should be all n's. static void init(float[][] a, float[][] b, int n) { for (int i = 0; i < n; ++i) { for (int j = 0; j < n; ++j) { a[i][j] = 1.0F; b[i][j] = 1.0F; } } } static void check(float[][] c, int n) { for (int i = 0; i < n; i++ ) { for (int j = 0; j < n; j++ ) { if (c[i][j] != n) { throw new Error("Check Failed at [" + i +"]["+j+"]: " + c[i][j]); } } } } /** * Multiply matrices AxB by dividing into quadrants, using algorithm: *
     *      A      x      B
     *
     *  A11 | A12     B11 | B12     A11*B11 | A11*B12     A12*B21 | A12*B22
     * |----+----| x |----+----| = |--------+--------| + |---------+-------|
     *  A21 | A22     B21 | B21     A21*B11 | A21*B21     A22*B21 | A22*B22
     * 
*/ static class Multiplier extends RecursiveAction { final float[][] A; // Matrix A final int aRow; // first row of current quadrant of A final int aCol; // first column of current quadrant of A final float[][] B; // Similarly for B final int bRow; final int bCol; final float[][] C; // Similarly for result matrix C final int cRow; final int cCol; final int size; // number of elements in current quadrant Multiplier(float[][] A, int aRow, int aCol, float[][] B, int bRow, int bCol, float[][] C, int cRow, int cCol, int size) { this.A = A; this.aRow = aRow; this.aCol = aCol; this.B = B; this.bRow = bRow; this.bCol = bCol; this.C = C; this.cRow = cRow; this.cCol = cCol; this.size = size; } public void compute() { if (size <= granularity) { multiplyStride2(); } else { int h = size / 2; invokeAll(new Seq2[] { seq(new Multiplier(A, aRow, aCol, // A11 B, bRow, bCol, // B11 C, cRow, cCol, // C11 h), new Multiplier(A, aRow, aCol+h, // A12 B, bRow+h, bCol, // B21 C, cRow, cCol, // C11 h)), seq(new Multiplier(A, aRow, aCol, // A11 B, bRow, bCol+h, // B12 C, cRow, cCol+h, // C12 h), new Multiplier(A, aRow, aCol+h, // A12 B, bRow+h, bCol+h, // B22 C, cRow, cCol+h, // C12 h)), seq(new Multiplier(A, aRow+h, aCol, // A21 B, bRow, bCol, // B11 C, cRow+h, cCol, // C21 h), new Multiplier(A, aRow+h, aCol+h, // A22 B, bRow+h, bCol, // B21 C, cRow+h, cCol, // C21 h)), seq(new Multiplier(A, aRow+h, aCol, // A21 B, bRow, bCol+h, // B12 C, cRow+h, cCol+h, // C22 h), new Multiplier(A, aRow+h, aCol+h, // A22 B, bRow+h, bCol+h, // B22 C, cRow+h, cCol+h, // C22 h)) }); } } /** * Version of matrix multiplication that steps 2 rows and columns * at a time. Adapted from Cilk demos. * Note that the results are added into C, not just set into C. * This works well here because Java array elements * are created with all zero values. */ void multiplyStride2() { for (int j = 0; j < size; j+=2) { for (int i = 0; i < size; i +=2) { float[] a0 = A[aRow+i]; float[] a1 = A[aRow+i+1]; float s00 = 0.0F; float s01 = 0.0F; float s10 = 0.0F; float s11 = 0.0F; for (int k = 0; k < size; k+=2) { float[] b0 = B[bRow+k]; s00 += a0[aCol+k] * b0[bCol+j]; s10 += a1[aCol+k] * b0[bCol+j]; s01 += a0[aCol+k] * b0[bCol+j+1]; s11 += a1[aCol+k] * b0[bCol+j+1]; float[] b1 = B[bRow+k+1]; s00 += a0[aCol+k+1] * b1[bCol+j]; s10 += a1[aCol+k+1] * b1[bCol+j]; s01 += a0[aCol+k+1] * b1[bCol+j+1]; s11 += a1[aCol+k+1] * b1[bCol+j+1]; } C[cRow+i] [cCol+j] += s00; C[cRow+i] [cCol+j+1] += s01; C[cRow+i+1][cCol+j] += s10; C[cRow+i+1][cCol+j+1] += s11; } } } } static Seq2 seq(RecursiveAction task1, RecursiveAction task2) { return new Seq2(task1, task2); } static final class Seq2 extends RecursiveAction { final RecursiveAction fst; final RecursiveAction snd; public Seq2(RecursiveAction task1, RecursiveAction task2) { fst = task1; snd = task2; } public void compute() { fst.invoke(); snd.invoke(); } } }