在多线程环境下,解决线程安全问题的方案主要分三个方案:互斥同步(如Syncrhonize)、非阻塞同步(如CAS)、线程封闭(局部变量、ThreadLocal)。本文将讨论其中的ThreadLocal的实现原理。
1. ThreadLocal是什么 ThreadLocal是一个线程级别的变量,主要用于多线程环境下,无需线程间共享的变量。当每个线程都维护属于自己线程的变量,线程间是隔离的,那么也就彻底消除了线程间的竞争,消除了线程安全问题。
ThreadLocal的使用方法很简单,如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 public static void main (String[] args) throws InterruptedException { ThreadLocal<Integer> threadLocal = new ThreadLocal <>(); ExecutorService service = Executors.newSingleThreadExecutor(); service.execute(() -> { threadLocal.set(666 ); log.info("set: {}" , 666 ); }); Thread.sleep(1000 ); service.execute(() -> { log.info("get: {}" , threadLocal.get()); }); }
2. ThreadLocal整体结构 首先来看一下整体的结构:
通过上图的结构可见,每个Thread都维护了一个ThreadLocalMap类型的变量,这个变量维护了一个Entry数组,其中存的就是ThreadLocal变量,这样就实现了不同线程之间的隔离。
以上只是对于ThreadLocal的一个初步认识,接下来将对其具体原理进行讨论。
3. 原理分析 3.1 ThreadLocalMap是什么 首先来看看ThreadLocalMap的结构:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 public class ThreadLocal <T> { static class ThreadLocalMap { static class Entry extends WeakReference <ThreadLocal<?>> { Object value; Entry(ThreadLocal<?> k, Object v) { super (k); value = v; } } private static final int INITIAL_CAPACITY = 16 ; private Entry[] table; private int size = 0 ; private int threshold; } }
可见它是ThreadLocal的静态内部类,且ThreadLocalMap还有一个Entry静态内部类。 在ThreadLocalMap中,它主要维护了private Entry[] table这个Entry数组。Entry类是一个比较特殊的类,它继承了弱引用。其中,将key做为弱引用value依然是正常的强引用。
如果不了解弱引用可以查看[[JVM中的引用]]这篇文章。
至于这里为什么要这么做,后面在讨论。接下来我们直接看ThreadLocal的存取的流程。
3.2 set()方法 直接看源码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 public void set (T value) { Thread t = Thread.currentThread(); ThreadLocalMap map = getMap(t); if (map != null ) map.set(this , value); else createMap(t, value); } ThreadLocalMap getMap (Thread t) { return t.threadLocals; } void createMap (Thread t, T firstValue) { t.threadLocals = new ThreadLocalMap (this , firstValue); } ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) { table = new Entry [INITIAL_CAPACITY]; int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1 ); table[i] = new Entry (firstKey, firstValue); size = 1 ; setThreshold(INITIAL_CAPACITY); } private void set (ThreadLocal<?> key, Object value) { Entry[] tab = table; int len = tab.length; int i = key.threadLocalHashCode & (len-1 ); for (Entry e = tab[i]; e != null ; e = tab[i = nextIndex(i, len)]) { ThreadLocal<?> k = e.get(); if (k == key) { e.value = value; return ; } if (k == null ) { replaceStaleEntry(key, value, i); return ; } } tab[i] = new Entry (key, value); int sz = ++size; if (!cleanSomeSlots(i, sz) && sz >= threshold) rehash(); } private void rehash () { expungeStaleEntries(); if (size >= threshold - threshold / 4 ) resize(); } private void expungeStaleEntries () { Entry[] tab = table; int len = tab.length; for (int j = 0 ; j < len; j++) { Entry e = tab[j]; if (e != null && e.get() == null ) expungeStaleEntry(j); } }
通过以上代码,可以了解set的整体流程。简单来说,其实就是通过hash算法拿到响应的ThreadLocal的对应的Entry的槽位,然后将值放到对应的槽位即可。其中夹杂了一些判断是否过期、清理过期Entry、处理槽位不够的情况等细节。
接下来再继续看上面代码分析中略过的replaceStaleEntry(key, value, i)的细节:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 private void replaceStaleEntry (ThreadLocal<?> key, Object value, int staleSlot) { ThreadLocal.ThreadLocalMap.Entry[] tab = table; int len = tab.length; ThreadLocal.ThreadLocalMap.Entry e; int slotToExpunge = staleSlot; for (int i = prevIndex(staleSlot, len); (e = tab[i]) != null ; i = prevIndex(i, len)) if (e.get() == null ) slotToExpunge = i; for (int i = nextIndex(staleSlot, len); (e = tab[i]) != null ; i = nextIndex(i, len)) { ThreadLocal<?> k = e.get(); if (k == key) { e.value = value; tab[i] = tab[staleSlot]; tab[staleSlot] = e; if (slotToExpunge == staleSlot) slotToExpunge = i; cleanSomeSlots(expungeStaleEntry(slotToExpunge), len); return ; } if (k == null && slotToExpunge == staleSlot) slotToExpunge = i; } tab[staleSlot].value = null ; tab[staleSlot] = new ThreadLocal .ThreadLocalMap.Entry(key, value); if (slotToExpunge != staleSlot) cleanSomeSlots(expungeStaleEntry(slotToExpunge), len); }
可以看出,replaceStaleEntry(key, value, i)方法中,先向前找失效entry,然后找是否有和key对应的entry。此方法不仅仅插入了新的entry,还顺便清理了一下失效的entry。
在讨论上面代码时,其中的清理失效entry的代码当时没有详细介绍,现在我们来详细看看清理失效entry的代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 private int expungeStaleEntry (int staleSlot) { Entry[] tab = table; int len = tab.length; tab[staleSlot].value = null ; tab[staleSlot] = null ; size--; Entry e; int i; for (i = nextIndex(staleSlot, len); (e = tab[i]) != null ; i = nextIndex(i, len)) { ThreadLocal<?> k = e.get(); if (k == null ) { e.value = null ; tab[i] = null ; size--; } else { int h = k.threadLocalHashCode & (len - 1 ); if (h != i) { tab[i] = null ; while (tab[h] != null ) h = nextIndex(h, len); tab[h] = e; } } } return i; } private boolean cleanSomeSlots (int i, int n) { boolean removed = false ; Entry[] tab = table; int len = tab.length; do { i = nextIndex(i, len); Entry e = tab[i]; if (e != null && e.get() == null ) { n = len; removed = true ; i = expungeStaleEntry(i); } } while ( (n >>>= 1 ) != 0 ); return removed; }
3.3 get()方法 以上我们讨论了set方法的整体流程以及一些清除过期entry的细节。接下来我们讨论一下get()的流程。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 public T get () { Thread t = Thread.currentThread(); ThreadLocalMap map = getMap(t); if (map != null ) { ThreadLocalMap.Entry e = map.getEntry(this ); if (e != null ) { @SuppressWarnings("unchecked") T result = (T)e.value; return result; } } return setInitialValue(); } ThreadLocalMap getMap (Thread t) { return t.threadLocals; } private Entry getEntry (ThreadLocal<?> key) { int i = key.threadLocalHashCode & (table.length - 1 ); Entry e = table[i]; if (e != null && e.get() == key) return e; else return getEntryAfterMiss(key, i, e); } private Entry getEntryAfterMiss (ThreadLocal<?> key, int i, Entry e) { Entry[] tab = table; int len = tab.length; while (e != null ) { ThreadLocal<?> k = e.get(); if (k == key) return e; if (k == null ) expungeStaleEntry(i); else i = nextIndex(i, len); e = tab[i]; } return null ; }
现在我们再回到主线:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 public T get () { Thread t = Thread.currentThread(); ThreadLocalMap map = getMap(t); if (map != null ) { ThreadLocalMap.Entry e = map.getEntry(this ); if (e != null ) { @SuppressWarnings("unchecked") T result = (T)e.value; return result; } } return setInitialValue(); } private T setInitialValue () { T value = initialValue(); Thread t = Thread.currentThread(); ThreadLocalMap map = getMap(t); if (map != null ) map.set(this , value); else createMap(t, value); return value; } protected T initialValue () { return null ; }
以上代码因为比较简单,没有写太详细的注释。在setInitialValue()方法中,只需要重点关注initialValue()即可。因为其他部分都在前面的讨论中见过。
initialValue()这个方法非常简单,直接返回一个null。但是我们需要注意,它的权限修饰符是protected,因此我们可以根据自己的业务需求通过重写这个方法。如:
1 2 3 4 5 6 ThreadLocal<Object> tl = new ThreadLocal <Object>(){ @Override protected Object initialValue () { return super .initialValue(); } };
以上就是整个get的流程。
3.4 remove()方法 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 public void remove () { ThreadLocalMap m = getMap(Thread.currentThread()); if (m != null ) m.remove(this ); } private void remove (ThreadLocal<?> key) { Entry[] tab = table; int len = tab.length; int i = key.threadLocalHashCode & (len-1 ); for (Entry e = tab[i]; e != null ; e = tab[i = nextIndex(i, len)]) { if (e.get() == key) { e.clear(); expungeStaleEntry(i); return ; } } }
这里有一个问题:找到相应的entry后,为什么调用了e.clear()后还要调用expungeStaleEntry(i)来删除呢?
这是因ThreadLocalMap中的table是用的线性探测法插入的,如果简单的置空可能损坏hash表的完整性,get的时候发现没前面的entry为null,就理解为没有hash冲突,直接返回null了。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 private Entry getEntryAfterMiss (ThreadLocal<?> key, int i, Entry e) { Entry[] tab = table; int len = tab.length; while (e != null ) { ThreadLocal<?> k = e.get(); if (k == key) return e; if (k == null ) expungeStaleEntry(i); else i = nextIndex(i, len); e = tab[i]; } return null ; }
3.5 总结 通过对于ThreadLocal的几个重要的方法的分析。可见整体的逻辑不难,只是在其中穿插了很多清除过期Entry的逻辑。
那是因为Entry是弱引用,当没有强引用指向的时候就会在下一次GC时将其回收。但是Entry中的Value并不是强引用,如果回收了entry中的key,但是value依然被强引用指向,就会有内存泄漏的风险。因此,需要尽可能地及时清理掉。
那么都在什么时候可能会清理过期entry呢?
在设置新值时(set 方法)
在获取值时(get 方法)
在删除值时(remove 方法)
在重新哈希时(rehash 方法)
注意:虽然尽可能地在回收过期Entry,但还是有一定可能性出现内存泄漏的问题,因此要养成一个随手remove()用完的ThreadLocal的习惯。
4. ThreadLocal的使用场景 ThreadLocal的使用场景比较多,比如:
4. 1 数据库连接管理 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 public class ConnectionManager { private static ThreadLocal<Connection> connectionHolder = ThreadLocal.withInitial(() -> { return createNewConnection(); }); public static Connection getConnection () { return connectionHolder.get(); } public static void closeConnection () { Connection conn = connectionHolder.get(); if (conn != null ) { conn.close(); connectionHolder.remove(); } } }
4.2 事务管理 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 public class TransactionManager { private static ThreadLocal<Transaction> transactionHolder = new ThreadLocal <>(); public static void beginTransaction () { Transaction tx = new Transaction (); transactionHolder.set(tx); tx.begin(); } public static void commitTransaction () { Transaction tx = transactionHolder.get(); if (tx != null ) { tx.commit(); transactionHolder.remove(); } } public static void rollbackTransaction () { Transaction tx = transactionHolder.get(); if (tx != null ) { tx.rollback(); transactionHolder.remove(); } } public static Transaction getTransaction () { return transactionHolder.get(); } }
4.3 用户会话管理 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 public class UserContext { private static ThreadLocal<User> userHolder = new ThreadLocal <>(); public static void setUser (User user) { userHolder.set(user); } public static User getUser () { return userHolder.get(); } public static void clear () { userHolder.remove(); } }
4.4 日志跟踪链路 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 public class RequestContext { private static ThreadLocal<String> requestIdHolder = new ThreadLocal <>(); public static void setRequestId (String requestId) { requestIdHolder.set(requestId); } public static String getRequestId () { return requestIdHolder.get(); } public static void clear () { requestIdHolder.remove(); } }
……