在多线程环境下,解决线程安全问题的方案主要分三个方案:互斥同步(如Syncrhonize)、非阻塞同步(如CAS)、线程封闭(局部变量、ThreadLocal)。本文将讨论其中的ThreadLocal的实现原理。

1. ThreadLocal是什么

ThreadLocal是一个线程级别的变量,主要用于多线程环境下,无需线程间共享的变量。当每个线程都维护属于自己线程的变量,线程间是隔离的,那么也就彻底消除了线程间的竞争,消除了线程安全问题。

ThreadLocal的使用方法很简单,如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
public static void main(String[] args) throws InterruptedException {  
ThreadLocal<Integer> threadLocal = new ThreadLocal<>();

ExecutorService service = Executors.newSingleThreadExecutor();

service.execute(() -> {
threadLocal.set(666);
log.info("set: {}", 666);
});


Thread.sleep(1000);

service.execute(() -> {
log.info("get: {}", threadLocal.get());
});
}

2. ThreadLocal整体结构

首先来看一下整体的结构:

通过上图的结构可见,每个Thread都维护了一个ThreadLocalMap类型的变量,这个变量维护了一个Entry数组,其中存的就是ThreadLocal变量,这样就实现了不同线程之间的隔离。

以上只是对于ThreadLocal的一个初步认识,接下来将对其具体原理进行讨论。

3. 原理分析

3.1 ThreadLocalMap是什么

首先来看看ThreadLocalMap的结构:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
public class ThreadLocal<T> {
static class ThreadLocalMap {
static class Entry extends WeakReference<ThreadLocal<?>> {
/** The value associated with this ThreadLocal. */
Object value;

Entry(ThreadLocal<?> k, Object v) {
super(k);
value = v;
}
}

private static final int INITIAL_CAPACITY = 16;
private Entry[] table;

private int size = 0;
private int threshold;

// ......
}

// ......
}

可见它是ThreadLocal的静态内部类,且ThreadLocalMap还有一个Entry静态内部类。
在ThreadLocalMap中,它主要维护了private Entry[] table这个Entry数组。Entry类是一个比较特殊的类,它继承了弱引用。其中,将key做为弱引用value依然是正常的强引用。

如果不了解弱引用可以查看[[JVM中的引用]]这篇文章。

至于这里为什么要这么做,后面在讨论。接下来我们直接看ThreadLocal的存取的流程。

3.2 set()方法

直接看源码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98

// ThreadLocal
public void set(T value) {
// 首先拿到当前的线程
Thread t = Thread.currentThread();

// 然后拿到当前线程所维护的ThradLocalMap
ThreadLocalMap map = getMap(t);

if (map != null)
map.set(this, value);
else
createMap(t, value);
}

ThreadLocalMap getMap(Thread t) {
return t.threadLocals;
}


// ThreadLocalMap是懒加载的,在第一次set时才会创建
void createMap(Thread t, T firstValue) {
t.threadLocals = new ThreadLocalMap(this, firstValue);
}

ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
table = new Entry[INITIAL_CAPACITY];
int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
table[i] = new Entry(firstKey, firstValue);
size = 1;
setThreshold(INITIAL_CAPACITY);
}

// ThreadLocalMap
private void set(ThreadLocal<?> key, Object value) {
// table就是维护的那个Entry数组
Entry[] tab = table;
int len = tab.length;
// 通过hash定位数组下标
int i = key.threadLocalHashCode & (len-1);

// for循环就是为了处理可能出现hash冲突的情况
for (Entry e = tab[i];
e != null;
e = tab[i = nextIndex(i, len)]) {
ThreadLocal<?> k = e.get();

// 说明之前set过,那么重新赋value即可
if (k == key) {
e.value = value;
return;
}

/**
* 思考:这里的k为什么会为null呢?
* 上面说过Entry的实现是将key作为若引用。当此key没有了强引用的话,将会在下一次GC时被回收掉。
* 因此这里的 k == null的情况其实就是这个key过期的情况。
*/
if (k == null) {
// 如果这个key过期了,那么就用新的entry替换老的entry。
//(这里暂时先走整体流程,后面再讨论具体细节)
replaceStaleEntry(key, value, i);
return;
}

// 能走到这里就说明,即没有k即不是key也没有过期,那就说明哈希冲突了,进行重哈希。
}

// 如果没有找到已有的key,那么将entry放到槽位i中
tab[i] = new Entry(key, value);
int sz = ++size;

// 如果没有清除任何过期entry 且大小超过了门槛值
if (!cleanSomeSlots(i, sz) && sz >= threshold)
// 重新哈希
rehash();
}

