--- jsr166/src/test/tck/ThreadLocalTest.java 2003/12/27 19:26:44 1.4 +++ jsr166/src/test/tck/ThreadLocalTest.java 2004/01/10 20:37:20 1.5 @@ -18,24 +18,89 @@ public class ThreadLocalTest extends JSR return new TestSuite(ThreadLocalTest.class); } - static ThreadLocal tl = new ThreadLocal() { - public Object initialValue() { - return new Integer(1); + static ThreadLocal tl = new ThreadLocal() { + public Integer initialValue() { + return one; + } + }; + + static InheritableThreadLocal itl = + new InheritableThreadLocal() { + protected Integer initialValue() { + return zero; + } + + protected Integer childValue(Integer parentValue) { + return new Integer(parentValue.intValue() + 1); } }; - /** * remove causes next access to return initial value */ public void testRemove() { - Integer one = new Integer(1); - Integer two = new Integer(2); assertEquals(tl.get(), one); tl.set(two); assertEquals(tl.get(), two); tl.remove(); assertEquals(tl.get(), one); } + + /** + * remove in InheritableThreadLocal causes next access to return + * initial value + */ + public void testRemoveITL() { + assertEquals(itl.get(), zero); + itl.set(two); + assertEquals(itl.get(), two); + itl.remove(); + assertEquals(itl.get(), zero); + } + + private class ITLThread extends Thread { + final int[] x; + ITLThread(int[] array) { x = array; } + public void run() { + Thread child = null; + if (itl.get().intValue() < x.length - 1) { + child = new ITLThread(x); + child.start(); + } + Thread.currentThread().yield(); + + int threadId = itl.get().intValue(); + for (int j = 0; j < threadId; j++) { + x[threadId]++; + Thread.currentThread().yield(); + } + + if (child != null) { // Wait for child (if any) + try { + child.join(); + } catch(InterruptedException e) { + threadUnexpectedException(); + } + } + } + } + + /** + * InheritableThreadLocal propagates generic values. + */ + public void testGenericITL() { + final int threadCount = 10; + final int x[] = new int[threadCount]; + Thread progenitor = new ITLThread(x); + try { + progenitor.start(); + progenitor.join(); + for(int i = 0; i < threadCount; i++) { + assertEquals(i, x[i]); + } + } catch(InterruptedException e) { + unexpectedException(); + } + } }