ThreadLocal
可以用于多线程情况下的变量保存,各个线程之间的变量值互不影响。关于它的使用场景可以看
刘欣 老师的文章一个故事讲明白线程的私家领地:ThreadLocal,这篇文章通过生动的案列引出了使用场景,令人印象深刻。
ThreadLocal threadLocalA= new ThreadLocal();
ThreadLocal threadLocalB = new ThreadLocal();
//保存
线程1: threadLocalA.set("Sbingo");
threadLocalB.set(99);
线程2: threadLocalA.set("sbingo");
threadLocalB.set(100);
//使用
线程1: threadLocalA.get() --> "Sbingo"
threadLocalB.get() --> "99"
线程2: threadLocalA.get() --> "sbingo"
threadLocalB.get() --> "100"
可以看到,ThreadLocal
的使用方法很简单,为什么这么操作后同一个变量在不同线程中取出的值不一样,并且1个线程可以保存N个ThreadLocal
呢?
原理就在Thread
类中成员变量threadLocals
,它的类型是ThreadLocal.ThreadLocalMap
,每个线程都有一个这样的Map
,所以可以保存N个ThreadLocal
键值对,并且不同线程的变量值不同。
因此,按上面这样操作后相应的数据结构如下:
线程1中map
key | value |
---|---|
threadLocalA | Sbingo |
threadLocalB | 99 |
线程2中map
key | value |
---|---|
threadLocalA | sbingo |
threadLocalB | 100 |
变量threadLocals
虽然定义在Thread
类中,但对它的访问却完全交给了类ThreadLocal
,下面我们通过源码看看ThreadLocal.ThreadLocalMap
具体是如何保存这些键值对的。
public void set(T value) {
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
if (map != null)
map.set(this, value);
else
createMap(t, value);
}
ThreadLocalMap getMap(Thread t) {
return t.threadLocals;
}
void createMap(Thread t, T firstValue) {
t.threadLocals = new ThreadLocalMap(this, firstValue);
}
ThreadLocalMap(ThreadLocal> firstKey, Object firstValue) {
table = new Entry[INITIAL_CAPACITY];
int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
table[i] = new Entry(firstKey, firstValue);
size = 1;
setThreshold(INITIAL_CAPACITY);
}
static class Entry extends WeakReference> {
/** The value associated with this ThreadLocal. */
Object value;
Entry(ThreadLocal> k, Object v) {
super(k);
value = v;
}
}
private void set(ThreadLocal> key, Object value) {
Entry[] tab = table;
int len = tab.length;
int i = key.threadLocalHashCode & (len-1);
for (Entry e = tab[i];
e != null;
e = tab[i = nextIndex(i, len)]) {
ThreadLocal> k = e.get();
if (k == key) {
e.value = value;
return;
}
if (k == null) {
replaceStaleEntry(key, value, i);
return;
}
}
tab[i] = new Entry(key, value);
int sz = ++size;
if (!cleanSomeSlots(i, sz) && sz >= threshold)
rehash();
}
private static int nextIndex(int i, int len) {
return ((i + 1 < len) ? i + 1 : 0);
}
private void rehash() {
expungeStaleEntries();
// Use lower threshold for doubling to avoid hysteresis
if (size >= threshold - threshold / 4)
resize();
}
private void setThreshold(int len) {
threshold = len * 2 / 3;
}
从第11行可以看出,getMap()
方法就是返回了刚才提到的Thread
类中成员变量threadLocals
。
当首次调用set()
方法时,会进入第7行的createMap()
方法,第15行赋值给提到过的线程变量threadLocals
,自此它不再是null
。
第15行在构造方法中传入了this
,也就是将当前的ThreadLocal
本身作为了key
。
进入构造方法,table
是Map
中的哈希表,第19行对其赋值,初始容量为16。table
中保存的元素类型为Entry
,它是对ThreadLocal
的弱引用(第26行),这样便于GC
,防止内存泄漏。Entry
中只有一个变量value
(第28行),value
就是Map
中每个ThreadLocal
对应的值。
第20行将threadLocalHashCode
和INITIAL_CAPACITY - 1
进行按位与运算,相当于threadLocalHashCode
对INITIAL_CAPACITY
取模,但速度比用%
运算更快,这样就计算出了哈希表中索引的位置。
最后对哈希表相应索引的元素赋值,更新表内元素数量和阙值,完成了首次set
。
之后再调用set()
方法时,就会进入第5行代码。
还是和之前一样计算出索引,当发生冲突时,进入第41至55行的循环体内。
从第64行可以看出,解决冲突的nextIndex()
方法用的是线性探测法。
第46至49行,如果是同一个ThreadLocal
对象,则更新对应的值。
第51至54行,如果ThreadLocal
的引用为null
,则用当前的新值替换旧值,为简化篇幅,具体源码这里不展开了。
如果没发生冲突、发生冲突但遍历结束也没找到相同的ThreadLocal
对象或null
对象,则插入新值,此时执行第57行。
第59行表明在插入新值后,尝试删除已回收的值,如果没有删除发生并且哈希表内元素数量超过了阙值,则扩容哈希表并重新排列表内元素。
第76行表明阙值大小为哈希表长度的2/3 ,第71行表明扩容时哈希表内元素数量为阙值的3/4,相乘为哈希表大小的1/2。即当哈希表内元素数量超过哈希表大小一半时,就会发生扩容。扩容后的大小为原大小的两倍,这里也不展开分析了。
public T get() {
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
if (map != null) {
ThreadLocalMap.Entry e = map.getEntry(this);
if (e != null) {
@SuppressWarnings("unchecked")
T result = (T)e.value;
return result;
}
}
return setInitialValue();
}
private T setInitialValue() {
T value = initialValue();
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
if (map != null)
map.set(this, value);
else
createMap(t, value);
return value;
}
protected T initialValue() {
return null;
}
private Entry getEntry(ThreadLocal> key) {
int i = key.threadLocalHashCode & (table.length - 1);
Entry e = table[i];
if (e != null && e.get() == key)
return e;
else
return getEntryAfterMiss(key, i, e);
}
private Entry getEntryAfterMiss(ThreadLocal> key, int i, Entry e) {
Entry[] tab = table;
int len = tab.length;
while (e != null) {
ThreadLocal> k = e.get();
if (k == key)
return e;
if (k == null)
expungeStaleEntry(i);
else
i = nextIndex(i, len);
e = tab[i];
}
return null;
}
当get()
发生在set()
之前,就会进入第12行代码。此后在第16行调用initialValue()
方法返回了null
,当然也可以复写该方法以修改默认值。第22行会初始化Map
,最后返回了初始值(null
或修改后的初始值)。
当然大多数情况下调用get()
方法会进入第5行获取Entry
对象,在第8、9行获取对应的value
值并返回。可见,关键在于getEntry()
方法。
第31行还是计算出索引,第33行判断索引处的Entry
对象是否为null
并比较两个ThreadLocal
对象是否为同一个,结果为真的话就可以直接返回Entry
对象,否则进入getEntryAfterMiss()
方法。
第43至52行的循环体中,不断遍历哈希表,当两个ThreadLocal
对象为同一个时,就可以返回哈希表当前索引上的Entry
对象。
ThreadLocal
虽然用起来很简单,但使用时要注意内存泄漏问题。
刚才分析时说过,Entry
对象是ThreadLocal
的弱引用。所以当线程消亡时,垃圾回收器会将它回收,此时不存在内存泄漏问题。
但如果使用线程池,并且线程存在复用的情形,线程不会消亡,这些Entry
对象就会发生泄漏,此时就需要我们进行手动清理。
清理的方法很简单,就是调用ThreadLocal
的remove()
方法,和上面的分析类似,它会不停探测索引,如果找到正确的ThreadLocal
对象就将其清除。
相关源码如下:
public void remove() {
ThreadLocalMap m = getMap(Thread.currentThread());
if (m != null)
m.remove(this);
}
private void remove(ThreadLocal> key) {
Entry[] tab = table;
int len = tab.length;
int i = key.threadLocalHashCode & (len-1);
for (Entry e = tab[i];
e != null;
e = tab[i = nextIndex(i, len)]) {
if (e.get() == key) {
e.clear();
expungeStaleEntry(i);
return;
}
}
}