天天看點

ThreadLocal、InheritableThreadLocal和ThreadLocalMap源碼解析

1.ThreadLocal作用

  • 作用:為變量線上程中都建立副本,線程可通路自己内部的副本變量。該類提供了線程局部 (thread-local) 變量,通路這個變量(通過其 get 或 set 方法)的每個線程都有自己的局部變量,它獨立于變量的初始化副本
  • 原理:每個線程都有一個ThreadLocalMap類型變量 threadLocals。ThreadLocal的set()會在threadLocals中儲存以ThreadLocal對象為key,以儲存的變量為value的值,get()會擷取該值

2.ThreadLocal繼承關系

ThreadLocal、InheritableThreadLocal和ThreadLocalMap源碼解析

3.源碼走讀

3.1.ThreadLocal.java

public class ThreadLocal<T> {
    //**每一個執行個體都有一個唯一的threadLocalHashCode,值為上一個執行個體的值加上0x61c88647
    //**作用是為了讓哈希碼能均勻的分布在2的N次方的數組裡
    private final int threadLocalHashCode = nextHashCode();
    private static AtomicInteger nextHashCode = new AtomicInteger();
    private static final int HASH_INCREMENT = 0x61c88647;
    private static int nextHashCode() {
        return nextHashCode.getAndAdd(HASH_INCREMENT);
    }

    //**傳回此線程局部變量的目前線程的“初始值”
    //**線程第一次使用get()方法時調用此方法,如果線程之前調用了set(T)方法,則不會對該線程再調用該方法
    //**通常,此方法對每個線程最多調用一次,但調用了remove(),則會再次調用此方法
    //**預設傳回null,如果希望傳回其它值,則須建立子類,并重寫此方法,通常将使用匿名内部類完成此操作 
    protected T initialValue() {
        return null;
    }

    //**在java8,使用函數式程式設計的方式設定并傳回目前線程變量的初始值,與上個方法功能相同
    //**示例:ThreadLocal<String> threadLocal = ThreadLocal.withInitial(() -> "test");
    public static <S> ThreadLocal<S> withInitial(Supplier<? extends S> supplier) {
        //**傳回SuppliedThreadLocal對象,SuppliedThreadLocal是ThreadLocal的子類,
        //**重寫的initialValue方法調用supplier的get方法做為目前線程變量的初始值
        return new SuppliedThreadLocal<>(supplier);
    }

    //**傳回此線程局部變量的目前線程副本中的值,如果變量沒有用于目前線程的值,則傳回initialValue()的值
    public T get() {
        //**擷取目前線程的執行個體
        Thread t = Thread.currentThread();
        //**擷取目前線程中的ThreadLocalMap類型變量threadLocals
        ThreadLocalMap map = getMap(t);
        if (map != null) {
            //**從threadLocals中擷取以this為key的Entry對象
            ThreadLocalMap.Entry e = map.getEntry(this);
            //**如果Entry對象不為空,則傳回它的value
            if (e != null) {
                @SuppressWarnings("unchecked")
                T result = (T)e.value;
                return result;
            }
        }
        //**如果threadLocals對象不為空或者Entry為空,則調用setInitialValue進行初始化
        return setInitialValue();
    }

    //**使用initialValue()的值初始化線程局部變量
    private T setInitialValue() {
        //**擷取線程局部變量的初始值,預設為null
        T value = initialValue();
        //**擷取目前線程的執行個體
        Thread t = Thread.currentThread();
        //**擷取目前線程中的ThreadLocalMap類型變量threadLocals
        ThreadLocalMap map = getMap(t);
        //**如果threadLocals不為空,設定以this為key,以value為值的Entry對象
        if (map != null)
            map.set(this, value);
        //**如果threadLocals為空,則進行初始化,并設定以this為key,以value為值的Entry對象
        else
            createMap(t, value);
        return value;
    }

