线程本地变量及传递ThreadLocal,InheritableThreadLocal,TransmittableThreadLocal

lsr_flyingIP属地: 上海
字数 1,604阅读 540

在实际开发中,我们经常需要传递一些上下文变量,有些是线程独立的,有些可能需要传递到子线程,甚至是线程池中,比如,分布式链路追踪中的traceId,请求会话信息等。我们将介绍threadlocal, inheritableThreadLocal的局限和原理。


1.ThreadLocal

举一个简单场景,假设我们提供一http请求服务,服务内部可能会涉及到多个方法,每个方法都可能会需要用到用户信息(比如:用户id)。那么一般我们可以定义一个用户Id上下文:

public class UserIdContext {
    private static ThreadLocal<String> userIdThreadLocal = new ThreadLocal<>();

    /**
     * 设置userId上下文
     */
    public static void setUserId(String userId) {
        userIdThreadLocal.set(userId);
    }

    /**
     * 获取userId
     */
    public static String getUserId() {
        return userIdThreadLocal.get();
    }

    /**
     * 及时清空userId上下文,消除内存泄漏
     */
    public static void clearUserId() {
        userIdThreadLocal.remove();
    }
}

于是,我们就可以在我们需要获取userId的时候,调用UserIdContext.getUserId()方法,获取用户信息。这里面就涉及到ThreadLocal这个对象了,大概原理如下图,如果了解,可以直接看下面章节了。

ThreadLocal原理图

下面介绍下ThreadLocal这个对象的原理,先看看是如何保存上下文信息的,获取的雷同,后文就不展开了。

public void set(T value) {
        // 获取当前请求的线程t
        Thread t = Thread.currentThread();
        // 从线程t中,获取ThreadLocalMap对象,这个对象保存了所有和这个线程相关的上下文值(当然也将会包括用户id上下文),注意ThreadLocalMap的get,set方法均为private,及虽然Thread保存了,但是也仅仅保存。
        ThreadLocalMap map = getMap(t);
        if (map != null)
            // 把我们要设置的用户id上下文保存,我们说过,这个ThreadLocalMap是保留了所有不同ThreadLocalMap的值,因此,针对UserIdThreadLocal,这里保存自身引用作为key。
            map.set(this, value);
        else
            // 先初始化ThreadLocalMap再保存,体现懒加载思想
            createMap(t, value);
    }

从以上代码可以看到,线程的上下文值是保存在线程自身当中,可以看到getMap方法如下:

ThreadLocalMap getMap(Thread t) {
        return t.threadLocals;
    }

下面我们再看看此次的核心ThreadLocalMap中的数据结构:

static class ThreadLocalMap {
       /**
        * 保存某个ThreadLocal要记录的单个线程的值
        **/
        static class Entry extends WeakReference<ThreadLocal<?>> {
            /** The value associated with this ThreadLocal. */
            Object value;
            Entry(ThreadLocal<?> k, Object v) {
                super(k);
                value = v;
            }
        }

        /**
         * 内部所有的Entry的值都是利用数组来保存的,当我们要根据ThreadLocal对象来查找值的时候,
        * 就是根据ThreadLocal对象的哈希操作,来计算出对象table的索引,从而获取数据的
         */
        private Entry[] table;

        /**
        * 这就是ThreadLocal#get方法的最终实现方法,
        * 可以看出根据ThreadLocal的ThreadLocalHashCode及table的长度,可以计算出当前key保存的索引,注意该索引不一定就是实际保存的位置。
        * 这里ThreadLocalHashCode其实是一个ThreadLocal里一个AtomicInteger类型的局静态变量维护的一个计数器。
        **/
        private Entry getEntry(ThreadLocal<?> key) {
            int i = key.threadLocalHashCode & (table.length - 1);
            Entry e = table[i];
            // 根据首次计算的索引,获得的key必须和当前key一致,如果不等,说明还要继续查找
            if (e != null && e.get() == key)
                return e;
            else
            // 内部继续顺延索引查找具有相同key的Entry对象
                return getEntryAfterMiss(key, i, e);
        }

