ThreadLocal原理分析

前言

熟悉Android开发的同学知道Handler异步通信机制,其中如何做到一个线程有且只有一个Looper,从而也就只有一个MessageQueue,这是如何做到的呢?这就要涉及ThreadLocal了。

static final ThreadLocal sThreadLocal = new ThreadLocal();

private static void prepare(boolean quitAllowed) {
        if (sThreadLocal.get() != null) {
            throw new RuntimeException("Only one Looper may be created per thread");
        }
        sThreadLocal.set(new Looper(quitAllowed));
}

public static @Nullable Looper myLooper() {
        return sThreadLocal.get();
}

以上就是Android中经典的ThreadLocal使用场景了。

原理分析

ThreadLocal#set()

public void set(T value) {
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null)
        map.set(this, value);
    else
        createMap(t, value);
}

哦呦,在Thread类中具有一个ThreadLocal的成员变量。

ThreadLocalMap getMap(Thread t) {
    return t.threadLocals;
}

看下Thread中的ThreadLocalMap

ThreadLocal.ThreadLocalMap threadLocals = null;
public void remove() {
    ThreadLocalMap m = getMap(Thread.currentThread());
    if (m != null)
        m.remove(this);
}

不难理解了,先去当前Thread中获取ThreadLocalMap,若ThreadLocalMap为空,则创建一个ThreadLocalMap并赋值给成员变量。

void createMap(Thread t, T firstValue) {
    t.threadLocals = new ThreadLocalMap(this, firstValue);
}

以上我们看到每个线程都维护了一个ThreadLocalMap这个对象,所以每次调用ThreadLocal#set()方法,就是向当前线程中的map中添加对应的数据。
ok,接下来我们重点看下这个ThreadLocalMap这个类。

ThreadLocal#ThreadLocalMap

构造方法

// 维护了一个Entry数组
private Entry[] table;
ThreadLocalMap(ThreadLocal firstKey, Object firstValue) {
    // 初始容量为16 也就意味着一个线程初始可以存放16个ThreadLocal对象
    table = new Entry[INITIAL_CAPACITY];
    // 计算index
    int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
    table[i] = new Entry(firstKey, firstValue);
    // 标记目前指针指在什么位置
    size = 1;
    // 计算扩容因子 threshold = len * 2 / 3;
    setThreshold(INITIAL_CAPACITY);
}

ThreadLocal#Entry

static class Entry extends WeakReference> {
    /** The value associated with this ThreadLocal. */
    Object value;
    Entry(ThreadLocal k, Object v) {
        super(k);
        value = v;
    }
}

这个构造方法第一个参数是ThreadLocal变量,被标记成弱引用,我们知道变量被标记为弱引用时,当GC从GcRoot开始可达性分析时,只要检查经过该变量(变量没有强引用)都会被回收掉。因此这里的ThreadLocal变量被标记为弱引用,是因为在当前线程的ThreadLocalMap对ThreadLocal变量进行引用,但是其他线程也对该ThreadLocal进行了引用,这样在当当前线程执行完毕,该线程中的ThreadLocal就不会被释放回收的导致内存泄漏,因此这里加了弱引用,但是与之对应的value则没有加弱引用所以若使用完毕以后没有调用remove()方法,则导致value回收不掉导致内存泄漏。

真正执行变量保存的set()方法

private void set(ThreadLocal key, Object value) {
    Entry[] tab = table;
    int len = tab.length;
    // 计算数组的下标
    int i = key.threadLocalHashCode & (len-1);
    // 遍历一遍Entry数组 开放地址法 一旦entry为空就跳出
    for (Entry e = tab[i];
         e != null;
         e = tab[i = nextIndex(i, len)]) {
        ThreadLocal k = e.get();

        if (k == key) {
            // 同一个ThreadLocal变量 则直接赋值value
            e.value = value;
            return;
        }
        // key为空 说明该ThreadLocal变量已经被回收了,把新值放在这里
        if (k == null) {
            replaceStaleEntry(key, value, i);
            return;
        }
    }
    // 没找到的话就新创建一下
    tab[i] = new Entry(key, value);
    int sz = ++size;
    // 先去清理下若清理不了 并且 超过了扩容阈值就进行扩容
    if (!cleanSomeSlots(i, sz) && sz >= threshold)
        // 扩容
        rehash();
}