    //**設定線程局部變量的值
    public void set(T value) {
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null)
            map.set(this, value);
        else
            createMap(t, value);
    }

    //**移除此線程局部變量目前線程的值,如果随後調用get()方法,且沒有調用set()設定值,則将調用initialValue()重新初始化值
     public void remove() {
         ThreadLocalMap m = getMap(Thread.currentThread());
         if (m != null)
             m.remove(this);
     }

    //**從線程執行個體中擷取threadLocals對象
    ThreadLocalMap getMap(Thread t) {
        return t.threadLocals;
    }

    //**初始化線程t的threadLocals對象,并設定以this為key,以firstValue為值的Entry對象
    void createMap(Thread t, T firstValue) {
        t.threadLocals = new ThreadLocalMap(this, firstValue);
    }

    //**根據主線程中的ThreadLocalMap對象建立子線程的ThreadLocalMap對象
    static ThreadLocalMap createInheritedMap(ThreadLocalMap parentMap) {
        return new ThreadLocalMap(parentMap);
    }

    //**ThreadLocal對象不支援,在InheritableThreadLocal中實作
    T childValue(T parentValue) {
        throw new UnsupportedOperationException();
    }
}

           

3.2.SuppliedThreadLocal.java

  • SuppliedThreadLocal是ThreadLoacl的靜态内部類,ThreadLocal的withInitial方法使用Supplier對象建立SuppliedThreadLocal對象
  • 作用是為了在java8,支援使用函數式程式設計的方式設定并傳回目前線程變量的初始值
static final class SuppliedThreadLocal<T> extends ThreadLocal<T> {

    private final Supplier<? extends T> supplier;

    SuppliedThreadLocal(Supplier<? extends T> supplier) {
        this.supplier = Objects.requireNonNull(supplier);
    }

    @Override
    //**重寫的initialValue方法,調用supplier的get方法做為目前線程變量的初始值
    protected T initialValue() {
        return supplier.get();
    }
}
           

3.3.InheritableThreadLocal.java

  • 作用:把父線程變量的值傳遞到子線程中。可通過重寫childValue方法,改變從父線程中擷取的值
public class InheritableThreadLocal<T> extends ThreadLocal<T> {
    //**父線程inheritableThreadLocals的值傳給子線程的inheritableThreadLocals的處理邏輯
    protected T childValue(T parentValue) {
        return parentValue;
    }
    //**重寫父類的方法,維護inheritableThreadLocals變量
    ThreadLocalMap getMap(Thread t) {
       return t.inheritableThreadLocals;
    }
    //**重寫父類的方法,維護inheritableThreadLocals變量
    void createMap(Thread t, T firstValue) {
        t.inheritableThreadLocals = new ThreadLocalMap(this, firstValue);
    }
}
           
  • 原理和解析:
    • 每個線程都還有另外一個ThreadLocalMap類型變量inheritableThreadLocals
    • InheritableThreadLocal重寫了getMap和createMap方法,維護的不在是threadLocals,而是inheritableThreadLocals
    • 當主線程建立一個子線程的時候,會判斷主線程的inheritableThreadLocals是否為空
    • 如果不為空,則會把inheritableThreadLocals的值傳給子線程的inheritableThreadLocals,傳送的邏輯是childValue實作的
    private void init(ThreadGroup g, Runnable target, String name,
                      long stackSize, AccessControlContext acc) {
        ...
        //**擷取主線程的執行個體
        Thread parent = currentThread();
        ...
        //**如果主線的inheritableThreadLocals不為空
        if (parent.inheritableThreadLocals != null)
            //**根據主線程的inheritableThreadLocals建立子線程的inheritableThreadLocals
            this.inheritableThreadLocals = ThreadLocal.createInheritedMap(parent.inheritableThreadLocals);
        ...
    }
               
  • 注意:
    • 因為傳送邏輯是在建立子線程的時候完成的,子線程建立後,主線程在修改InheritableThreadLocal變量的值,是無法傳給子線程的
    • 建立子線程完成後,原則上子線程和父線程中InheritableThreadLocal變量的值在沒有關聯,各自調用set/get/remove都隻影響本線程中的值
    • 如果InheritableThreadLocal變量的值是引用類型,通過get方法擷取到對象後,直接修改了該對象的屬性,則父線程和子線程都會受影響
    public class Test {
    
        private static List getList(String param) {
            List rst = new ArrayList<>();
            rst.add(param);
    
            return rst;
        }
    
        private static final InheritableThreadLocal<List> threadLocal = new InheritableThreadLocal<>();
    
        public static void test(Consumer<InheritableThreadLocal<List>> consumer) throws InterruptedException {
        threadLocal.set(getList("test"));
    
        Thread child = new Thread(() -> {
            while (!Thread.currentThread().isInterrupted()){
                try {
                    TimeUnit.MILLISECONDS.sleep(1);
                } catch (InterruptedException e) {
                    throw new RuntimeException(e);
                }
                System.out.println("子線程中threadLocal的值:" + threadLocal.get());
            }
        });
    
        System.out.println("主線程中threadLocal的值:" + threadLocal.get());
        child.start();
        TimeUnit.MILLISECONDS.sleep(1);
        consumer.accept(threadLocal);
        System.out.println("主線程中threadLocal的值:" + threadLocal.get());
        TimeUnit.MILLISECONDS.sleep(3);
        child.interrupt();
    }
    
        public static void main(String[] args) throws InterruptedException {
            //**建立子線程完成後,主線程調用set方法修改值,不會影響到子線程
            test(local -> local.set(getList("test1")));
            System.out.println("===========================");
            //**儲存list對象時,通過get方法擷取,然後修改list的值,則會影響到子線程
            test(local -> local.get().set(0, "test2"));
        }
    }
    
    //**執行結果
    主線程中threadLocal的值:[test]
    子線程中threadLocal的值:[test]
    主線程中threadLocal的值:[test1]
    子線程中threadLocal的值:[test]
    ===========================
    主線程中threadLocal的值:[test]
    子線程中threadLocal的值:[test]
    主線程中threadLocal的值:[test2]
    子線程中threadLocal的值:[test2]
               

