ThreadLocal源代码解析

2019/01/23 21:19 下午 posted in  java

面试时,很多人都会被问到关于ThreadLocal的内容,如果没有看过ThreadLocal相关的代码的话,关于这方面的面试题都会回答的很混乱。今天,我们就从ThreadLocal的源代码入手,来彻底解开ThreadLocal的原理。

ThreadLocal是个啥?

我们知道,看java的源码时,一般都是先去看一个类上面的注释,类上的注释会表明该类是用来做什么的。下面节选一部分ThreadLocal的类注释:

/**
 * This class provides thread-local variables.  These variables differ from
 * their normal counterparts in that each thread that accesses one (via its
 * {@code get} or {@code set} method) has its own, independently initialized
 * copy of the variable.  {@code ThreadLocal} instances are typically private
 * static fields in classes that wish to associate state with a thread (e.g.,
 * a user ID or Transaction ID).
 */

翻译过来的意思就是:

该类提供了线程局部 (thread-local) 变量。这些变量不同于它们的普通对应物,因为访问某个变量(通过其get 或 set方法)的每个线程都有自己的局部变量,它独立于变量的初始化副本。ThreadLocal实例通常是类中的 private static 字段,它们希望将状态与某一个线程(例如,用户 ID 或事务 ID)相关联。

很多人以为ThreadLocal是用来解决线程同步的问题的,其实这是非常大的误解。ThreadLocal虽然提供了一种多线程下成员变量问题的解决方式,但是它并不是用来解决多线程共享变量问题的。线程同步机制是多个线程共享同一个成员变量,而ThreadLocal是为每一个线程创建独立的变量,每一个线程都可以独立的修改自己的变量而不需要担心会修改其他线程的变量。

ThreadLocal的方法

ThreadLocal常用的一共有4种方法

  • public T get() 返回此线程局部变量的当前线程副本中的值。
  • public void set(T value) 将此线程局部变量的当前线程副本中的值设置为指定值。
  • public void remove() 移除此线程局部变量当前线程的值。
  • private T setInitialValue() 返回此线程局部变量的当前线程的“初始值”。

同时,ThreadLocal下有一个内部类叫做ThreadLocal.ThreadLocalMap,这个类才是实现线程隔离机制的核心,上面的get set remove最终操作的数据结构都是该内部类。看ThreadLocalMap的名字也能大概猜出该类是基于键值对的方式存储的,key是当前的ThreadLocal实例,value是对应线程的变量副本。

所以从上面的说明来看,我们可以得出如下两个结论

  1. ThreadLocal本身不存储数据,它只是提供了在当前线程中找到数据的key
  2. 是ThreadLocal包含在Thread中,而不是反过来。

下图是两者的关系

ThreadLocal例子

public class SeqCount {

    private static ThreadLocal<Integer> seqCount = new ThreadLocal<Integer>(){
        // 实现initialValue()
        public Integer initialValue() {
            return 0;
        }
    };

    public int nextSeq(){
        seqCount.set(seqCount.get() + 1);

        return seqCount.get();
    }

    public static void main(String[] args){
        SeqCount seqCount = new SeqCount();

        SeqThread thread1 = new SeqThread(seqCount);
        SeqThread thread2 = new SeqThread(seqCount);
        SeqThread thread3 = new SeqThread(seqCount);
        SeqThread thread4 = new SeqThread(seqCount);

        thread1.start();
        thread2.start();
        thread3.start();
        thread4.start();
    }

    private static class SeqThread extends Thread{
        private SeqCount seqCount;

        SeqThread(SeqCount seqCount){
            this.seqCount = seqCount;
        }

        public void run() {
            for(int i = 0 ; i < 3 ; i++){
                System.out.println(Thread.currentThread().getName() + " seqCount :" + seqCount.nextSeq());
            }
        }
    }
}

运行结果为

可以看到,三个线程是分别累加的自己的独立的数据,相互之间没有任何的干扰。

ThreadLocal源码解析

get方法

我们先从get方法入手

/**
 * Returns the value in the current thread's copy of this
 * thread-local variable.  If the variable has no value for the
 * current thread, it is first initialized to the value returned
 * by an invocation of the {@link #initialValue} method.
 *
 * @return the current thread's value of this thread-local
 */
public T get() {
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null) {
        ThreadLocalMap.Entry e = map.getEntry(this);
        if (e != null) {
            @SuppressWarnings("unchecked")
            T result = (T)e.value;
            return result;
        }
    }
    return setInitialValue();
}

get方法流程是首先获取当前的线程t,然后通过getMap(t)获取到ThreadLocalMap实例map,如果map不为null,那么就通过map.getEntry获取ThreadLocalMap.Entry实例e,如果e不为null,那么就返回e.value,否则调用setInitialValue获取默认值。
getMap方法源码:

/**
 * Get the map associated with a ThreadLocal. Overridden in
 * InheritableThreadLocal.
 *
 * @param  t the current thread
 * @return the map
 */
ThreadLocalMap getMap(Thread t) {
    return t.threadLocals;
}

可以看到方法只有一行,非常的简洁,就是返回了Thread实例tthreadLocals变量。那么这个变量又是什么呢?继续跟踪到Thread类的源码中,

/* ThreadLocal values pertaining to this thread. This map is maintained
 * by the ThreadLocal class. */
ThreadLocal.ThreadLocalMap threadLocals = null;

就是一个ThreadLocal.ThreadLocalMap的实例,默认为null