        private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
            Entry[] tab = table;
            int len = tab.length;

            while (e != null) {
                ThreadLocal<?> k = e.get();
                if (k == key)
                    return e;
                if (k == null)
                    expungeStaleEntry(i);
                else
                    i = nextIndex(i, len);
                e = tab[i];
            }
            return null;
        }

        /**
         * Set the value associated with key.
         * set方法相对复杂点,也是先根据ThreadLocalHashCode先计算出一个索引,判断是否可以保存,最后可能还会涉及到对table数组进行rehash。
         */
        private void set(ThreadLocal<?> key, Object value) {
            Entry[] tab = table;
            int len = tab.length;
            int i = key.threadLocalHashCode & (len-1);

            // 如果根据哈希计算出的索引已经有值,如果主键非空,且和当前要设置的key不同,那么我们就调用nextIndex方法索引顺延,继续循环(开放地址法)
            for (Entry e = tab[i];
                 e != null;
                 e = tab[i = nextIndex(i, len)]) {
                ThreadLocal<?> k = e.get();
                // 主键相同,重新设置下value就好
                if (k == key) {
                    e.value = value;
                    return;
                }
                
                if (k == null) {
                    replaceStaleEntry(key, value, i);
                    return;
                }
            }

            // 最终到这里,说明索引i中未保存过指,新建一个entry维护进table
            tab[i] = new Entry(key, value);
            int sz = ++size;
            if (!cleanSomeSlots(i, sz) && sz >= threshold)
                rehash();
        }

以上,就是ThreadLocal的原理。关于它的局限性我们应该也了解,就是如果在多线程的情况下,子线程是无法获取父线程设置的值的。实例代码如下:

     public class ThreadLocalLimitDemo {
    static ThreadLocal<String> threadLocal = new ThreadLocal<>();

    /**
     * 运行结果:
     * main--->[p123]
     * child--->[null]
     */
    public static void main(String[] args) {
        threadLocal.set("p123");

        new Thread(() -> {
            System.out.println("child--->[" + threadLocal.get() + "]");
        }).start();

        System.out.println("main--->[" + threadLocal.get() + "]");
    }
}

而实际情况,我们常常会借助于多线程,优化代码中涉及到io操作的逻辑。对于上面这种new一个子线程的情况,我们可以借助jdk自带的InheritableThreadLocal来解决问题。

2.InheritableThreadLocal

先来看看,我们把ThreadLocal替换为InheritableThreadLocal的话,上文的示例是否就解决了呢:

public class ThreadLocalLimitDemo {
    static ThreadLocal<String> threadLocal = new InheritableThreadLocal<>();

    /**
     * 运行结果:
     * main--->[p123]
     * child--->[p123]
     */
    public static void main(String[] args) {
        threadLocal.set("p123");

        new Thread(() -> {
            System.out.println("child--->[" + threadLocal.get() + "]");
        }).start();

        System.out.println("main--->[" + threadLocal.get() + "]");
    }
}

子线程中也可以取到父线程设置的上下文值了。
下面我们来看看,这是怎么实现的。
InheriatableThreadLocal本身有什么特别的吗?其实没有,我们看到InheritableThreadLocal继承了ThreadLocal,本身重写的方法也不多。整个类代码如下:

public class InheritableThreadLocal<T> extends ThreadLocal<T> {
    /**
     * 通常情况,子线程的值和父线程一样,但是,如果希望有所加工,可以重载该方法
     */
    protected T childValue(T parentValue) {
        return parentValue;
    }

    /**
     * Get the map associated with a ThreadLocal.
     */
    ThreadLocalMap getMap(Thread t) {
       return t.inheritableThreadLocals;
    }

    /**
     * Create the map associated with a ThreadLocal.
     */
    void createMap(Thread t, T firstValue) {
        t.inheritableThreadLocals = new ThreadLocalMap(this, firstValue);
    }
}