3.4.ThreadLocalMap.java

  • 作用:ThreadLocalMap是ThreadLocal的内部類。存放以ThreadLocal變量為key,以儲存的變量為value的鍵值對
  • 原理:
    • ThreadLocalMap内部以Entry[]做為存儲,原始長度預設為16,當元素個數達到擴容閥值(數組長度的3/4)-擴容閥值/4,則自動擴容,擴容到上次長度的2倍。Entry[]的長度必須是2的倍數
    • Entry[]存儲元素并不是按索引順序存儲,而是根據ThreadLocal進行計算存儲位置,這樣能實作根據ThreadLocal都能快速定位鍵值對,而不用周遊數組的每個元素
    • 計算方法:ThreadLocal.threadLocalHashCode & (Entry[].length - 1)計算,ThreadLocal每一個執行個體都有一個唯一的threadLocalHashCode,值為上一個執行個體的值加上0x61c88647,該算法可以生成均勻的分布在2的N次方數組裡的下标
    • 如果計算的存儲位置已經有元素,則會存放到下一個索引的位置,ThreadLocalMap會清理過期資料,并重新根據計算的存儲位置重置,以保證盡可能減少和糾正此類問題
static class ThreadLocalMap {
    //**存放單個鍵值對的對象
    //**弱引用: 如果某個對象隻有弱引用,那麼gc會立即回收
    static class Entry extends WeakReference<ThreadLocal<?>> {
        Object value;

        Entry(ThreadLocal<?> k, Object v) {
            super(k);
            value = v;
        }
    }

    //**預設初始化的大小,必須是2的倍數
    private static final int INITIAL_CAPACITY = 16;

    //**真正存儲資料的數組,長度必須是2的倍數
    private Entry[] table;

    //**ThreadLocalMap的大小,即上述Entry[]中存放元素的個數
    private int size = 0;

    //**自動擴容的閥值
    private int threshold; // Default to 0

    //**設定自動擴容的閥值,為設定長度的2/3
    private void setThreshold(int len) {
        threshold = len * 2 / 3;
    }

    //**下一個索引
    private static int nextIndex(int i, int len) {
        return ((i + 1 < len) ? i + 1 : 0);
    }

    //**上一個索引
    private static int prevIndex(int i, int len) {
        return ((i - 1 >= 0) ? i - 1 : len - 1);
    }

