之前我们通过讨论ReentrantLock学习到了AQS的核心、公平与非公平锁的实现以及Condition的实现原理。但是之前所涉及到的都是非共享锁,也就是独占锁。今天我们来讨论基于AQS的共享模式实现的CountDownLatch组件。 本文大体上会分为两部分进行讨论。第一部分为介绍CountDownLatch的使用,第二部分将通过源码来分析CountDownLatch的实现原理。
1. CountDownLatch的使用 CountDownLatch是一个使用频率非常高的类, 是AQS共享模式的典型应用。它的名字翻译过来为:倒计时门闩。具体的使用如下所示:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 class Driver2 { void main () throws InterruptedException { CountDownLatch doneSignal = new CountDownLatch (N); Executor e = Executors.newFixedThreadPool(8 ); for (int i = 0 ; i < N; ++i) e.execute(new WorkerRunnable (doneSignal, i)); doneSignal.await(); } } class WorkerRunnable implements Runnable { private final CountDownLatch doneSignal; private final int i; WorkerRunnable(CountDownLatch doneSignal, int i) { this .doneSignal = doneSignal; this .i = i; } public void run () { try { doWork(i); doneSignal.countDown(); } catch (InterruptedException ex) { } } void doWork () { ...} }
以上代码是CountDownLatch源码中JavaDoc中的示例代码。不难理解,代码中的逻辑为创建了一个线程池用于执行任务。主线程等待,直到N任务全部都执行,再对主线程放行。
因此,我们可以得知它的使用场景可以是讲一个任务拆分为多个任务,让多个线程来并行执行,直到所有任务完成后,再向下执行。但是这个例子并没有完全展示出CountDownLatch的特性。接下来再看一个例子:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 class Driver { void main () throws InterruptedException { CountDownLatch startSignal = new CountDownLatch (1 ); CountDownLatch doneSignal = new CountDownLatch (N); for (int i = 0 ; i < N; ++i) new Thread (new Worker (startSignal, doneSignal)).start(); doSomethingElse(); startSignal.countDown(); doSomethingElse(); doneSignal.await(); } } class Worker implements Runnable { private final CountDownLatch startSignal; private final CountDownLatch doneSignal; Worker(CountDownLatch startSignal, CountDownLatch doneSignal) { this .startSignal = startSignal; this .doneSignal = doneSignal; } public void run () { try { startSignal.await(); doWork(); doneSignal.countDown(); } catch (InterruptedException ex) { } } void doWork () { ...} }
以上代码中,整体逻辑是先等所有的线程启动后,再开始执行任务,然后等所有任务都执行完了,main线程再继续向下执行。可以理解为,有N个线程被一个栅栏阻塞住,只有当通过条件达到了,再打开栅栏放行。注意,放行后,N个线程都被放行了。
有点类似于短跑比赛,首先等所有人准备好了再开始跑,等所有人跑完全程了才能结束比赛。
2. 源码分析 2.1 整体结构 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 public class CountDownLatch { private final Sync sync; public CountDownLatch (int count) { if (count < 0 ) throw new IllegalArgumentException ("count < 0" ); this .sync = new Sync (count); } public void await () throws InterruptedException { sync.acquireSharedInterruptibly(1 ); } public boolean await (long timeout, TimeUnit unit) throws InterruptedException { return sync.tryAcquireSharedNanos(1 , unit.toNanos(timeout)); } public void countDown () { sync.releaseShared(1 ); } public long getCount () { return sync.getCount(); } public String toString () { return super .toString() + "[Count = " + sync.getCount() + "]" ; } private static final class Sync extends AbstractQueuedSynchronizer { } }
可见CountDownLatch的源码并不多,包括Sync的源码满打满算也就300来行。不得不感叹,Doug Lea的设计能力,把设计模式的精髓发挥到极致,抽象出的AQS能简单快速的实现一个同步组件。
2.2 await() 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 public void await () throws InterruptedException { sync.acquireSharedInterruptibly(1 ); } public final void acquireSharedInterruptibly (int arg) throws InterruptedException { if (Thread.interrupted()) throw new InterruptedException (); if (tryAcquireShared(arg) < 0 ) doAcquireSharedInterruptibly(arg); } protected int tryAcquireShared (int acquires) { return (getState() == 0 ) ? 1 : -1 ; } private void doAcquireSharedInterruptibly (int arg) throws InterruptedException { final Node node = addWaiter(Node.SHARED); boolean failed = true ; try { for (;;) { final Node p = node.predecessor(); if (p == head) { int r = tryAcquireShared(arg); if (r >= 0 ) { setHeadAndPropagate(node, r); p.next = null ; failed = false ; return ; } } if (shouldParkAfterFailedAcquire(p, node) && parkAndCheckInterrupt()) throw new InterruptedException (); } } finally { if (failed) cancelAcquire(node); } }
2.3 countDown() 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 public void countDown () { sync.releaseShared(1 ); } public final boolean releaseShared (int arg) { if (tryReleaseShared(arg)) { doReleaseShared(); return true ; } return false ; } protected boolean tryReleaseShared (int releases) { for (;;) { int c = getState(); if (c == 0 ) return false ; int nextc = c-1 ; if (compareAndSetState(c, nextc)) return nextc == 0 ; } } private void doReleaseShared () { for (;;) { Node h = head; if (h != null && h != tail) { int ws = h.waitStatus; if (ws == Node.SIGNAL) { if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0 )) continue ; unparkSuccessor(h); } else if (ws == 0 && !compareAndSetWaitStatus(h, 0 , Node.PROPAGATE)) continue ; } if (h == head) break ; } } private void unparkSuccessor (Node node) { int ws = node.waitStatus; if (ws < 0 ) compareAndSetWaitStatus(node, ws, 0 ); Node s = node.next; if (s == null || s.waitStatus > 0 ) { s = null ; for (Node t = tail; t != null && t != node; t = t.prev) if (t.waitStatus <= 0 ) s = t; } if (s != null ) LockSupport.unpark(s.thread); }
根据以上代码中,unparkSuccessor(h)方法唤醒了后继节点的线程。现在再返回去看代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 private void doAcquireSharedInterruptibly (int arg) throws InterruptedException { final Node node = addWaiter(Node.SHARED); boolean failed = true ; try { for (;;) { final Node p = node.predecessor(); if (p == head) { int r = tryAcquireShared(arg); if (r >= 0 ) { setHeadAndPropagate(node, r); p.next = null ; failed = false ; return ; } } if (shouldParkAfterFailedAcquire(p, node) && parkAndCheckInterrupt()) throw new InterruptedException (); } } finally { if (failed) cancelAcquire(node); } } private void setHeadAndPropagate (Node node, int propagate) { Node h = head; setHead(node); if (propagate > 0 || h == null || h.waitStatus < 0 || (h = head) == null || h.waitStatus < 0 ) { Node s = node.next; if (s == null || s.isShared()) doReleaseShared(); } }
此时再看唤醒部分的代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 private void doReleaseShared () { for (;;) { Node h = head; if (h != null && h != tail) { int ws = h.waitStatus; if (ws == Node.SIGNAL) { if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0 )) continue ; unparkSuccessor(h); } else if (ws == 0 && !compareAndSetWaitStatus(h, 0 , Node.PROPAGATE)) continue ; } if (h == head) break ; } }
关于最后的if (h == head)语句的理解:
h == head时:说明头节点还没有被刚刚用 unparkSuccessor 唤醒的线程占有,此时 break 退出循环。
h != head时:头节点被刚刚唤醒的线程占有,那么这里重新进入下一轮循环,唤醒下一个节点。那么有一个问题,刚才被唤醒的节点会主动唤醒它后面的节点,为什么这里还要再下一轮中循环呢?我觉得这里应该是处于吞吐量的考虑(帮忙唤醒)。