可以看到,与ThreadLocal的主要区别就是getMap和createMap关联的是线程中的inheritableThreadLocals字段,而不是Thread#threadLocals字段了。
那么,上下文继承的关键就是在Thread类,在new一个Thread的时候进行了继承操作,new最终会调init方法。

private void init(ThreadGroup g, Runnable target, String name,
                      long stackSize, AccessControlContext acc,
                      boolean inheritThreadLocals) {
         //获得父线程,currentThread方法是个native方法
        Thread parent = currentThread();
        //......
        // inheritThreadLocals :if {@code true}, inherit initial values for inheritable thread-locals from the constructing thread
        // 如果父线程有inheritableThreadLocals,则复制创建,则继承父线程的上下文
        if (inheritThreadLocals && parent.inheritableThreadLocals != null)
            this.inheritableThreadLocals =
                ThreadLocal.createInheritedMap(parent.inheritableThreadLocals);
    }

以上可以看出来,InheritableThreadLocal是在new线程的时候才能起作用,但是如果使用了线程池,线程会复用,这样,在复用的线程中,还是无法正确获取到线程上下文的值。再来段示例:

public class InheritableThreadLocalLimitDemo{

    static ThreadLocal<String> threadLocal = new InheritableThreadLocal<>();

    static ExecutorService executorService
            = Executors.newSingleThreadExecutor());

    /**
     * 执行结果:
     * main--->p123
     * first creat new thread--->p123
     * then reuse thread--->p123
     */
    public static void main(String[] args) {
        threadLocal.set("p123");
        System.out.println("main--->" + threadLocal.get());

        executorService.submit(new Thread(() -> {
            System.out.println("first creat new thread--->" + threadLocal.get());
        }));

        threadLocal.set("t456");
        executorService.submit(new Thread(() -> {
            System.out.println("then reuse thread--->" + threadLocal.get());
        }));

        try {
            Thread.sleep(1000);
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
        System.exit(0);
    }
}

可以看到,当复用线程时,无法获取到父线程设置的"t456"这个值,仍然是第一次创建的"p123"这个值,即:在线程池中无法继承父线程的线程上下文。
这个问题我们在下一节看看如何解决。

3.TransmittableThreadLocal

我们先自己分析下,这个问题出现在哪,以及可以怎么解决。首先,问题是,在复用线程的时候,当前线程未继承父线程的InheritableThreadLocal值(因为复用,不涉及到new线程操作了)。知道问题了,那么如果要继承上一个线程的InheritableThreadLocal值,我们是不是可以通过某个折中办法,把上一个线程的值传进来。时机呢?看来只有在定义业务逻辑的时候,把local值保存起来,在执行业务的逻辑之前,先利用保存的local值重置下threadlocal值。业务开发中,引入一个中间层,往往能解决很多问题。因此,针对原来定义一个Thread的逻辑,我们引入一个新类,保存父线程的值,并重写run方法,在该方法中先重置ThreadLocal值,然后再调用业务逻辑。
下面我们看看这样实现是否有用呢。

public abstract class InheritableTask implements Runnable {
    /**
    * 中间值
    **/
    private Object inheritableThreadLocalsObj;

    /**
     * 保存父线程threadlocal的值
     */
    public InheritableTask(){
        try{
            Thread currentThread = Thread.currentThread();
            Field inheritableThreadLocalsField = Thread.class.getDeclaredField("inheritableThreadLocals");
            inheritableThreadLocalsField.setAccessible(true);
            Object threadLocalMapObj = inheritableThreadLocalsField.get(currentThread);
            if(threadLocalMapObj != null){
                // 调用ThreadLocal中的createInheritedMap方法,重新复制一个新的inheritableThreadLocals值
                Class threadLocalMapClazz =   inheritableThreadLocalsField.getType();
                Method method =  ThreadLocal.class.getDeclaredMethod("createInheritedMap",threadLocalMapClazz);
                method.setAccessible(true);
                // 创建一个新的ThreadLocalMap类型的inheritableThreadLocals
                Object newThreadLocalMap = method.invoke(ThreadLocal.class,threadLocalMapObj);

                // 将这个值暂存下来
                inheritableThreadLocalsObj = newThreadLocalMap;
            }
        }catch (Exception e){
            throw new IllegalStateException(e);
        }
    }

