ForkJoin是Java7提供的原生多线程并行处理框架,其基本思想是将大人物分割成小任务,最后将小任务聚合起来得到结果。它非常类似于HADOOP提供的MapReduce框架,只是MapReduce的任务可以针对集群内的所有计算节点,可以充分利用集群的能力完成计算任务。ForkJoin更加类似于单机版的MapReduce。
我们要使用ForkJoin框架,必须首先创建一个ForkJoin任务。它提供在任务中执行fork()和join的操作机制,通常我们不直接继承ForkjoinTask类,只需要直接继承其子类。
1. RecursiveAction,用于没有返回结果的任务
2. RecursiveTask,用于有返回值的任务
· ForkJoinPool:task要通过ForkJoinPool来执行,分割的子任务也会添加到当前工作线程的双端队列中,进入队列的头部。当一个工作线程中没有任务时,会从其他工作线程的队列尾部获取一个任务。
ForkJoin框架使用了工作窃取的思想(work-stealing),算法从其他队列中窃取任务来执行。
下面代码,使我模拟的一个场景,即进行数组求和,假如我们的数组大小为400,计算机每次求一百个数字之和需要花费一秒钟(使用线程睡眠模拟)。那么单线程情况下大概需要4秒钟,而我们使用4个线程分别进行100个数字的求和,其并行计算,只需要1秒钟。
代码如下:
package com.zt.thread;
import java.util.Random;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinTask;
import java.util.concurrent.RecursiveTask;
/**
*
* 利用Fork/join框架进行400个100以内的数字的加法
*
* JDK用来执行Fork/Join任务的工作线程池大小默认等于CPU核心数。默认在一个4核CPU上,最多可以同时执行4个子任务。
*
* @author zhaotong
*
*/
class CountTask extends RecursiveTask<Long> {
private static final long serialVersionUID = 1L;
// 100个数组进行一次计算
private static final int SIZE = 100;
// 存储本次计算的数字
private int[] array;
// 计算起始位置
private int start;
// 计算结束位置
private int end;
CountTask(int[] array, int start, int end) {
this.array = array;
this.start = start;
this.end = end;
}
@Override
protected Long compute() {
// 如果任务小于我们规定的阀值,则直接进行计算
if (end - start <= SIZE) {
long sum = 0;
for (int i = start; i < end; i++) {
sum += array[i];
}
try {
Thread.sleep(1000);
} catch (InterruptedException e) {
}
return sum;
}
// 任务太大,将任务一分为二
int middle = (end + start) / 2;
CountTask ct1 = new CountTask(array, start, middle);
CountTask ct2 = new CountTask(array, middle, end);
invokeAll(ct1, ct2);
long res1 = ct1.join();
long res2 = ct2.join();
return res1 + res2;
}
}
public class ForkAndJoin {
private static void fillRandomArray(int[] array) {
Random rd=new Random();
for (int i = 0; i < array.length; i++) {
array[i] = rd.nextInt(100);
}
}
// 单线程进行数组求和计算
private static long computeArray(int[] array) {
long sum = 0;
for (int i = 0; i < array.length; i++) {
sum += array[i];
if ((i + 1) % 100 == 0) {
try {
Thread.sleep(1000);
} catch (InterruptedException e) {
}
}
}
return sum;
}
public static void main(String[] args) {
// 获取cpu核心数
// Runtime.getRuntime().availableProcessors()
int[] test = new int[400];
fillRandomArray(test);
// 模拟单个线程进行计算,假设,我们每计算100个数字,需要消耗一秒钟(这里我们使用线程睡眠1秒钟)
long startTime1 = System.currentTimeMillis();
Long result1 = computeArray(test);
long endTime1 = System.currentTimeMillis();
System.out.println("single Thread sum: " + result1 + " in "
+ (endTime1 - startTime1) + " ms.");
// 使用fork/join 进行并行计算,我的电脑四核 可以使用
// Runtime.getRuntime().availableProcessors()
/**
* ForkJoinPool提供了一系列的submit方法,计算任务。ForkJoinPool默认的线程数通过Runtime.
* availableProcessors()获得, 因为在计算密集型的任务中,获得多于处理性核心数的线程并不能获得更多性能提升
* ,该方法也可以传以前的Runnable, Callback的接口实现(底层会将其封装成ForkJoinTask对象)。
*/
ForkJoinPool fp = new ForkJoinPool(Runtime.getRuntime()
.availableProcessors());
ForkJoinTask<Long> ft = new CountTask(test, 0, test.length);
long startTime = System.currentTimeMillis();
Long result = fp.invoke(ft);
long endTime = System.currentTimeMillis();
System.out.println("Fork/join sum: " + result + " in "
+ (endTime - startTime) + " ms.");
}
}
执行结果:
image.png
注意,这里有个特别需要注意的地方,在划分任务时,很多人会这样写。
int middle = (end + start) / 2;
CountTask ct1 = new CountTask(array, start, middle);
CountTask ct2 = new CountTask(array, middle, end);
ct1.fork();
ct2.fork();
long res1 = ct1.join();
long res2 = ct2.join();
因为compute()方法其实本身就是一个线程,这样写,就让compute()所在的线程闲置了。因此应当使用invokeAll()方法,invokeAll的N个任务中,其中N-1个任务会使用fork()交给其它线程执行,但是,它还会留一个任务自己执行,这样,就充分利用了线程池,保证没有空闲的不干活的线程。(廖学峰大大的教程真的不错,强烈推荐)。