遇到了哈希冲突,ThreadLocalMap是采用开放地址法解决冲突的。开放地址法原理就是根据除留余数法进行对数据的散列,若遇到哈希冲突时,将index指针后移直到找到一块空间为空的地方进行存储。但是如何尽量让hash值不重复而又尽可能的散列呢?ThreadLocalMap中有个魔法数字0x61c88647,这个数字是Integer有符号的黄金分割的0.618倍,可以根据key生成的Index更加的分布均匀。

// 容量是2的整数倍
int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
int i = firstKey.threadLocalHashCode & (len - 1);

之所以,INITIAL_CAPACITY - 1 是因为 16 的2进制为

000000000000000000000000000010000

减1之后为

000000000000000000000000000011111

高位被消除,剩下低位与被除数进行与运算,就相当于普通的取模运算。HashMap中也有这样的操作。

replaceStaleEntry()

        private void replaceStaleEntry(ThreadLocal key, Object value, int staleSlot) {
            Entry[] tab = table;
            int len = tab.length;
            Entry e;
            int slotToExpunge = staleSlot;
            // 1
            for (int i = prevIndex(staleSlot, len);
                 (e = tab[i]) != null;
                 i = prevIndex(i, len))
                if (e.get() == null)
                    slotToExpunge = i; 
            // 2
            for (int i = nextIndex(staleSlot, len);
                 (e = tab[i]) != null;
                 i = nextIndex(i, len)) {
                ThreadLocal k = e.get();
                if (k == key) {
                    e.value = value;
                    tab[i] = tab[staleSlot];
                    tab[staleSlot] = e;
                    if (slotToExpunge == staleSlot)
                        slotToExpunge = i;
                    cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
                    return;
                }
                if (k == null && slotToExpunge == staleSlot)
                    slotToExpunge = i;
            }
            tab[staleSlot].value = null;
            tab[staleSlot] = new Entry(key, value);
            if (slotToExpunge != staleSlot)
                cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
        }

如果新来了一个数据,key根据计算index,发现这个index对应的Entity的key 为null,则会进入上面的方法。代码1处,向前遍历找空的Entity到staleSlot之间key过期最小的index。方便之后进行再次清理,代码2处,向数组的后边找一直遍历到Entity 为null,若找到了key,要和当前的staleSlot位置进行交换。若不进行交换直接赋值,则会导致有两个相同的key。若往前找的index也就是slotToExpunge不等于当前的staleSlot,说明需要进行清理最终还要调用一下cleanSomeSlots()方法。

 private int expungeStaleEntry(int staleSlot) {
            Entry[] tab = table;
            int len = tab.length;
            tab[staleSlot].value = null;
            tab[staleSlot] = null;
            size--;
            Entry e;
            int i;
            for (i = nextIndex(staleSlot, len);
                 (e = tab[i]) != null;
                 i = nextIndex(i, len)) {
                ThreadLocal k = e.get();
                if (k == null) {
                    e.value = null;
                    tab[i] = null;
                    size--;
                } else {
                    int h = k.threadLocalHashCode & (len - 1);
                    if (h != i) {
                        tab[i] = null;
                        while (tab[h] != null)
                            h = nextIndex(h, len);
                        tab[h] = e;
                    }
                }
            }
            return i;
        }
        private boolean cleanSomeSlots(int i, int n) {
            boolean removed = false;
            Entry[] tab = table;
            int len = tab.length;
            do {
                i = nextIndex(i, len);
                Entry e = tab[i];
                if (e != null && e.get() == null) {
                    n = len;
                    removed = true;
                    i = expungeStaleEntry(i);
                }
            } while ( (n >>>= 1) != 0);
            return removed;
        }