private void rehash() {
// 清除
expungeStaleEntries();

// 使用较低的阈值加倍,避免迟滞
if (size >= threshold - threshold / 4)
// 大小加倍
resize();
}

// 清除过期的entry
private void expungeStaleEntries() {
Entry[] tab = table;
int len = tab.length;
for (int j = 0; j < len; j++) {
Entry e = tab[j];
if (e != null && e.get() == null)
expungeStaleEntry(j);
}
}

通过以上代码,可以了解set的整体流程。简单来说,其实就是通过hash算法拿到响应的ThreadLocal的对应的Entry的槽位,然后将值放到对应的槽位即可。其中夹杂了一些判断是否过期、清理过期Entry、处理槽位不够的情况等细节。

接下来再继续看上面代码分析中略过的replaceStaleEntry(key, value, i)的细节:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
private void replaceStaleEntry(ThreadLocal<?> key, Object value, int staleSlot) {  
ThreadLocal.ThreadLocalMap.Entry[] tab = table;
int len = tab.length;
ThreadLocal.ThreadLocalMap.Entry e;

// slotToExpunge用于记录上一个待清楚的过期entry的槽位下标
int slotToExpunge = staleSlot;
// 向前搜索,找到过期的entry槽位,则记录到slotToExpunge中
for (int i = prevIndex(staleSlot, len);
(e = tab[i]) != null;
i = prevIndex(i, len))
if (e.get() == null)
slotToExpunge = i;

// 从staleSlot向后搜索
for (int i = nextIndex(staleSlot, len);
(e = tab[i]) != null;
i = nextIndex(i, len)) {
ThreadLocal<?> k = e.get();


if (k == key) {
// 如果k就是key对应的槽,则更新 value
e.value = value;
tab[i] = tab[staleSlot];
tab[staleSlot] = e;

// 如果slotToExpunge 依然是staleSlot,那就更新到当前的i。
// (因为依然是staleSlot已经更新成有效的entry了)
if (slotToExpunge == staleSlot)
slotToExpunge = i;
// 然后清理过期entry。
// (后面详细介绍)
cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
return;
}

// 走到这里说明k != key
// 如果当前的entry是过期entry,且staleSlot之前没有过期的槽
// 那么把即将被清除的过期槽位更新为i,因为staleSlot可能后面还要用用于存储Entry
if (k == null && slotToExpunge == staleSlot)
slotToExpunge = i;
}

// 走到这里说明:在staleSlot后面没有找到对应的key的槽位
// 那么就在staleSlot这里实例化一个Entry
tab[staleSlot].value = null;
tab[staleSlot] = new ThreadLocal.ThreadLocalMap.Entry(key, value);

// 防止把已经实例化的entry给清理了
if (slotToExpunge != staleSlot)
cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
}

可以看出,replaceStaleEntry(key, value, i)方法中,先向前找失效entry,然后找是否有和key对应的entry。此方法不仅仅插入了新的entry,还顺便清理了一下失效的entry。

