--- jsr166/src/jsr166y/Phaser.java 2010/11/17 10:48:59 1.56 +++ jsr166/src/jsr166y/Phaser.java 2010/11/19 16:03:24 1.57 @@ -245,8 +245,7 @@ public class Phaser { private static final int PARTIES_SHIFT = 16; private static final int PHASE_SHIFT = 32; private static final int UNARRIVED_MASK = 0xffff; - private static final int PARTIES_MASK = 0xffff0000; - private static final long LPARTIES_MASK = 0xffff0000L; // long version + private static final long PARTIES_MASK = 0xffff0000L; // for masking long private static final long ONE_ARRIVAL = 1L; private static final long ONE_PARTY = 1L << PARTIES_SHIFT; private static final long TERMINATION_PHASE = -1L << PHASE_SHIFT; @@ -304,30 +303,31 @@ public class Phaser { */ private int doArrive(long adj) { for (;;) { - long s; - int phase, unarrived; - if ((phase = (int)((s = state) >>> PHASE_SHIFT)) < 0) + long s = state; + int phase = (int)(s >>> PHASE_SHIFT); + if (phase < 0) return phase; - else if ((unarrived = (int)s & UNARRIVED_MASK) == 0) + int unarrived = (int)s & UNARRIVED_MASK; + if (unarrived == 0) checkBadArrive(s); else if (UNSAFE.compareAndSwapLong(this, stateOffset, s, s-=adj)) { if (unarrived == 1) { - Phaser par; - long p = s & LPARTIES_MASK; // unshifted parties field + long p = s & PARTIES_MASK; // unshifted parties field long lu = p >>> PARTIES_SHIFT; int u = (int)lu; int nextPhase = (phase + 1) & MAX_PHASE; long next = ((long)nextPhase << PHASE_SHIFT) | p | lu; - if ((par = parent) == null) { + final Phaser parent = this.parent; + if (parent == null) { if (onAdvance(phase, u)) next |= TERMINATION_PHASE; // obliterate phase UNSAFE.compareAndSwapLong(this, stateOffset, s, next); releaseWaiters(phase); } else { - par.doArrive(u == 0? - ONE_ARRIVAL|ONE_PARTY : ONE_ARRIVAL); - if ((int)(par.state >>> PHASE_SHIFT) != nextPhase || + parent.doArrive((u == 0) ? + ONE_ARRIVAL|ONE_PARTY : ONE_ARRIVAL); + if ((int)(parent.state >>> PHASE_SHIFT) != nextPhase || ((int)(state >>> PHASE_SHIFT) != nextPhase && !UNSAFE.compareAndSwapLong(this, stateOffset, s, next))) @@ -356,18 +356,19 @@ public class Phaser { * @param registrations number to add to both parties and unarrived fields */ private int doRegister(int registrations) { - long adj = (long)registrations; // adjustment to state - adj |= adj << PARTIES_SHIFT; - Phaser par = parent; + // assert registrations > 0; + // adjustment to state + long adj = ((long)registrations << PARTIES_SHIFT) | registrations; + final Phaser parent = this.parent; for (;;) { - int phase, parties; - long s = par == null? state : reconcileState(); - if ((phase = (int)(s >>> PHASE_SHIFT)) < 0) + long s = (parent == null) ? state : reconcileState(); + int phase = (int)(s >>> PHASE_SHIFT); + if (phase < 0) return phase; - if ((parties = (int)s >>> PARTIES_SHIFT) != 0 && - ((int)s & UNARRIVED_MASK) == 0) + int parties = (int)s >>> PARTIES_SHIFT; + if (parties != 0 && ((int)s & UNARRIVED_MASK) == 0) internalAwaitAdvance(phase, null); // wait for onAdvance - else if (parties + registrations > MAX_PARTIES) + else if (registrations > MAX_PARTIES - parties) throw new IllegalStateException(badRegister(s)); else if (UNSAFE.compareAndSwapLong(this, stateOffset, s, s + adj)) return phase; @@ -387,29 +388,27 @@ public class Phaser { */ private long reconcileState() { Phaser par = parent; - if (par == null) - return state; - Phaser rt = root; - for (;;) { - long s, u; - int phase, rPhase, pPhase; - if ((phase = (int)((s = state)>>> PHASE_SHIFT)) < 0 || - (rPhase = (int)(rt.state >>> PHASE_SHIFT)) == phase) - return s; - long pState = par.parent == null? par.state : par.reconcileState(); - if (state == s) { - if ((rPhase < 0 || ((int)s & UNARRIVED_MASK) == 0) && - ((pPhase = (int)(pState >>> PHASE_SHIFT)) < 0 || - pPhase == ((phase + 1) & MAX_PHASE))) - UNSAFE.compareAndSwapLong - (this, stateOffset, s, - (((long) pPhase) << PHASE_SHIFT) | - (u = s & LPARTIES_MASK) | - (u >>> PARTIES_SHIFT)); // reset unarrived to parties - else - releaseWaiters(phase); // help release others + long s = state; + if (par != null) { + Phaser rt = root; + int phase, rPhase; + while ((phase = (int)(s >>> PHASE_SHIFT)) >= 0 && + (rPhase = (int)(rt.state >>> PHASE_SHIFT)) != phase) { + if ((int)(par.state >>> PHASE_SHIFT) != rPhase) + par.reconcileState(); + else if (rPhase < 0 || ((int)s & UNARRIVED_MASK) == 0) { + long u = s & PARTIES_MASK; // reset unarrived to parties + long next = ((((long) rPhase) << PHASE_SHIFT) | u | + (u >>> PARTIES_SHIFT)); + if (state == s && + UNSAFE.compareAndSwapLong(this, stateOffset, + s, s = next)) + break; + } + s = state; } } + return s; } /** @@ -505,8 +504,6 @@ public class Phaser { public int bulkRegister(int parties) { if (parties < 0) throw new IllegalArgumentException(); - if (parties > MAX_PARTIES) - throw new IllegalStateException(badRegister(state)); if (parties == 0) return getPhase(); return doRegister(parties); @@ -573,14 +570,11 @@ public class Phaser { * if terminated or argument is negative */ public int awaitAdvance(int phase) { - int p; if (phase < 0) return phase; - else if ((p = (int)((parent == null? state : reconcileState()) - >>> PHASE_SHIFT)) == phase) - return internalAwaitAdvance(phase, null); - else - return p; + long s = (parent == null) ? state : reconcileState(); + int p = (int)(s >>> PHASE_SHIFT); + return (p != phase) ? p : internalAwaitAdvance(phase, null); } /** @@ -599,11 +593,11 @@ public class Phaser { */ public int awaitAdvanceInterruptibly(int phase) throws InterruptedException { - int p; if (phase < 0) return phase; - if ((p = (int)((parent == null? state : reconcileState()) - >>> PHASE_SHIFT)) == phase) { + long s = (parent == null) ? state : reconcileState(); + int p = (int)(s >>> PHASE_SHIFT); + if (p == phase) { QNode node = new QNode(this, phase, true, false, 0L); p = internalAwaitAdvance(phase, node); if (node.wasInterrupted) @@ -635,12 +629,12 @@ public class Phaser { public int awaitAdvanceInterruptibly(int phase, long timeout, TimeUnit unit) throws InterruptedException, TimeoutException { - long nanos = unit.toNanos(timeout); - int p; if (phase < 0) return phase; - if ((p = (int)((parent == null? state : reconcileState()) - >>> PHASE_SHIFT)) == phase) { + long s = (parent == null) ? state : reconcileState(); + int p = (int)(s >>> PHASE_SHIFT); + if (p == phase) { + long nanos = unit.toNanos(timeout); QNode node = new QNode(this, phase, true, true, nanos); p = internalAwaitAdvance(phase, node); if (node.wasInterrupted) @@ -682,7 +676,7 @@ public class Phaser { * @return the phase number, or a negative value if terminated */ public final int getPhase() { - return (int)((parent==null? state : reconcileState()) >>> PHASE_SHIFT); + return (int)(root.state >>> PHASE_SHIFT); } /** @@ -691,7 +685,7 @@ public class Phaser { * @return the number of parties */ public int getRegisteredParties() { - return partiesOf(parent==null? state : reconcileState()); + return partiesOf(state); } /** @@ -739,7 +733,7 @@ public class Phaser { * @return {@code true} if this barrier has been terminated */ public boolean isTerminated() { - return (parent == null? state : reconcileState()) < 0; + return root.state < 0L; } /** @@ -803,7 +797,7 @@ public class Phaser { // Waiting mechanics /** - * Removes and signals threads from queue for phase + * Removes and signals threads from queue for phase. */ private void releaseWaiters(int phase) { AtomicReference head = queueFor(phase); @@ -817,20 +811,6 @@ public class Phaser { } } - /** - * Tries to enqueue given node in the appropriate wait queue. - * - * @return true if successful - */ - private boolean tryEnqueue(int phase, QNode node) { - releaseWaiters(phase-1); // ensure old queue clean - AtomicReference head = queueFor(phase); - QNode q = head.get(); - return ((q == null || q.phase == phase) && - (int)(root.state >>> PHASE_SHIFT) == phase && - head.compareAndSet(node.next = q, node)); - } - /** The number of CPUs, for spin control */ private static final int NCPU = Runtime.getRuntime().availableProcessors(); @@ -862,26 +842,22 @@ public class Phaser { boolean queued = false; // true when node is enqueued int lastUnarrived = -1; // to increase spins upon change int spins = SPINS_PER_ARRIVAL; - for (;;) { - int p, unarrived; + long s; + int p; + while ((p = (int)((s = current.state) >>> PHASE_SHIFT)) == phase) { Phaser par; - long s = current.state; - if ((p = (int)(s >>> PHASE_SHIFT)) != phase) { - if (node != null) - node.onRelease(); - releaseWaiters(phase); - return p; + int unarrived = (int)s & UNARRIVED_MASK; + if (unarrived != lastUnarrived) { + if (lastUnarrived == -1) // ensure old queue clean + releaseWaiters(phase-1); + if ((lastUnarrived = unarrived) < NCPU) + spins += SPINS_PER_ARRIVAL; } - else if ((unarrived = (int)s & UNARRIVED_MASK) == 0 && - (par = current.parent) != null) { + else if (unarrived == 0 && (par = current.parent) != null) { current = par; // if all arrived, use parent par = par.parent; lastUnarrived = -1; } - else if (unarrived != lastUnarrived) { - if ((lastUnarrived = unarrived) < NCPU) - spins += SPINS_PER_ARRIVAL; - } else if (spins > 0) { if (--spins == (SPINS_PER_ARRIVAL >>> 1)) Thread.yield(); // yield midway through spin @@ -889,11 +865,22 @@ public class Phaser { else if (node == null) // must be noninterruptible node = new QNode(this, phase, false, false, 0L); else if (node.isReleasable()) { - if ((int)(reconcileState() >>> PHASE_SHIFT) == phase) + if ((p = (int)(root.state >>> PHASE_SHIFT)) != phase) + break; + else return phase; // aborted } - else if (!queued) - queued = tryEnqueue(phase, node); + else if (!queued) { // push onto queue + AtomicReference head = queueFor(phase); + QNode q = head.get(); + if (q == null || q.phase == phase) { + node.next = q; + if ((p = (int)(root.state >>> PHASE_SHIFT)) != phase) + break; // recheck to avoid stale enqueue + else + queued = head.compareAndSet(q, node); + } + } else { try { ForkJoinPool.managedBlock(node); @@ -902,6 +889,10 @@ public class Phaser { } } } + releaseWaiters(phase); + if (node != null) + node.onRelease(); + return p; } /**