java并发编程之Fork/Join 框架

ForkJoin是Java7提供的原生多线程并行处理框架,其基本思想是将大人物分割成小任务,最后将小任务聚合起来得到结果。它非常类似于HADOOP提供的MapReduce框架,只是MapReduce的任务可以针对集群内的所有计算节点,可以充分利用集群的能力完成计算任务。ForkJoin更加类似于单机版的MapReduce。

我们要使用ForkJoin框架,必须首先创建一个ForkJoin任务。它提供在任务中执行fork()和join的操作机制,通常我们不直接继承ForkjoinTask类,只需要直接继承其子类。

1. RecursiveAction,用于没有返回结果的任务

2. RecursiveTask,用于有返回值的任务

· ForkJoinPool:task要通过ForkJoinPool来执行,分割的子任务也会添加到当前工作线程的双端队列中,进入队列的头部。当一个工作线程中没有任务时,会从其他工作线程的队列尾部获取一个任务。

ForkJoin框架使用了工作窃取的思想(work-stealing),算法从其他队列中窃取任务来执行。

下面代码,使我模拟的一个场景,即进行数组求和,假如我们的数组大小为400,计算机每次求一百个数字之和需要花费一秒钟(使用线程睡眠模拟)。那么单线程情况下大概需要4秒钟,而我们使用4个线程分别进行100个数字的求和,其并行计算,只需要1秒钟。

代码如下:

package com.zt.thread;

import java.util.Random;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinTask;
import java.util.concurrent.RecursiveTask;

/**
 * 
 * 利用Fork/join框架进行400个100以内的数字的加法
 * 
 * JDK用来执行Fork/Join任务的工作线程池大小默认等于CPU核心数。默认在一个4核CPU上,最多可以同时执行4个子任务。
 * 
 * @author zhaotong
 *
 */
class CountTask extends RecursiveTask<Long> {

    private static final long serialVersionUID = 1L;
    // 100个数组进行一次计算
    private static final int SIZE = 100;
    // 存储本次计算的数字
    private int[] array;
    // 计算起始位置
    private int start;
    // 计算结束位置
    private int end;

    CountTask(int[] array, int start, int end) {
        this.array = array;
        this.start = start;
        this.end = end;
    }

    @Override
    protected Long compute() {

        // 如果任务小于我们规定的阀值,则直接进行计算
        if (end - start <= SIZE) {
            long sum = 0;
            for (int i = start; i < end; i++) {
                sum += array[i];
            }
            try {
                Thread.sleep(1000);
            } catch (InterruptedException e) {
            }
            return sum;

        }
        // 任务太大,将任务一分为二
        int middle = (end + start) / 2;
        CountTask ct1 = new CountTask(array, start, middle);
        CountTask ct2 = new CountTask(array, middle, end);
        invokeAll(ct1, ct2);
        long res1 = ct1.join();
        long res2 = ct2.join();
        return res1 + res2;
    }

}

public class ForkAndJoin {

    private static void fillRandomArray(int[] array) {
        Random rd=new Random();
        for (int i = 0; i < array.length; i++) {
            array[i] = rd.nextInt(100);
        }
    }

    // 单线程进行数组求和计算
    private static long computeArray(int[] array) {
        long sum = 0;
        for (int i = 0; i < array.length; i++) {
            sum += array[i];
            if ((i + 1) % 100 == 0) {
                try {
                    Thread.sleep(1000);
                } catch (InterruptedException e) {
                }
            }
        }
        return sum;
    }

    public static void main(String[] args) {
        // 获取cpu核心数
        // Runtime.getRuntime().availableProcessors()

        int[] test = new int[400];
        fillRandomArray(test);

        // 模拟单个线程进行计算,假设,我们每计算100个数字,需要消耗一秒钟(这里我们使用线程睡眠1秒钟)
        long startTime1 = System.currentTimeMillis();
        Long result1 = computeArray(test);
        long endTime1 = System.currentTimeMillis();

        System.out.println("single Thread sum: " + result1 + " in "
                + (endTime1 - startTime1) + " ms.");
        // 使用fork/join 进行并行计算,我的电脑四核 可以使用
        // Runtime.getRuntime().availableProcessors()

        /**
         * ForkJoinPool提供了一系列的submit方法,计算任务。ForkJoinPool默认的线程数通过Runtime.
         * availableProcessors()获得, 因为在计算密集型的任务中,获得多于处理性核心数的线程并不能获得更多性能提升
         * ,该方法也可以传以前的Runnable, Callback的接口实现(底层会将其封装成ForkJoinTask对象)。
         */
        ForkJoinPool fp = new ForkJoinPool(Runtime.getRuntime()
                .availableProcessors());
        ForkJoinTask<Long> ft = new CountTask(test, 0, test.length);
        long startTime = System.currentTimeMillis();
        Long result = fp.invoke(ft);
        long endTime = System.currentTimeMillis();
        System.out.println("Fork/join sum: " + result + " in "
                + (endTime - startTime) + " ms.");

    }

}

执行结果:

image.png

注意,这里有个特别需要注意的地方,在划分任务时,很多人会这样写。

        int middle = (end + start) / 2;
        CountTask ct1 = new CountTask(array, start, middle);
        CountTask ct2 = new CountTask(array, middle, end);
        ct1.fork();
        ct2.fork();
        long res1 = ct1.join();
        long res2 = ct2.join();

因为compute()方法其实本身就是一个线程,这样写,就让compute()所在的线程闲置了。因此应当使用invokeAll()方法,invokeAll的N个任务中,其中N-1个任务会使用fork()交给其它线程执行,但是,它还会留一个任务自己执行,这样,就充分利用了线程池,保证没有空闲的不干活的线程。(廖学峰大大的教程真的不错,强烈推荐)。

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容