在讨论上面代码时,其中的清理失效entry的代码当时没有详细介绍,现在我们来详细看看清理失效entry的代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
// 删除过期的Entry
private int expungeStaleEntry(int staleSlot) {
Entry[] tab = table;
int len = tab.length;

// 将staleSlot对应槽位的value和key都置空,以供GC
tab[staleSlot].value = null;
tab[staleSlot] = null;
size--;

// Rehash until we encounter null
// 翻译:重新哈希直到遇到空值
Entry e;
int i;
for (i = nextIndex(staleSlot, len);
(e = tab[i]) != null;
i = nextIndex(i, len)) {
ThreadLocal<?> k = e.get();
if (k == null) {
// 如果k时null,则说明此entry过期,直接清理。
e.value = null;
tab[i] = null;
size--;
} else {
// k不为空,则对其进行重新哈希
int h = k.threadLocalHashCode & (len - 1);
if (h != i) {
tab[i] = null;

// Unlike Knuth 6.4 Algorithm R, we must scan until
// null because multiple entries could have been stale.
while (tab[h] != null)
h = nextIndex(h, len);
tab[h] = e;
}
}
}

// 返回的是staleSlot后,下一个空槽的下标
return i;
}


// 返回是否清除entry
private boolean cleanSomeSlots(int i, int n) {
boolean removed = false;
Entry[] tab = table;
int len = tab.length;
do {
i = nextIndex(i, len);
Entry e = tab[i];

// 如果发现了过期的Entry
if (e != null && e.get() == null) {
// n置为table的长度
n = len;
removed = true;
// 清除此过期entry
i = expungeStaleEntry(i);
}
} while ( (n >>>= 1) != 0);
return removed;
}

3.3 get()方法

以上我们讨论了set方法的整体流程以及一些清除过期entry的细节。接下来我们讨论一下get()的流程。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
public T get() {  
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
if (map != null) {
// 从TThreadLocalMap中拿到当前的TheadLocal的Entry
ThreadLocalMap.Entry e = map.getEntry(this);
if (e != null) {
@SuppressWarnings("unchecked")
T result = (T)e.value;
return result;
}
}

// 走到这里,有两种可能:
// 1、map == null: 因为ThreadLocalMap是懒加载的,如果从来没有set过,那么就没有
// 2、map不为空但是map中没有相应的entry
// 那么久返回默认的初始化value
return setInitialValue();
}

ThreadLocalMap getMap(Thread t) {
return t.threadLocals;
}

private Entry getEntry(ThreadLocal<?> key) {
int i = key.threadLocalHashCode & (table.length - 1);
Entry e = table[i];
// 如果没有哈希冲突,顺利找到,直接返回
if (e != null && e.get() == key)
return e;
else
return getEntryAfterMiss(key, i, e);
}

private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
Entry[] tab = table;
int len = tab.length;

// 向后遍历,找到符合条件的entry
while (e != null) {
ThreadLocal<?> k = e.get();
if (k == key)
return e;
if (k == null)
// 顺便把过期的entry清除了
expungeStaleEntry(i);
else
i = nextIndex(i, len);
e = tab[i];
}

// 能走到这里说明没找到。(有可能根本就没有set过,也有可能过期后被清除了)
return null;
}

现在我们再回到主线:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
public T get() {  
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
if (map != null) {
ThreadLocalMap.Entry e = map.getEntry(this);
if (e != null) {
@SuppressWarnings("unchecked")
T result = (T)e.value;
return result;
}
}

// 现在来分析这部分
return setInitialValue();
}


private T setInitialValue() {
// 初始化一个默认值
T value = initialValue();

Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);

if (map != null)
map.set(this, value);
else
createMap(t, value);
return value;
}

protected T initialValue() {
return null;
}

以上代码因为比较简单,没有写太详细的注释。在setInitialValue()方法中,只需要重点关注initialValue()即可。因为其他部分都在前面的讨论中见过。

initialValue()这个方法非常简单,直接返回一个null。但是我们需要注意,它的权限修饰符是protected,因此我们可以根据自己的业务需求通过重写这个方法。如:

1
2
3
4
5
6
ThreadLocal<Object> tl = new ThreadLocal<Object>(){  
@Override
protected Object initialValue() {
return super.initialValue();
}
};