    /**
     * 搞个代理方法,这个方法中处理业务逻辑
     */
    public abstract void runTask();

    @Override
    public void run() {
        // 此处得到的是当前处理该业务的线程,也就是线程池中的线程
        Thread currentThread = Thread.currentThread();
        Field field = null;
        try {
            field  = Thread.class.getDeclaredField("inheritableThreadLocals");
            field.setAccessible(true);
            // 将暂存的值,赋值给currentThread
            if (inheritableThreadLocalsObj != null && field != null) {
                field.set(currentThread, inheritableThreadLocalsObj);
                inheritableThreadLocalsObj = null;
            }
            // 执行任务
            runTask();
        }catch (Exception e){
            throw new IllegalStateException(e);
        }finally {
            // 最后将线程中的InheritableThreadLocals设置为null
            try{
                field.set(currentThread,null);
            }catch (Exception e){
                throw new IllegalStateException(e);
            }
        }
    }
}

接下来看下使用效果

public class TransmittableThreadLocalWrapper {

    static ThreadLocal<String> threadLocal = new InheritableThreadLocal<>();

    static ExecutorService executorService
            = Executors.newSingleThreadExecutor();

    /**
     * 执行结果:
     * main--->p123
     * first creat new thread--->p123
     * then reuse thread--->t456
     */
    public static void main(String[] args) {
        threadLocal.set("p123");
        System.out.println("main--->" + threadLocal.get());

        executorService.submit(new Thread(() -> {
            System.out.println("first creat new thread--->" + threadLocal.get());
        }));

        threadLocal.set("t456");
        executorService.submit(new InheritableTask() {
            @Override
            public void runTask() {
                System.out.println("then reuse thread--->" + threadLocal.get());
            }
        });

        try {
            Thread.sleep(1000);
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
        System.exit(0);
    }
}

问题确实解决了。

不过这里面还有很多问题不够优雅,一是反射耗性能,二是还要去实现类似InheritableTask这种具有侵入性的逻辑。所以,这里重点引入TransmittableThreadLocal(github地址l
)。具体用法可以详细参考官网。这里简单总结下:1)可以通过修饰Runnable和Callable; 2)可以通过修饰线程池;3)如果希望降低侵入性,还可以通过agent的方式修饰jdk线程池实现类。

public class TransmittableThreadLocalWrapper {

    static ThreadLocal<String> threadLocal = new TransmittableThreadLocal<>();

    static ExecutorService executorService
            = TtlExecutors.getTtlExecutorService(Executors.newSingleThreadExecutor());

    /**
     * 执行结果:
     * main--->p123
     * first creat new thread--->p123
     * then reuse thread--->t456
     */
    public static void main(String[] args) {
        threadLocal.set("p123");
        System.out.println("main--->" + threadLocal.get());

        executorService.submit(new Thread(() -> {
            System.out.println("first creat new thread--->" + threadLocal.get());
        }));

        threadLocal.set("t456");
        executorService.submit(new Thread(() -> {
            System.out.println("then reuse thread--->" + threadLocal.get());
        }));

        try {
            Thread.sleep(1000);
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
        System.exit(0);
    }
}

我们下面重点介绍下其实现原理。
根据二八原理,首先尝试介绍下最重要的百分之二十的实现思想。
最核心思想其实还是类似:在new一个线程的时候,获取父线程的所有TransmittableThreadLocal的值,在执行子线程run方法的时候,将父线程的值拷贝到子线程当中。

记住以下几个重要的类(或内部类、字段):
a.TransmittableThreadLocal.Transmitter

负责父子线程TransmittableThreadLocal的录制回放工作,最核心的一个类了。

b.TransmittableThreadLocal.Transmitter.Snapshot

顾名思义,只是一个对象载体,保存父线程/子线程的所有ThreadLocal的值。

c.TransmittableThreadLocal#holder:private static InheritableThreadLocal<WeakHashMap<TransmittableThreadLocal<Object>, ?>>

TransmittableThreadLocal最核心的字段,正是有该字段,所有传递机制才得以存在。
记录了当前线程都包含了哪些ThreadLocal对象,不然在录制父线程的threadlocal值时,哪能知道用户都定义了哪些。

相信通过以上,已经大致知道TransmittableThreadLocal的原理了,下面再看以下几个方法片段进一步理解下:
a.

/**
     * see {@link InheritableThreadLocal#set}
     */
    @Override
    public final void set(T value) {
        if (!disableIgnoreNullValueSemantics && null == value) {
            // may set null to remove value
            remove();
        } else {
            super.set(value);
            // 核心思想方法
            addThisToHolder();
        }
    }

b.

@NonNull
        public static Object capture() {
            return new Snapshot(captureTtlValues(), captureThreadLocalValues());
        }