    //**建立ThreadLocalMap,并設定第一個鍵值對
    ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
        //**根據預設初始化的大小初始化Entry[]
        table = new Entry[INITIAL_CAPACITY];
        //**根據threadlocal對象的threadLocalHashCode和Entry[]數組的長度計算存放的位置
        //**該算法可以生成均勻的分布在2的N次方數組裡的下标
        //**每個鍵值對并不是按順序存放Entry[]裡面
        int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
        //**把Entry對象放到指定位置
        table[i] = new Entry(firstKey, firstValue);
        //**設定ThreadLocalMap的大小,即Entry[]中存放元素的個數
        size = 1;
        //**設定自動擴容的閥值
        setThreshold(INITIAL_CAPACITY);
    }

    //**根據parentMap建立另一個parentMap,使用InheritableThreadLocal時,建立子線程時會調用
    private ThreadLocalMap(ThreadLocalMap parentMap) {
        Entry[] parentTable = parentMap.table;
        int len = parentTable.length;
        setThreshold(len);
        table = new Entry[len];

        for (int j = 0; j < len; j++) {
            Entry e = parentTable[j];
            if (e != null) {
                @SuppressWarnings("unchecked")
                ThreadLocal<Object> key = (ThreadLocal<Object>) e.get();
                if (key != null) {
                    //**調用InheritableThreadLocal的childValue方法處理儲存的對象
                    Object value = key.childValue(e.value);
                    Entry c = new Entry(key, value);
                    int h = key.threadLocalHashCode & (len - 1);
                    while (table[h] != null)
                        h = nextIndex(h, len);
                    table[h] = c;
                    size++;
                }
            }
        }
    }

    //**根據threadlocal擷取Entry對象
    private Entry getEntry(ThreadLocal<?> key) {
        //**計算下标
        int i = key.threadLocalHashCode & (table.length - 1);
        Entry e = table[i];
        //**如果對象存在,且key一樣,則傳回
        if (e != null && e.get() == key)
            return e;
        else  //**否則從指定索引的下一個索引開始查找
            return getEntryAfterMiss(key, i, e);
    }

    //**沒有直接命中,則指定索引的下一個索引開始查找
    private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
        Entry[] tab = table;
        int len = tab.length;

        //**從指定索引開始周遊,直到資料為null
        while (e != null) {
            ThreadLocal<?> k = e.get();
            //**如果資料存在則傳回
            if (k == key)
                return e;
            //**threadlocal對象為空,删除過期資料
            if (k == null)
                //**删除過期資料
                expungeStaleEntry(i);
            //**i為下一個索引
            else
                i = nextIndex(i, len);
            //**e為下一個索引的值
            e = tab[i];
        }
        //**沒有資料不存在則傳回null
        return null;
    }

    //**根據threadlocal對象設定value
    private void set(ThreadLocal<?> key, Object value) {
        Entry[] tab = table;
        int len = tab.length;
        //**計算存放的索引
        int i = key.threadLocalHashCode & (len-1);

        //**從指定索引開始周遊Entry[],直到資料為null
        for (Entry e = tab[i]; e != null; e = tab[i = nextIndex(i, len)]) {
            ThreadLocal<?> k = e.get();
            //**如果資料存在,則直接傳回
            if (k == key) {
                e.value = value;
                return;
            }
            //**如果key為空,則替換目前索引的資料,并傳回
            if (k == null) {
                replaceStaleEntry(key, value, i);
                return;
            }
        }
        
        //**設定指定索引的資料
        tab[i] = new Entry(key, value);
        int sz = ++size;
        //**如果沒有資料需要清理并且數組長度大于了擴容閥值,則擴容
        if (!cleanSomeSlots(i, sz) && sz >= threshold)
            rehash();
    }

    //**根據key删除資料
    private void remove(ThreadLocal<?> key) {
        Entry[] tab = table;
        int len = tab.length;
        //**計算存放的索引
        int i = key.threadLocalHashCode & (len-1);
        //**從指定的索引開始周遊Entry[],直到資料為null
        for (Entry e = tab[i]; e != null; e = tab[i = nextIndex(i, len)]) {
            //**如果指定key存在,則删除指定資料
            if (e.get() == key) {
                e.clear();
                expungeStaleEntry(i);
                return;
            }
        }
    }

    //**替換指定索引的過期資料的
    private void replaceStaleEntry(ThreadLocal<?> key, Object value, int staleSlot) {
        Entry[] tab = table;
        int len = tab.length;
        Entry e;

        //**從指定索引往前找,找到過期資料的索引
        int slotToExpunge = staleSlot;
        for (int i = prevIndex(staleSlot, len); (e = tab[i]) != null; i = prevIndex(i, len))
            if (e.get() == null)
                slotToExpunge = i;

        //**從指定索引往後找
        for (int i = nextIndex(staleSlot, len); (e = tab[i]) != null; i = nextIndex(i, len)) {
            ThreadLocal<?> k = e.get();

            //**如果是資料的key等于指定的key
            if (k == key) {
                //**替換它的value
                e.value = value;
                
                //**把它的位置和指定索引的位置互換(把資料替換到計算索引的位置)
                tab[i] = tab[staleSlot];
                tab[staleSlot] = e;

                //**如果過期資料的的索引等于指定索引,則過期資料的索引為互換後的新索引
                if (slotToExpunge == staleSlot)
                    slotToExpunge = i;
                cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
                return;
            }

            //**過期資料的索引
            if (k == null && slotToExpunge == staleSlot)
                slotToExpunge = i;
        }

        //**如果指定資料不存在,則建立新的資料
        tab[staleSlot].value = null;
        tab[staleSlot] = new Entry(key, value);

        //**如果有過時的條目,則清理
        if (slotToExpunge != staleSlot)
            cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
    }

    //**删除指定索引的過期資料,并傳回資料為null的索引
    private int expungeStaleEntry(int staleSlot) {
        Entry[] tab = table;
        int len = tab.length;

        //**指定索引的資料置為null,資料減一(删除指定資料)
        tab[staleSlot].value = null;
        tab[staleSlot] = null;
        size--;

        Entry e;
        int i;
        //**從指定的索引的下一個資料開始循環周遊Entry[]數組,直到遇到null值
        for (i = nextIndex(staleSlot, len); (e = tab[i]) != null; i = nextIndex(i, len)) {
            ThreadLocal<?> k = e.get();
            //**如果key為空,Entry置為空,資料減一(删除指定資料)
            if (k == null) {
                e.value = null;
                tab[i] = null;
                size--;
            } else {
                //**重新計算存放的索引
                int h = k.threadLocalHashCode & (len - 1);
                //**如果新索引不等于原索引,則原索引資料置為null
                if (h != i) {
                    tab[i] = null;

                   //**如果新的存放的索引有資料,則存放到新索引的下一個索引,直到沒有資料為止
                    while (tab[h] != null)
                        h = nextIndex(h, len);
                    tab[h] = e;
                }
            }
        }
        //**傳回資料為null的索引
        return i;
    }

    //**從指定索引開始清理資料
    private boolean cleanSomeSlots(int i, int n) {
        boolean removed = false;
        Entry[] tab = table;
        int len = tab.length;
        do {
            i = nextIndex(i, len);
            Entry e = tab[i];
            if (e != null && e.get() == null) {
                n = len;
                removed = true;
                i = expungeStaleEntry(i);
            }
        } while ( (n >>>= 1) != 0);
        return removed;
    }

    //**删除過期資料并擴容
    private void rehash() {
        //**删除所有的過期資料
        expungeStaleEntries();

        //**資料量 >= 擴容閥值 - 擴容閥值 / 4,則擴容
        if (size >= threshold - threshold / 4)
            resize();
    }

    //**擴容
    private void resize() {
        Entry[] oldTab = table;
        int oldLen = oldTab.length;
        //**擴容為原來的2倍
        int newLen = oldLen * 2;
        Entry[] newTab = new Entry[newLen];
        int count = 0;
        
        //**把舊資料存放在新的Entry[]中
        for (int j = 0; j < oldLen; ++j) {
            Entry e = oldTab[j];
            if (e != null) {
                ThreadLocal<?> k = e.get();
                if (k == null) {
                    e.value = null; // Help the GC
                } else {
                    int h = k.threadLocalHashCode & (newLen - 1);
                    while (newTab[h] != null)
                        h = nextIndex(h, newLen);
                    newTab[h] = e;
                    count++;
                }
            }
        }
        
        //**計算新的擴容閥值
        setThreshold(newLen);
        size = count;
        table = newTab;
    }

    //**删除所有的過期資料
    private void expungeStaleEntries() {
        Entry[] tab = table;
        int len = tab.length;
        for (int j = 0; j < len; j++) {
            Entry e = tab[j];
            //**如果Entry不為空并且key為空(threadlocal對象為null)則為過期資料
            if (e != null && e.get() == null)
                expungeStaleEntry(j);
        }
    }
}