ViewVC Help
View File | Revision Log | Show Annotations | Download File | Root Listing
root/jsr166/jsr166/src/test/loops/FJSums.java
Revision: 1.7
Committed: Wed Jul 4 20:07:01 2012 UTC (11 years, 10 months ago) by jsr166
Branch: MAIN
Changes since 1.6: +0 -2 lines
Log Message:
trailing newlines

File Contents

# Content
1 /*
2 * Written by Doug Lea with assistance from members of JCP JSR-166
3 * Expert Group and released to the public domain, as explained at
4 * http://creativecommons.org/publicdomain/zero/1.0/
5 */
6
7 import java.util.*;
8 import java.util.concurrent.*;
9 import java.util.concurrent.atomic.*;
10
11 // parallel sums and cumulations
12
13 public class FJSums {
14 static final long NPS = (1000L * 1000 * 1000);
15 static int THRESHOLD;
16
17 public static void main(String[] args) throws Exception {
18 int procs = 0;
19 int n = 1 << 25;
20 int reps = 10;
21 try {
22 if (args.length > 0)
23 procs = Integer.parseInt(args[0]);
24 if (args.length > 1)
25 n = Integer.parseInt(args[1]);
26 if (args.length > 2)
27 reps = Integer.parseInt(args[2]);
28 }
29 catch (Exception e) {
30 System.out.println("Usage: java FJSums threads n reps");
31 return;
32 }
33 ForkJoinPool g = (procs == 0) ? new ForkJoinPool() :
34 new ForkJoinPool(procs);
35 System.out.println("Number of procs=" + g.getParallelism());
36 // for now hardwire Cumulate threshold to 8 * #CPUs leaf tasks
37 THRESHOLD = 1 + ((n + 7) >>> 3) / g.getParallelism();
38
39 long[] a = new long[n];
40 for (int i = 0; i < n; ++i)
41 a[i] = i;
42 long expected = ((long)n * (long)(n - 1)) / 2;
43 for (int i = 0; i < 2; ++i) {
44 System.out.print("Seq: ");
45 long last = System.nanoTime();
46 long ss = seqSum(a, 0, n);
47 double elapsed = elapsedTime(last);
48 System.out.printf("sum = %24d time: %7.3f\n", ss, elapsed);
49 if (ss != expected)
50 throw new Error("expected " + expected + " != " + ss);
51 }
52 for (int i = 0; i < reps; ++i) {
53 System.out.print("Par: ");
54 long last = System.nanoTime();
55 Summer s = new Summer(a, 0, a.length, null);
56 g.invoke(s);
57 long ss = s.result;
58 double elapsed = elapsedTime(last);
59 System.out.printf("sum = %24d time: %7.3f\n", ss, elapsed);
60 if (i == 0 && ss != expected)
61 throw new Error("expected " + expected + " != " + ss);
62 System.out.print("Cum: ");
63 last = System.nanoTime();
64 g.invoke(new Cumulater(null, a, 0, n));
65 long sc = a[n - 1];
66 elapsed = elapsedTime(last);
67 System.out.printf("sum = %24d time: %7.3f\n", ss, elapsed);
68 if (sc != ss)
69 throw new Error("expected " + ss + " != " + sc);
70 }
71 System.out.println(g);
72 g.shutdown();
73 }
74
75 static double elapsedTime(long startTime) {
76 return (double)(System.nanoTime() - startTime) / NPS;
77 }
78
79 static long seqSum(long[] array, int l, int h) {
80 long sum = 0;
81 for (int i = l; i < h; ++i)
82 sum += array[i];
83 return sum;
84 }
85
86 static long seqCumulate(long[] array, int lo, int hi, long base) {
87 long sum = base;
88 for (int i = lo; i < hi; ++i)
89 array[i] = sum += array[i];
90 return sum;
91 }
92
93 /**
94 * Adapted from Applyer demo in RecursiveAction docs
95 */
96 static final class Summer extends RecursiveAction {
97 final long[] array;
98 final int lo, hi;
99 long result;
100 Summer next; // keeps track of right-hand-side tasks
101 Summer(long[] array, int lo, int hi, Summer next) {
102 this.array = array; this.lo = lo; this.hi = hi;
103 this.next = next;
104 }
105
106 protected void compute() {
107 int l = lo;
108 int h = hi;
109 Summer right = null;
110 while (h - l > 1 && getSurplusQueuedTaskCount() <= 3) {
111 int mid = (l + h) >>> 1;
112 right = new Summer(array, mid, h, right);
113 right.fork();
114 h = mid;
115 }
116 long sum = seqSum(array, l, h);
117 while (right != null) {
118 if (right.tryUnfork()) // directly calculate if not stolen
119 sum += seqSum(array, right.lo, right.hi);
120 else {
121 right.join();
122 sum += right.result;
123 }
124 right = right.next;
125 }
126 result = sum;
127 }
128 }
129
130 /**
131 * Cumulative scan, adapted from ParallelArray code
132 *
133 * A basic version of scan is straightforward.
134 * Keep dividing by two to threshold segment size, and then:
135 * Pass 1: Create tree of partial sums for each segment
136 * Pass 2: For each segment, cumulate with offset of left sibling
137 * See G. Blelloch's http://www.cs.cmu.edu/~scandal/alg/scan.html
138 *
139 * This version improves performance within FJ framework mainly by
140 * allowing second pass of ready left-hand sides to proceed even
141 * if some right-hand side first passes are still executing. It
142 * also combines first and second pass for leftmost segment, and
143 * for cumulate (not precumulate) also skips first pass for
144 * rightmost segment (whose result is not needed for second pass).
145 *
146 * To manage this, it relies on "phase" phase/state control field
147 * maintaining bits CUMULATE, SUMMED, and FINISHED. CUMULATE is
148 * main phase bit. When false, segments compute only their sum.
149 * When true, they cumulate array elements. CUMULATE is set at
150 * root at beginning of second pass and then propagated down. But
151 * it may also be set earlier for subtrees with lo==0 (the
152 * left spine of tree). SUMMED is a one bit join count. For leafs,
153 * set when summed. For internal nodes, becomes true when one
154 * child is summed. When second child finishes summing, it then
155 * moves up tree to trigger cumulate phase. FINISHED is also a one
156 * bit join count. For leafs, it is set when cumulated. For
157 * internal nodes, it becomes true when one child is cumulated.
158 * When second child finishes cumulating, it then moves up tree,
159 * executing complete() at the root.
160 *
161 */
162 static final class Cumulater extends ForkJoinTask<Void> {
163 static final short CUMULATE = (short)1;
164 static final short SUMMED = (short)2;
165 static final short FINISHED = (short)4;
166
167 final Cumulater parent;
168 final long[] array;
169 Cumulater left, right;
170 final int lo;
171 final int hi;
172 volatile int phase; // phase/state
173 long in, out; // initially zero
174
175 static final AtomicIntegerFieldUpdater<Cumulater> phaseUpdater =
176 AtomicIntegerFieldUpdater.newUpdater(Cumulater.class, "phase");
177
178 Cumulater(Cumulater parent, long[] array, int lo, int hi) {
179 this.parent = parent;
180 this.array = array;
181 this.lo = lo;
182 this.hi = hi;
183 }
184
185 public final Void getRawResult() { return null; }
186 protected final void setRawResult(Void mustBeNull) { }
187
188 /** Returns true if can CAS CUMULATE bit true */
189 final boolean transitionToCumulate() {
190 int c;
191 while (((c = phase) & CUMULATE) == 0)
192 if (phaseUpdater.compareAndSet(this, c, c | CUMULATE))
193 return true;
194 return false;
195 }
196
197 public final boolean exec() {
198 if (hi - lo > THRESHOLD) {
199 if (left == null) { // first pass
200 int mid = (lo + hi) >>> 1;
201 left = new Cumulater(this, array, lo, mid);
202 right = new Cumulater(this, array, mid, hi);
203 }
204
205 boolean cumulate = (phase & CUMULATE) != 0;
206 if (cumulate) {
207 long pin = in;
208 left.in = pin;
209 right.in = pin + left.out;
210 }
211
212 if (!cumulate || right.transitionToCumulate())
213 right.fork();
214 if (!cumulate || left.transitionToCumulate())
215 left.exec();
216 }
217 else {
218 int cb;
219 for (;;) { // Establish action: sum, cumulate, or both
220 int b = phase;
221 if ((b & FINISHED) != 0) // already done
222 return false;
223 if ((b & CUMULATE) != 0)
224 cb = FINISHED;
225 else if (lo == 0) // combine leftmost
226 cb = (SUMMED|FINISHED);
227 else
228 cb = SUMMED;
229 if (phaseUpdater.compareAndSet(this, b, b|cb))
230 break;
231 }
232
233 if (cb == SUMMED)
234 out = seqSum(array, lo, hi);
235 else if (cb == FINISHED)
236 seqCumulate(array, lo, hi, in);
237 else if (cb == (SUMMED|FINISHED))
238 out = seqCumulate(array, lo, hi, 0L);
239
240 // propagate up
241 Cumulater ch = this;
242 Cumulater par = parent;
243 for (;;) {
244 if (par == null) {
245 if ((cb & FINISHED) != 0)
246 ch.complete(null);
247 break;
248 }
249 int pb = par.phase;
250 if ((pb & cb & FINISHED) != 0) { // both finished
251 ch = par;
252 par = par.parent;
253 }
254 else if ((pb & cb & SUMMED) != 0) { // both summed
255 par.out = par.left.out + par.right.out;
256 int refork =
257 ((pb & CUMULATE) == 0 &&
258 par.lo == 0) ? CUMULATE : 0;
259 int nextPhase = pb|cb|refork;
260 if (pb == nextPhase ||
261 phaseUpdater.compareAndSet(par, pb, nextPhase)) {
262 if (refork != 0)
263 par.fork();
264 cb = SUMMED; // drop finished bit
265 ch = par;
266 par = par.parent;
267 }
268 }
269 else if (phaseUpdater.compareAndSet(par, pb, pb|cb))
270 break;
271 }
272 }
273 return false;
274 }
275
276 }
277
278 }