以上就是整个get的流程。

3.4 remove()方法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
public void remove() {  
ThreadLocalMap m = getMap(Thread.currentThread());
if (m != null)
m.remove(this);
}


private void remove(ThreadLocal<?> key) {
Entry[] tab = table;
int len = tab.length;
int i = key.threadLocalHashCode & (len-1);

for (Entry e = tab[i];
e != null;
e = tab[i = nextIndex(i, len)]) {
if (e.get() == key) {
// 找到相应的entry,清除引用
e.clear();
// 删除对应的entry
expungeStaleEntry(i);
return;
}
}
}

这里有一个问题:找到相应的entry后,为什么调用了e.clear()后还要调用expungeStaleEntry(i)来删除呢?

这是因ThreadLocalMap中的table是用的线性探测法插入的,如果简单的置空可能损坏hash表的完整性,get的时候发现没前面的entry为null,就理解为没有hash冲突,直接返回null了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {  
Entry[] tab = table;
int len = tab.length;

// 如果e为空,那么就理解为没有相应的entry了
while (e != null) {
ThreadLocal<?> k = e.get();
if (k == key)
return e;
if (k == null)
expungeStaleEntry(i);
else
i = nextIndex(i, len);
e = tab[i];
}
return null;
}

3.5 总结

通过对于ThreadLocal的几个重要的方法的分析。可见整体的逻辑不难,只是在其中穿插了很多清除过期Entry的逻辑。

那是因为Entry是弱引用,当没有强引用指向的时候就会在下一次GC时将其回收。但是Entry中的Value并不是强引用,如果回收了entry中的key,但是value依然被强引用指向,就会有内存泄漏的风险。因此,需要尽可能地及时清理掉。

那么都在什么时候可能会清理过期entry呢?

  • 在设置新值时(set 方法)
  • 在获取值时(get 方法)
  • 在删除值时(remove 方法)
  • 在重新哈希时(rehash 方法)

注意:虽然尽可能地在回收过期Entry,但还是有一定可能性出现内存泄漏的问题,因此要养成一个随手remove()用完的ThreadLocal的习惯。

4. ThreadLocal的使用场景

ThreadLocal的使用场景比较多,比如:

4. 1 数据库连接管理

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
public class ConnectionManager {
private static ThreadLocal<Connection> connectionHolder = ThreadLocal.withInitial(() -> {
// Create and return a new database connection
return createNewConnection();
});

public static Connection getConnection() {
return connectionHolder.get();
}

public static void closeConnection() {
Connection conn = connectionHolder.get();
if (conn != null) {
conn.close();
connectionHolder.remove();
}
}
}

4.2 事务管理

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
public class TransactionManager {
private static ThreadLocal<Transaction> transactionHolder = new ThreadLocal<>();

public static void beginTransaction() {
Transaction tx = new Transaction();
transactionHolder.set(tx);
tx.begin();
}

public static void commitTransaction() {
Transaction tx = transactionHolder.get();
if (tx != null) {
tx.commit();
transactionHolder.remove();
}
}

public static void rollbackTransaction() {
Transaction tx = transactionHolder.get();
if (tx != null) {
tx.rollback();
transactionHolder.remove();
}
}

public static Transaction getTransaction() {
return transactionHolder.get();
}
}

4.3 用户会话管理

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
public class UserContext {
private static ThreadLocal<User> userHolder = new ThreadLocal<>();

public static void setUser(User user) {
userHolder.set(user);
}

public static User getUser() {
return userHolder.get();
}

public static void clear() {
userHolder.remove();
}
}

4.4 日志跟踪链路

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
public class RequestContext {
private static ThreadLocal<String> requestIdHolder = new ThreadLocal<>();

public static void setRequestId(String requestId) {
requestIdHolder.set(requestId);
}

public static String getRequestId() {
return requestIdHolder.get();
}

public static void clear() {
requestIdHolder.remove();
}
}

……