使用示例

CountDownLatch 是一个同步工具类,它允许一个或多个线程一直等待,直到其他线程的操作执行完后再执行。在使用线程池的情况下提交任务的情况下,我们无法使用线程的join()方法,这就需要选择使用CountDowConLatch了。示例代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import java.util.concurrent.CountDowConLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

public class CountDownLatchTest {

// 创建一个CountDownLath实例
private static CountDownLatch countDownLatch = new CountDownLatch(2);

public static void main(String[] args) throws InterruptedException {
ExecutorService executorService = Executors.newFixedThreadPool(2);
// 将线程A添加到线程池
executorService.submit(()->{
try {
Thread.sleep(1000);
System.out.println("child threadOne over!");
} catch (InterruptedException e) {
e.printStackTrace();
}finally {
countDownLatch.countDown();
}
});
// 将线程B添加到线程池
executorService.submit(()->{
try {
Thread.sleep(1000);
System.out.println("child threadTwo over!");
} catch (InterruptedException e) {
e.printStackTrace();
}finally {
countDownLatch.countDown();
}
});
System.out.println("wait all child thread over!");
// 等待子线程执行完毕,返回
countDownLatch.await();
System.out.println("all child thread over!");
executorService.shutdown();
}
}

原理分析

我们先看CountDowConLatch的类图:

从类图可以看出,CountDowConLatch是使用AQS实现的。通过下面这个构造函数,可以发现实际上是把计数器的值赋给了AQS的状态变量state,也就是使用AQS的状态值来表示计数器的值。

1
2
3
4
5
6
7
8
public CountDownLatch(int count) {
if (count < 0) throw new IllegalArgumentException("count < 0");
this.sync = new Sync(count);
}

Sync(int count) {
setState(count);
}

下面来看看CountDownLatch是如何调用AQS来实现功能的。

  1. void await() 方法
    当线程调用CountDowLnLatch 对象的await方法后,当前线程会被阻塞,知道下面的情况之一发生才会返回:
  • 当所有的线程调用CountDowLnLatch对象的coutDown方法后,也就是计数器的值为0;
  • 其他线程调用了当前线程的interrupt()方法中断了当前线程,当前线程就会抛出 InterruptedException异常,然后返回。

看await()方法如何调用AQS的方法的。

1
2
3
public void await() throws InterruptedException {
sync.acquireSharedInterruptibly(1);
}

从上面代码可以看出,await方法委托sync调用AQS的acquireSharedInterruptibly方法,后者如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
// AQS获取共享资源时可被中断的方法
public final void acquireSharedInterruptibly(int arg)
throws InterruptedException {
// 如果线程被中断则抛出异常
if (Thread.interrupted())
throw new InterruptedException();
// 查看当前计数器是否为0,为0则直接返回,否则进入AQS的等待队列
if (tryAcquireShared(arg) < 0)
doAcquireSharedInterruptibly(arg);
}
// 上一章讲过,AQS的tryAcquireShared由具体的子类实现,CountDownLatch子类Sync实现如下
protected int tryAcquireShared(int acquires) {
return (getState() == 0) ? 1 : -1;
}

由以上代码可知,该方法的特点是线程获取资源时可被中断,并且获取的资源是共享资源。acquireSharedInterruptibly首先判断当前线程是否已被中断,若是则抛出异常,否则调用Sync实现的tryAcquireShared方法查看当前状态值(计数器值)是否为0,是则当前线程的await方法直接返回,否则调用AQS的doAcquireSharedInterruptibly方法让当前线程阻塞。另外可以看到,这里的tryAcquireShare传参没有被用到,调用tryAcquireShared的方法仅仅是为了检查当前状态是不是为0,并没有调用CAS让当前状态值减1。

  1. await(long timeout, TimeUnit unit)方法
1
2
3
4
public boolean await(long timeout, TimeUnit unit)
throws InterruptedException {
return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout));
}

当线程调用CountDowLnLatch 对象的await方法后,当前线程会被阻塞,知道下面的情况之一发生才会返回:

  • 当所有的线程调用CountDowLnLatch对象的coutDown方法后,也就是计数器的值为0;
  • 其他线程调用了当前线程的interrupt()方法中断了当前线程,当前线程就会抛出 InterruptedException异常,然后返回;
  • 设置的timeout时间到了,因为超时而返回false。
  1. void countDown()方法
    线程调用该方法后,计数器的值递减,递减后如果计数器为0则唤醒所有因调用await方法而被阻塞的线程,否则什么也不做。
1
2
3
4
public void countDown() {
// 委托sync调用AQS的方法
sync.releaseShared(1);
}

由上可知,CountDowLnLatch的countDown方法委托sync调用AQS的releaseShared方法,后者代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
// AQS 方法
public final boolean releaseShared(int arg) {
// 调用sync实现的tryReleaseShared
if (tryReleaseShared(arg)) {
// AQS的释放资源方法
doReleaseShared();
return true;
}
return false;
}

// sync的方法
protected boolean tryReleaseShared(int releases) {
// Decrement count; signal when transition to zero
// 循环进行CAS,知道当前线程成功完成CAS使计数器值(状态值state)减1并更新到sate
for (;;) {
int c = getState();
// 如果当前状态值为0则直接返回(1)
if (c == 0)
return false;
// 使用CAS让计数器值减1(2)
int nextc = c-1;
if (compareAndSetState(c, nextc))
return nextc == 0;
}
}

如上代码首先获取当前状态值,代码(1)判断如果当前状态值为0则直接返回false,从而coutDown()方法直接返回;否则执行代码(2)使用CAS将计数器值减1,CAS失败则循环重试,否则如果当前计数器值为0则返回true,返回true说明是最后一个线程调用的countdown方法,那么该线程除了让计数器值减1外,还需唤醒因调用CountDownLatch的await方法而被阻塞的线程,具体是调用AQS的doReleaseShared方法来激活阻塞的线程。这里代码(1)貌似是多余的,其实不然,之所有添加代码(1)是为了防止当前计数器值为0后,其他线程又调用了countDown方法,如果没有代码(1),状态值就可能变成负数。

  1. long getCount()方法
    获取当前计数器值,一般测试时使用
1
2
3
4
5
6
public long getCount() {
return sync.getCount();
}
int getCount() {
return getState();
}

可以看出还是调用AQS的getState方法来获取state的值

总结

CountDownLatch 允许一个或多个线程一直等待,直到其他线程的操作执行完后再执行。通过构造函数设置计数器state的值,也就是需要等待的线程数量,然后调用await方法阻塞自己,其他线程完成任务通过调用coutDown()方法来通知,每调用一次countDown()方法,count值就减1,当count的值等于0时,等待线程被唤醒继续执行任务。

CountDownLatch 是一次性的,计数器的值只能在构造方法中初始化一次,之后没有任何机制再次对其设置值,当 CountDownLatch 使用完毕后,它不能再次被使用。

下回分析CyclicBarrier。