后边这两个方法主要是由于将Entity不为空但是key为空的Entity清除掉的时候,导致两个hash冲突的ThreadLocal中间存在空闲空间导致最终在找后面冲突值得时候发现Entity为空直接跳出循环了。

public void remove() {
    ThreadLocalMap m = getMap(Thread.currentThread());
    if (m != null)
        m.remove(this);
}

总结

如果我们没有调用get,set,remove方法,虽然说ThreadLocalMap中的Entity中的ThreadLocal变量是弱引用的,当线程退出的时候,ThreadLocal变量没有强引用引用,不会导致内存泄漏。但是有一种情况,比如是线程池这种,在线程执行完成以后,由于池化作用,只是将当前线程归还线程池进行重新调度并没有退出,所以在这种情况下就产生了ThreadLocal引用,导致了内存泄漏,因此我们在使用ThreadLocal的时候尽量养成好的习惯不再使用的时候,去调用一下remove方法。

补充

当主线程想将数据副本也传递到子线程中,我们将要怎么处理呢?我们知道,ThreadLocal是通过ThreadLocalMap保存ThreadLocal变量,而ThreadLocalMap是在Thread类中进行赋值的,并且是该线程的单一副本,是无法传递到子线程中的。因此,可以使用 InheritableThreadLocal 来代替 ThreadLocal,ThreadLocal 和 InheritableThreadLocal 都是线程的属性。

private static final ThreadLocal threadLocal = new ThreadLocal<>();
private static final InheritableThreadLocal threadLocal1 = new InheritableThreadLocal<>()
public static void main(String[] args) throws InterruptedException {
    Thread parentThread = new Thread(new Runnable() {
        @Override
        public void run() {
            threadLocal1.set(1);
            Thread childThread = new Thread(new Runnable() {
                @Override
                public void run() {
                    System.out.println(threadLocal1.get());
                    System.out.println();
                }
            });
            childThread.start();
            System.out.println(threadLocal1.get());
        }
    });
    Thread parentThread1 = new Thread(() -> {
        threadLocal1.set(12);
        Thread childThread = new Thread(new Runnable() {
            @Override
            public void run() {
                System.out.println(threadLocal1.get());
                System.out.println();
            }
        });
        childThread.start();
        System.out.println(threadLocal1.get());
    });
    parentThread.start();
    parentThread.join();
    parentThread1.start();
}

输出

1
null
12
12

InheritableThreadLocal

public class InheritableThreadLocal extends ThreadLocal {
    protected T childValue(T parentValue) {
        return parentValue;
    }
    ThreadLocalMap getMap(Thread t) {
       return t.inheritableThreadLocals;
    }
    // 重写了createMap逻辑
    void createMap(Thread t, T firstValue) {
        t.inheritableThreadLocals = new ThreadLocalMap(this, firstValue);
    }
}

Thread#init()

 if (inheritThreadLocals && parent.inheritableThreadLocals != null)
            this.inheritableThreadLocals =
                ThreadLocal.createInheritedMap(parent.inheritableThreadLocals);

当Thread的父线程的inheritableThreadLocals的成员变量不为空,就给当前的inheritableThreadLocals变量也赋值。

ThreadLocal#createInheritedMap()

private ThreadLocalMap(ThreadLocalMap parentMap) {
            Entry[] parentTable = parentMap.table;
            int len = parentTable.length;
            setThreshold(len);
            table = new Entry[len];
            for (int j = 0; j < len; j++) {
                Entry e = parentTable[j];
                if (e != null) {
                    @SuppressWarnings("unchecked")
                    ThreadLocal key = (ThreadLocal) e.get();
                    if (key != null) {
                        Object value = key.childValue(e.value);
                        Entry c = new Entry(key, value);
                        int h = key.threadLocalHashCode & (len - 1);
                        while (table[h] != null)
                            h = nextIndex(h, len);
                        table[h] = c;
                        size++;
                    }
                }
            }
        }
 
 

很简单,将父线程的ThreadLocalMap的链表完全插入到子线程的ThreadLocalMap容器中,从而让子线程拿到父线程的ThreadLocal。

你可能感兴趣的:(ThreadLocal原理分析)