ThreadLocalMap

ThreadLocalMap内部利用Entry来实现键值对的存储,Entry的源码如下:

/**
 * The entries in this hash map extend WeakReference, using
 * its main ref field as the key (which is always a
 * ThreadLocal object).  Note that null keys (i.e. entry.get()
 * == null) mean that the key is no longer referenced, so the
 * entry can be expunged from table.  Such entries are referred to
 * as "stale entries" in the code that follows.
 */
static class Entry extends WeakReference<ThreadLocal<?>> {
    /** The value associated with this ThreadLocal. */
    Object value;

    Entry(ThreadLocal<?> k, Object v) {
        super(k);
        value = v;
    }
}

可以看到Entry是一个WeakReference弱引用,其key是ThreadLocal实例,关于弱引用这里不再赘述,可以参考其他文章。
ThreadLocalMap的方法比较多,我们着重介绍两个方法getEntryset方法

getEntry

/**
 * Get the entry associated with key.  This method
 * itself handles only the fast path: a direct hit of existing
 * key. It otherwise relays to getEntryAfterMiss.  This is
 * designed to maximize performance for direct hits, in part
 * by making this method readily inlinable.
 *
 * @param  key the thread local object
 * @return the entry associated with key, or null if no such
 */
private Entry getEntry(ThreadLocal<?> key) {
    //获取key的hash值,用于在table中查找对应的Entry
    int i = key.threadLocalHashCode & (table.length - 1);
    Entry e = table[i];
    //当当前位置一次找到了对应的Entry,直接返回
    if (e != null && e.get() == key)
        return e;
    else
    //当当前位置为null或者当前位置存储的并不是要找的Entry时,进入此方法查找
        return getEntryAfterMiss(key, i, e);
}

getEntry没有太大的难度,与HashMap.get的初始思路比较一致,都是先计算hash,然后去对应的位置查找。但是ThreadLocalMapHashMap不一致的地方在于,HashMap针对hash碰撞所采用的方式是链表法(即,将所有hash冲突的元素保存在一个链表中),而ThreadLocalMap所采用的方式是开放定址法(即,当发现冲突时,遍历table到接下来的一个空位,将其存储在这里。)。读者可以思考一下为什么同是散列表的实现,为什么这两者要使用不同的hash冲突解决方式。

由于ThreadLocalMap使用的开放定址法,因此当查找不到时会调用getEntryAfterMiss方法,源码如下:

private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
    Entry[] tab = table;
    int len = tab.length;

    while (e != null) {
        ThreadLocal<?> k = e.get();
        if (k == key)
            return e;
        if (k == null)
            expungeStaleEntry(i);
        else
            i = nextIndex(i, len);
        e = tab[i];
    }
    return null;
}

整体也没有难度,需要注意的一点是当遍历到k == null时,会调用expungeStaleEntry方法会rehash当前节点到下一个null节点之间的键值对,辅助gc。

set方法

/**
 * Sets the current thread's copy of this thread-local variable
 * to the specified value.  Most subclasses will have no need to
 * override this method, relying solely on the {@link #initialValue}
 * method to set the values of thread-locals.
 *
 * @param value the value to be stored in the current thread's copy of
 *        this thread-local.
 */
public void set(T value) {
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null)
        map.set(this, value);
    else
        createMap(t, value);
}

get方法类似,也是先获取当前线程的ThreadLocalMap,当ThreadLocalMap不为null时,调用ThreadLocalMap.set方法保存;当ThreadLocalMapnull时,调用createMap方法初始化ThreadLocalMap实例,并写入一个键值对。ThreadLoalMap.set方法不多赘述,只要了解了开放定址法就很简单了。

remove方法

/**
 * Removes the current thread's value for this thread-local
 * variable.  If this thread-local variable is subsequently
 * {@linkplain #get read} by the current thread, its value will be
 * reinitialized by invoking its {@link #initialValue} method,
 * unless its value is {@linkplain #set set} by the current thread
 * in the interim.  This may result in multiple invocations of the
 * {@code initialValue} method in the current thread.
 *
 * @since 1.5
 */
 public void remove() {
     ThreadLocalMap m = getMap(Thread.currentThread());
     if (m != null)
         m.remove(this);
 }

跟上面的setget方法类似,只不过最后不需要考虑ThreadLocalMap为空的情况。

initialValue方法

    protected T initialValue() {
        return null;
    }

默认是返回null,我们可以根据业务不同设置不同的返回值即可。

ThreadLocal与内存泄漏

前面提到每个Thread都有一个ThreadLocal.ThreadLocalMap的map,该map的key为ThreadLocal实例,它为一个弱引用,我们知道弱引用有利于GC回收。当ThreadLocal的key == null时,GC就会回收这部分空间,但是value却不一定能够被回收,因为他还与Current Thread存在一个强引用关系。

由于存在这个强引用关系,会导致value无法回收。如果这个线程对象不会销毁那么这个强引用关系则会一直存在,就会出现内存泄漏情况。所以说只要这个线程对象能够及时被GC回收,就不会出现内存泄漏。如果碰到线程池,那就更坑了。

那么要怎么避免这个问题呢?

在前面提过,在ThreadLocalMap中的setEntry()、getEntry(),如果遇到key == null的情况,会对value设置为null。当然我们也可以显示调用ThreadLocal的remove()方法进行处理。

参考:http://www.iocoder.cn/JUC/sike/ThreadLocal/?vip