在实际开发中,我们经常需要传递一些上下文变量,有些是线程独立的,有些可能需要传递到子线程,甚至是线程池中,比如,分布式链路追踪中的traceId,请求会话信息等。我们将介绍threadlocal, inheritableThreadLocal的局限和原理。
1.ThreadLocal
举一个简单场景,假设我们提供一http请求服务,服务内部可能会涉及到多个方法,每个方法都可能会需要用到用户信息(比如:用户id)。那么一般我们可以定义一个用户Id上下文:
public class UserIdContext {
private static ThreadLocal<String> userIdThreadLocal = new ThreadLocal<>();
/**
* 设置userId上下文
*/
public static void setUserId(String userId) {
userIdThreadLocal.set(userId);
}
/**
* 获取userId
*/
public static String getUserId() {
return userIdThreadLocal.get();
}
/**
* 及时清空userId上下文,消除内存泄漏
*/
public static void clearUserId() {
userIdThreadLocal.remove();
}
}
于是,我们就可以在我们需要获取userId的时候,调用UserIdContext.getUserId()方法,获取用户信息。这里面就涉及到ThreadLocal这个对象了,大概原理如下图,如果了解,可以直接看下面章节了。
下面介绍下ThreadLocal这个对象的原理,先看看是如何保存上下文信息的,获取的雷同,后文就不展开了。
public void set(T value) {
// 获取当前请求的线程t
Thread t = Thread.currentThread();
// 从线程t中,获取ThreadLocalMap对象,这个对象保存了所有和这个线程相关的上下文值(当然也将会包括用户id上下文),注意ThreadLocalMap的get,set方法均为private,及虽然Thread保存了,但是也仅仅保存。
ThreadLocalMap map = getMap(t);
if (map != null)
// 把我们要设置的用户id上下文保存,我们说过,这个ThreadLocalMap是保留了所有不同ThreadLocalMap的值,因此,针对UserIdThreadLocal,这里保存自身引用作为key。
map.set(this, value);
else
// 先初始化ThreadLocalMap再保存,体现懒加载思想
createMap(t, value);
}
从以上代码可以看到,线程的上下文值是保存在线程自身当中,可以看到getMap方法如下:
ThreadLocalMap getMap(Thread t) {
return t.threadLocals;
}
下面我们再看看此次的核心ThreadLocalMap中的数据结构:
static class ThreadLocalMap {
/**
* 保存某个ThreadLocal要记录的单个线程的值
**/
static class Entry extends WeakReference<ThreadLocal<?>> {
/** The value associated with this ThreadLocal. */
Object value;
Entry(ThreadLocal<?> k, Object v) {
super(k);
value = v;
}
}
/**
* 内部所有的Entry的值都是利用数组来保存的,当我们要根据ThreadLocal对象来查找值的时候,
* 就是根据ThreadLocal对象的哈希操作,来计算出对象table的索引,从而获取数据的
*/
private Entry[] table;
/**
* 这就是ThreadLocal#get方法的最终实现方法,
* 可以看出根据ThreadLocal的ThreadLocalHashCode及table的长度,可以计算出当前key保存的索引,注意该索引不一定就是实际保存的位置。
* 这里ThreadLocalHashCode其实是一个ThreadLocal里一个AtomicInteger类型的局静态变量维护的一个计数器。
**/
private Entry getEntry(ThreadLocal<?> key) {
int i = key.threadLocalHashCode & (table.length - 1);
Entry e = table[i];
// 根据首次计算的索引,获得的key必须和当前key一致,如果不等,说明还要继续查找
if (e != null && e.get() == key)
return e;
else
// 内部继续顺延索引查找具有相同key的Entry对象
return getEntryAfterMiss(key, i, e);
}
private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
Entry[] tab = table;
int len = tab.length;
while (e != null) {
ThreadLocal<?> k = e.get();
if (k == key)
return e;
if (k == null)
expungeStaleEntry(i);
else
i = nextIndex(i, len);
e = tab[i];
}
return null;
}
/**
* Set the value associated with key.
* set方法相对复杂点,也是先根据ThreadLocalHashCode先计算出一个索引,判断是否可以保存,最后可能还会涉及到对table数组进行rehash。
*/
private void set(ThreadLocal<?> key, Object value) {
Entry[] tab = table;
int len = tab.length;
int i = key.threadLocalHashCode & (len-1);
// 如果根据哈希计算出的索引已经有值,如果主键非空,且和当前要设置的key不同,那么我们就调用nextIndex方法索引顺延,继续循环(开放地址法)
for (Entry e = tab[i];
e != null;
e = tab[i = nextIndex(i, len)]) {
ThreadLocal<?> k = e.get();
// 主键相同,重新设置下value就好
if (k == key) {
e.value = value;
return;
}
if (k == null) {
replaceStaleEntry(key, value, i);
return;
}
}
// 最终到这里,说明索引i中未保存过指,新建一个entry维护进table
tab[i] = new Entry(key, value);
int sz = ++size;
if (!cleanSomeSlots(i, sz) && sz >= threshold)
rehash();
}
以上,就是ThreadLocal的原理。关于它的局限性我们应该也了解,就是如果在多线程的情况下,子线程是无法获取父线程设置的值的。实例代码如下:
public class ThreadLocalLimitDemo {
static ThreadLocal<String> threadLocal = new ThreadLocal<>();
/**
* 运行结果:
* main--->[p123]
* child--->[null]
*/
public static void main(String[] args) {
threadLocal.set("p123");
new Thread(() -> {
System.out.println("child--->[" + threadLocal.get() + "]");
}).start();
System.out.println("main--->[" + threadLocal.get() + "]");
}
}
而实际情况,我们常常会借助于多线程,优化代码中涉及到io操作的逻辑。对于上面这种new一个子线程的情况,我们可以借助jdk自带的InheritableThreadLocal来解决问题。
2.InheritableThreadLocal
先来看看,我们把ThreadLocal替换为InheritableThreadLocal的话,上文的示例是否就解决了呢:
public class ThreadLocalLimitDemo {
static ThreadLocal<String> threadLocal = new InheritableThreadLocal<>();
/**
* 运行结果:
* main--->[p123]
* child--->[p123]
*/
public static void main(String[] args) {
threadLocal.set("p123");
new Thread(() -> {
System.out.println("child--->[" + threadLocal.get() + "]");
}).start();
System.out.println("main--->[" + threadLocal.get() + "]");
}
}
子线程中也可以取到父线程设置的上下文值了。
下面我们来看看,这是怎么实现的。
InheriatableThreadLocal本身有什么特别的吗?其实没有,我们看到InheritableThreadLocal继承了ThreadLocal,本身重写的方法也不多。整个类代码如下:
public class InheritableThreadLocal<T> extends ThreadLocal<T> {
/**
* 通常情况,子线程的值和父线程一样,但是,如果希望有所加工,可以重载该方法
*/
protected T childValue(T parentValue) {
return parentValue;
}
/**
* Get the map associated with a ThreadLocal.
*/
ThreadLocalMap getMap(Thread t) {
return t.inheritableThreadLocals;
}
/**
* Create the map associated with a ThreadLocal.
*/
void createMap(Thread t, T firstValue) {
t.inheritableThreadLocals = new ThreadLocalMap(this, firstValue);
}
}
可以看到,与ThreadLocal的主要区别就是getMap和createMap关联的是线程中的inheritableThreadLocals字段,而不是Thread#threadLocals字段了。
那么,上下文继承的关键就是在Thread类,在new一个Thread的时候进行了继承操作,new最终会调init方法。
private void init(ThreadGroup g, Runnable target, String name,
long stackSize, AccessControlContext acc,
boolean inheritThreadLocals) {
//获得父线程,currentThread方法是个native方法
Thread parent = currentThread();
//......
// inheritThreadLocals :if {@code true}, inherit initial values for inheritable thread-locals from the constructing thread
// 如果父线程有inheritableThreadLocals,则复制创建,则继承父线程的上下文
if (inheritThreadLocals && parent.inheritableThreadLocals != null)
this.inheritableThreadLocals =
ThreadLocal.createInheritedMap(parent.inheritableThreadLocals);
}
以上可以看出来,InheritableThreadLocal是在new线程的时候才能起作用,但是如果使用了线程池,线程会复用,这样,在复用的线程中,还是无法正确获取到线程上下文的值。再来段示例:
public class InheritableThreadLocalLimitDemo{
static ThreadLocal<String> threadLocal = new InheritableThreadLocal<>();
static ExecutorService executorService
= Executors.newSingleThreadExecutor());
/**
* 执行结果:
* main--->p123
* first creat new thread--->p123
* then reuse thread--->p123
*/
public static void main(String[] args) {
threadLocal.set("p123");
System.out.println("main--->" + threadLocal.get());
executorService.submit(new Thread(() -> {
System.out.println("first creat new thread--->" + threadLocal.get());
}));
threadLocal.set("t456");
executorService.submit(new Thread(() -> {
System.out.println("then reuse thread--->" + threadLocal.get());
}));
try {
Thread.sleep(1000);
} catch (InterruptedException e) {
e.printStackTrace();
}
System.exit(0);
}
}
可以看到,当复用线程时,无法获取到父线程设置的"t456"这个值,仍然是第一次创建的"p123"这个值,即:在线程池中无法继承父线程的线程上下文。
这个问题我们在下一节看看如何解决。
3.TransmittableThreadLocal
我们先自己分析下,这个问题出现在哪,以及可以怎么解决。首先,问题是,在复用线程的时候,当前线程未继承父线程的InheritableThreadLocal值(因为复用,不涉及到new线程操作了)。知道问题了,那么如果要继承上一个线程的InheritableThreadLocal值,我们是不是可以通过某个折中办法,把上一个线程的值传进来。时机呢?看来只有在定义业务逻辑的时候,把local值保存起来,在执行业务的逻辑之前,先利用保存的local值重置下threadlocal值。业务开发中,引入一个中间层,往往能解决很多问题。因此,针对原来定义一个Thread的逻辑,我们引入一个新类,保存父线程的值,并重写run方法,在该方法中先重置ThreadLocal值,然后再调用业务逻辑。
下面我们看看这样实现是否有用呢。
public abstract class InheritableTask implements Runnable {
/**
* 中间值
**/
private Object inheritableThreadLocalsObj;
/**
* 保存父线程threadlocal的值
*/
public InheritableTask(){
try{
Thread currentThread = Thread.currentThread();
Field inheritableThreadLocalsField = Thread.class.getDeclaredField("inheritableThreadLocals");
inheritableThreadLocalsField.setAccessible(true);
Object threadLocalMapObj = inheritableThreadLocalsField.get(currentThread);
if(threadLocalMapObj != null){
// 调用ThreadLocal中的createInheritedMap方法,重新复制一个新的inheritableThreadLocals值
Class threadLocalMapClazz = inheritableThreadLocalsField.getType();
Method method = ThreadLocal.class.getDeclaredMethod("createInheritedMap",threadLocalMapClazz);
method.setAccessible(true);
// 创建一个新的ThreadLocalMap类型的inheritableThreadLocals
Object newThreadLocalMap = method.invoke(ThreadLocal.class,threadLocalMapObj);
// 将这个值暂存下来
inheritableThreadLocalsObj = newThreadLocalMap;
}
}catch (Exception e){
throw new IllegalStateException(e);
}
}
/**
* 搞个代理方法,这个方法中处理业务逻辑
*/
public abstract void runTask();
@Override
public void run() {
// 此处得到的是当前处理该业务的线程,也就是线程池中的线程
Thread currentThread = Thread.currentThread();
Field field = null;
try {
field = Thread.class.getDeclaredField("inheritableThreadLocals");
field.setAccessible(true);
// 将暂存的值,赋值给currentThread
if (inheritableThreadLocalsObj != null && field != null) {
field.set(currentThread, inheritableThreadLocalsObj);
inheritableThreadLocalsObj = null;
}
// 执行任务
runTask();
}catch (Exception e){
throw new IllegalStateException(e);
}finally {
// 最后将线程中的InheritableThreadLocals设置为null
try{
field.set(currentThread,null);
}catch (Exception e){
throw new IllegalStateException(e);
}
}
}
}
接下来看下使用效果
public class TransmittableThreadLocalWrapper {
static ThreadLocal<String> threadLocal = new InheritableThreadLocal<>();
static ExecutorService executorService
= Executors.newSingleThreadExecutor();
/**
* 执行结果:
* main--->p123
* first creat new thread--->p123
* then reuse thread--->t456
*/
public static void main(String[] args) {
threadLocal.set("p123");
System.out.println("main--->" + threadLocal.get());
executorService.submit(new Thread(() -> {
System.out.println("first creat new thread--->" + threadLocal.get());
}));
threadLocal.set("t456");
executorService.submit(new InheritableTask() {
@Override
public void runTask() {
System.out.println("then reuse thread--->" + threadLocal.get());
}
});
try {
Thread.sleep(1000);
} catch (InterruptedException e) {
e.printStackTrace();
}
System.exit(0);
}
}
问题确实解决了。
不过这里面还有很多问题不够优雅,一是反射耗性能,二是还要去实现类似InheritableTask这种具有侵入性的逻辑。所以,这里重点引入TransmittableThreadLocal(github地址l
)。具体用法可以详细参考官网。这里简单总结下:1)可以通过修饰Runnable和Callable; 2)可以通过修饰线程池;3)如果希望降低侵入性,还可以通过agent的方式修饰jdk线程池实现类。
public class TransmittableThreadLocalWrapper {
static ThreadLocal<String> threadLocal = new TransmittableThreadLocal<>();
static ExecutorService executorService
= TtlExecutors.getTtlExecutorService(Executors.newSingleThreadExecutor());
/**
* 执行结果:
* main--->p123
* first creat new thread--->p123
* then reuse thread--->t456
*/
public static void main(String[] args) {
threadLocal.set("p123");
System.out.println("main--->" + threadLocal.get());
executorService.submit(new Thread(() -> {
System.out.println("first creat new thread--->" + threadLocal.get());
}));
threadLocal.set("t456");
executorService.submit(new Thread(() -> {
System.out.println("then reuse thread--->" + threadLocal.get());
}));
try {
Thread.sleep(1000);
} catch (InterruptedException e) {
e.printStackTrace();
}
System.exit(0);
}
}
我们下面重点介绍下其实现原理。
根据二八原理,首先尝试介绍下最重要的百分之二十的实现思想。
最核心思想其实还是类似:在new一个线程的时候,获取父线程的所有TransmittableThreadLocal的值,在执行子线程run方法的时候,将父线程的值拷贝到子线程当中。
记住以下几个重要的类(或内部类、字段):
a.TransmittableThreadLocal.Transmitter
负责父子线程TransmittableThreadLocal的录制回放工作,最核心的一个类了。
b.TransmittableThreadLocal.Transmitter.Snapshot
顾名思义,只是一个对象载体,保存父线程/子线程的所有ThreadLocal的值。
c.TransmittableThreadLocal#holder:private static InheritableThreadLocal<WeakHashMap<TransmittableThreadLocal<Object>, ?>>
TransmittableThreadLocal最核心的字段,正是有该字段,所有传递机制才得以存在。
记录了当前线程都包含了哪些ThreadLocal对象,不然在录制父线程的threadlocal值时,哪能知道用户都定义了哪些。
相信通过以上,已经大致知道TransmittableThreadLocal的原理了,下面再看以下几个方法片段进一步理解下:
a.
/**
* see {@link InheritableThreadLocal#set}
*/
@Override
public final void set(T value) {
if (!disableIgnoreNullValueSemantics && null == value) {
// may set null to remove value
remove();
} else {
super.set(value);
// 核心思想方法
addThisToHolder();
}
}
b.
@NonNull
public static Object capture() {
return new Snapshot(captureTtlValues(), captureThreadLocalValues());
}
private static WeakHashMap<TransmittableThreadLocal<Object>, Object> captureTtlValues() {
WeakHashMap<TransmittableThreadLocal<Object>, Object> ttl2Value = new WeakHashMap<TransmittableThreadLocal<Object>, Object>();
for (TransmittableThreadLocal<Object> threadLocal : holder.get().keySet()) {
ttl2Value.put(threadLocal, threadLocal.copyValue());
}
return ttl2Value;
}
c.
@NonNull
public static Object replay(@NonNull Object captured) {
final Snapshot capturedSnapshot = (Snapshot) captured;
return new Snapshot(replayTtlValues(capturedSnapshot.ttl2Value), replayThreadLocalValues(capturedSnapshot.threadLocal2Value));
}
@NonNull
private static WeakHashMap<TransmittableThreadLocal<Object>, Object> replayTtlValues(@NonNull WeakHashMap<TransmittableThreadLocal<Object>, Object> captured) {
WeakHashMap<TransmittableThreadLocal<Object>, Object> backup = new WeakHashMap<TransmittableThreadLocal<Object>, Object>();
for (final Iterator<TransmittableThreadLocal<Object>> iterator = holder.get().keySet().iterator(); iterator.hasNext(); ) {
TransmittableThreadLocal<Object> threadLocal = iterator.next();
// backup
backup.put(threadLocal, threadLocal.get());
// clear the TTL values that is not in captured
// avoid the extra TTL values after replay when run task
if (!captured.containsKey(threadLocal)) {
iterator.remove();
threadLocal.superRemove();
}
}
// set TTL values to captured
setTtlValuesTo(captured);
// call beforeExecute callback
doExecuteCallback(true);
return backup;
}
d.
private TtlRunnable(@NonNull Runnable runnable, boolean releaseTtlValueReferenceAfterRun) {
this.capturedRef = new AtomicReference<Object>(capture());
this.runnable = runnable;
this.releaseTtlValueReferenceAfterRun = releaseTtlValueReferenceAfterRun;
}
/**
* wrap method {@link Runnable#run()}.
*/
@Override
public void run() {
Object captured = capturedRef.get();
if (captured == null || releaseTtlValueReferenceAfterRun && !capturedRef.compareAndSet(captured, null)) {
throw new IllegalStateException("TTL value reference is released after run!");
}
Object backup = replay(captured);
try {
runnable.run();
} finally {
restore(backup);
}
}
通过以上分析,相信应该能大致掌握TransmittableThreadLocal的原理了吧。最后放一张时序图,再进一步加深下印象。
-
TransmittableThreadLocal时序图
image.png
- 参考
Runnable改造