【Java并发】 - CountDownLatch使用以及原理
概述
CountDownLatch是一个用来控制并发的很常见的工具,它允许一个或者多个线程等待其他的线程执行到某一操作,比如说需要去解析一个excel的数据,为了更快的解析则每个sheet都使用一个线程去进行解析,但是最后的汇总数据的工作则需要等待每个sheet的解析工作完成之后才能进行,这就可以使用CountDownLatch。
使用
例子:
这里有三个线程(main,thread1,thread2),其中main线程将调用countDownLatch的await方法去等待另外两个线程的某个操作的结束(调用countDownLatch的countDown方法)。
public class CountDownLatchDemo {
public static void main(String[] args) throws InterruptedException{
CountDownLatch countDownLatch = new CountDownLatch(2){
@Override
public void await() throws InterruptedException {
super.await();
System.out.println(Thread.currentThread().getName() + " count down is ok");
}
};
Thread thread1 = new Thread(new Runnable() {
@Override
public void run() {
//do something
try {
Thread.sleep(1000);
} catch (InterruptedException e) {
e.printStackTrace();
}
System.out.println(Thread.currentThread().getName() + " is done");
countDownLatch.countDown();
}
}, "thread1");
Thread thread2 = new Thread(new Runnable() {
@Override
public void run() {
//do something
try {
Thread.sleep(2000);
} catch (InterruptedException e) {
e.printStackTrace();
}
System.out.println(Thread.currentThread().getName() + " is done");
countDownLatch.countDown();
}
}, "thread2");
thread1.start();
thread2.start();
countDownLatch.await();
}
}
输出:
test thread1 is done
test thread2 is done
test thread3 is done
test thread3 count down is ok
main count down is ok
实现原理
CountDownLatch类实际上是使用计数器的方式去控制的,不难想象当我们初始化CountDownLatch的时候传入了一个int变量这个时候在类的内部初始化一个int的变量,每当我们调用countDownt()方法的时候就使得这个变量的值减1,而对于await()方法则去判断这个int的变量的值是否为0,是则表示所有的操作都已经完成,否则继续等待。
实际上如果了解AQS的话应该很容易想到可以使用AQS的共享式获取同步状态的方式来完成这个功能。而CountDownLatch实际上也就是这么做的。
从结构上来看CountDownLatch的实现还是很简单的,通过很常见的继承AQS的方式来完成自己的同步器。
CountDownLatch的同步器实现:
private static final class Sync extends AbstractQueuedSynchronizer {
private static final long serialVersionUID = 4982264981922014374L;
//初始化state
Sync(int count) {
setState(count);
}
int getCount() {
return getState();
}
//尝试获取同步状态
//只有当同步状态为0的时候返回大于0的数1
//同步状态不为0则返回-1
protected int tryAcquireShared(int acquires) {
return (getState() == 0) ? 1 : -1;
}
//自旋+CAS的方式释放同步状态
protected boolean tryReleaseShared(int releases) {
// Decrement count; signal when transition to zero
for (;;) {
int c = getState();
if (c == 0)
return false;
int nextc = c-1;
if (compareAndSetState(c, nextc))
return nextc == 0;
}
}
}
比较关键的地方是tryAquireShared()方法的实现,因为在父类的AQS中aquireShared()方法在调用tryAquireShared()方法的时候的判断依据是返回值是否大于零。
public final void acquireShared(int arg) {
if (tryAcquireShared(arg) < 0)
//失败则进入等待队列
doAcquireShared(arg);
}
同步器的实现相对都比较简单,主要思路和上面基本一致。
CountDownLatch的主要方法(本身代码量就很少就直接贴了)
public class CountDownLatch {
private static final class Sync extends AbstractQueuedSynchronizer {
private static final long serialVersionUID = 4982264981922014374L;
Sync(int count) {
setState(count);
}
int getCount() {
return getState();
}
protected int tryAcquireShared(int acquires) {
return (getState() == 0) ? 1 : -1;
}
protected boolean tryReleaseShared(int releases) {
// Decrement count; signal when transition to zero
for (;;) {
int c = getState();
if (c == 0)
return false;
int nextc = c-1;
if (compareAndSetState(c, nextc))
return nextc == 0;
}
}
}
private final Sync sync;
//初始化一个同步器
public CountDownLatch(int count) {
if (count < 0) throw new IllegalArgumentException("count < 0");
this.sync = new Sync(count);
}
//调用同步器的acquireSharedInterruptibly方法
//并且是响应中断的
public void await() throws InterruptedException {
sync.acquireSharedInterruptibly(1);
}
//调用同步器的releaseShared方法去让state减1
public void countDown() {
sync.releaseShared(1);
}
//获取剩余的count
public long getCount() {
return sync.getCount();
}
public String toString() {
return super.toString() + "[Count = " + sync.getCount() + "]";
}
}
最后:由于CountDownLatch需要开发人员很明确需要等待的条件,否则很容易造成await()方法一直阻塞的情况。