        private static WeakHashMap<TransmittableThreadLocal<Object>, Object> captureTtlValues() {
            WeakHashMap<TransmittableThreadLocal<Object>, Object> ttl2Value = new WeakHashMap<TransmittableThreadLocal<Object>, Object>();
            for (TransmittableThreadLocal<Object> threadLocal : holder.get().keySet()) {
                ttl2Value.put(threadLocal, threadLocal.copyValue());
            }
            return ttl2Value;
        } 

c.

@NonNull
        public static Object replay(@NonNull Object captured) {
            final Snapshot capturedSnapshot = (Snapshot) captured;
            return new Snapshot(replayTtlValues(capturedSnapshot.ttl2Value), replayThreadLocalValues(capturedSnapshot.threadLocal2Value));
        }

        @NonNull
        private static WeakHashMap<TransmittableThreadLocal<Object>, Object> replayTtlValues(@NonNull WeakHashMap<TransmittableThreadLocal<Object>, Object> captured) {
            WeakHashMap<TransmittableThreadLocal<Object>, Object> backup = new WeakHashMap<TransmittableThreadLocal<Object>, Object>();

            for (final Iterator<TransmittableThreadLocal<Object>> iterator = holder.get().keySet().iterator(); iterator.hasNext(); ) {
                TransmittableThreadLocal<Object> threadLocal = iterator.next();

                // backup
                backup.put(threadLocal, threadLocal.get());

                // clear the TTL values that is not in captured
                // avoid the extra TTL values after replay when run task
                if (!captured.containsKey(threadLocal)) {
                    iterator.remove();
                    threadLocal.superRemove();
                }
            }

            // set TTL values to captured
            setTtlValuesTo(captured);

            // call beforeExecute callback
            doExecuteCallback(true);

            return backup;
        }

d.

private TtlRunnable(@NonNull Runnable runnable, boolean releaseTtlValueReferenceAfterRun) {
        this.capturedRef = new AtomicReference<Object>(capture());
        this.runnable = runnable;
        this.releaseTtlValueReferenceAfterRun = releaseTtlValueReferenceAfterRun;
    }

    /**
     * wrap method {@link Runnable#run()}.
     */
    @Override
    public void run() {
        Object captured = capturedRef.get();
        if (captured == null || releaseTtlValueReferenceAfterRun && !capturedRef.compareAndSet(captured, null)) {
            throw new IllegalStateException("TTL value reference is released after run!");
        }

        Object backup = replay(captured);
        try {
            runnable.run();
        } finally {
            restore(backup);
        }
    }

通过以上分析,相信应该能大致掌握TransmittableThreadLocal的原理了吧。最后放一张时序图,再进一步加深下印象。

  • TransmittableThreadLocal时序图


    image.png
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。
0人点赞
更多精彩内容,就在简书APP
"小礼物走一走,来简书关注我"
还没有人赞赏,支持一下
总资产0共写了8744字获得2个赞共1个粉丝