天天看点

JDK1.7的ConcurrentHashMap的put源码到底是怎么回事

JDK1.7的ConcurrentHashMap的put源码到底是怎么回事

    • 相关基本概念
      • ConcurrentHashMap的数据结构存放原理
      • 理解UNSAFE操作
      • 源码中用UNSAFE操作取值
      • put方法的流程
    • 源码分析
      • ConcurrentHashMap的构造方法
      • ConcurrentHashMap的put方法
      • ConcurrentHashMap的ensureSegment方法
      • Segment中的scanAndLockForPut方法
      • Segment中的put方法
      • Segment中的rehash扩容方法

JDK1.7中的HashMap是采用数组加链表的结构进行存储,并且是线程不安全的,而HashTable为了线程安全就在方法上加了synchronized锁,把整个HashTable锁了起来,多个线程竞争同一把锁,效率不高。

综合这两点两个问题,ConcurrentHashMap就来了。

相关基本概念

ConcurrentHashMap的数据结构存放原理

JDK1.7中的ConcurrentHashMap的实质也是采用数组加链表的方式实现。只是它进行了一个分段。

在最外层用一个Segment类的数组,每个Segment数组的位置内部又放了HashEntry链表数组,而数组就存放在内部的HashEntry中。

JDK1.7的ConcurrentHashMap的put源码到底是怎么回事

这里其实就可以理解成ConcurrentHashMap把HashMap进行了分段。这样的话每个线程只需要跟和它在一个分段上的其他线程竞争锁,而其他分段则互不影响,就提高了效率,也保证了线程的安全。

JDK1.7的ConcurrentHashMap的put源码到底是怎么回事

理解UNSAFE操作

CPU的执行速度要比内存读取数据速度高,所以将需要运算的数据复制一份到CPU的高速缓存中,也就是给当前运行线程的运行内存的高速缓冲中放入副本。每个线程第一次会从主存中将变量拿到高速缓冲中,以后每次线程运行所需变量都从该线程自己的高速缓存中取,而不是从主存中取,运算结束后再将高速缓冲中的数据刷新到主存中。这样就会导致每个线程取到的有可能不是主存中最新的值,那么计算的结果也就是错误的。

JDK1.7的ConcurrentHashMap的put源码到底是怎么回事

ConcurrentHashMap源码中,为了解决上面的这种数据的不一致性,用了

UNSAFE操作

方法操作。举个例子【UNSAFE不能直接使用,需要用反射获取,这里只是写个伪代码方便理解】:

//有一个数组
private String table[] = {"1", "2", "3", "4"};
// 数组的类型,用于下面arrayIndexScale和arrayBaseOffset的参数
Class arrayClass = String.class;
//获取数组存储对象的对象头大小
int ns = UNSAFE.arrayIndexScale(arrayClass);
//数组中第一个元素的起始位置
int base = UNSAFE.arrayBaseOffset(arrayClass);
//获取table[1]在内存中的最新值
String value = UNSAFE.getObject(table,base+1*ns);
//获取table[2]在内存中的最新值
String value = UNSAFE.getObject(table,base+2*ns);
//获取table第i个位置在内存中的最新值
String value = UNSAFE.getObject(table,base+i*ns);
           

这个UNSAFE方法其实就和java中的volatile修饰符有着同样的作用。这上面具体的参数base,ns这些的原理就不作深入了,大概知道这个UNSAFE有什么用,怎么用的就行了。

源码中用UNSAFE操作取值

那么这个UNSAFE操作在源码中的用法和上面说的差不多,稍微有些变化。这里举例取外层数组Segment[j]位置的值,源码中是这么写的

s = (Segment<K,V>)UNSAFE.getObject(segments, (j << SSHIFT) + SBASE)

前面的(Segment<K,V>)是类型强转,方法中的第一个参数就是要从中取数的数组对象。主要看后面的UNSAFE.getObject中的最后一个参数,这些参数都会在源码中的一个静态代码块中赋好值。【sc,ss,SSHIFT就是取外层Segment的相关参数。tc,ts,TSHIFT就是取取内层HashEntry数组的相关参数】

JDK1.7的ConcurrentHashMap的put源码到底是怎么回事

所以UNSAFE.getObject(segments, (j << SSHIFT) + SBASE)中的SBASE就是调用arrayBaseOffset方法获取的数组中第一个元素的起始位置base,SSHIFT调用的方法中的参数ss就是调用arrayIndexScale获取的获取数组存储对象的对象头大小ns,这里翻译一下就是

UNSAFE.getObject(segments, base +(j<<(31-Integer.numberOfLeadingZeros(ns))))

唯一不同的就是原本我们是用i*ns取第i个位置上的元素

源码中是用 j<<(31-Integer.numberOfLeadingZeros(ns))

这里先解释一下31-Integer.numberOfLeadingZeros(ns)的作用

Integer.numberOfLeadingZeros会返回某个数字最高位1前面0的个数。以16举例子:

16的二进制是 0000 0000 0000 0000 0000 0000 0001 0000 它前面共有27个0,以就会返回27

因为二进制的第一位表示2的0次方,而实际计算左移的时候,左移i位就是乘2的i次方,而不会是乘2的(i-1)次方,所以用整形的32位去掉一位后再减去最高位的1后面的0,剩下的就是小于原来数字最大的二次幂数。

31-Integer.numberOfLeadingZeros(ns) <=> log2(ns)向下取

最后再进行左移 j<<(31-Integer.numberOfLeadingZeros(ns)) <=> j*2[log2(ns)向下取整] <=> j*(小于ns的最大二次幂数)

这个方法分析出来和我们原来的i*ns好像有点差距,这个原因不太了解为啥,最后算出来的ns要往下取小于ns的最大二次幂数

这里就暂且能大概理解一下这个UNSAFE.getObject(segments, (j << SSHIFT) + SBASE)在源码中的写法就是取segments数组中第j位就行了,并且能保证一定是从主拿最新的值,而不是重缓存拿。可能具体的原理还要深入反射,JVM,CAS相关的东西。

put方法的流程

既然ConcurrentHashMap分为了内外两层,所以肯定先要在外层的Segment数组进行hash算出个值,然后进到Segment这个值下标得位置里面,进入Segment内部后,先对该分段加锁,然后再算一次hash,找到内部得Entry数组的下标位置,然后在进行元素的插入。所以大致流程可以分为下面几步:

  1. 外层计算hash值,找到外层Segment数组的对应下标位置
  2. 进入Segment数组中,对该分段加锁
  3. 再次计算hash值,确定内部的Entry数组的下标
  4. 放入值

大概的流程就这几步,但肯定不会这么笼统简单的完成了,下面看源码。

源码分析

ConcurrentHashMap的构造方法

构造方法需要三个参数,initialCapacity可存放元素容量,loadFactor加载因子,concurrencyLevel表示外层数组segment的大小,并规定不能超过定义好的外层数组容量最大值MAX_SEGMENTS。

public ConcurrentHashMap(int initialCapacity,float loadFactor, int concurrencyLevel) {
        if (!(loadFactor > 0) || initialCapacity < 0 || concurrencyLevel <= 0)
            throw new IllegalArgumentException();
        if (concurrencyLevel > MAX_SEGMENTS)
            concurrencyLevel = MAX_SEGMENTS;
        // Find power-of-two sizes best matching arguments
        int sshift = 0;
        int ssize = 1;
        while (ssize < concurrencyLevel) {
            ++sshift;
            ssize <<= 1;
        }
        this.segmentShift = 32 - sshift;
        this.segmentMask = ssize - 1;
        if (initialCapacity > MAXIMUM_CAPACITY)
            initialCapacity = MAXIMUM_CAPACITY;
        int c = initialCapacity / ssize;
        if (c * ssize < initialCapacity)
            ++c;
        int cap = MIN_SEGMENT_TABLE_CAPACITY;
        while (cap < c)
            cap <<= 1;
        // create segments and segments[0]
        Segment<K,V> s0 =
            new Segment<K,V>(loadFactor, (int)(cap * loadFactor),
                             (HashEntry<K,V>[])new HashEntry[cap]);
        Segment<K,V>[] ss = (Segment<K,V>[])new Segment[ssize];
        UNSAFE.putOrderedObject(ss, SBASE, s0); // ordered write of segments[0]
        this.segments = ss;
    }
           

中间过程先不看,看最后它先把segment的第一个位置上的对象new了出来,这里的作用就是,在后面每次扩容或相关操作的时候,就可以直接从s0里面直接取一些需要的数据,比如内层数据大小等,就不用每次再计算一遍。

这里可以看到,cap作为了内层数组HashEntry的大小,而ssize作为最终外层数组的容量大小,而没有取concurrencyLevel。那么也就是说中间的一段代码都在计算内外层数组的额容量大小。

那就来分析一下中间这段代码到底做了什么事情

ConcurrentHashMap和HashMap一样,都要求数组的长度为2的幂,不管内层还是外层数组容量。所以先要对我们传入的concurrencyLevel进行处理,while中用ssize与concurrencyLevel比较,小于的话,每次左移动一位,也就是每次都乘以2,这样就能找到大于等于concurrencyLevel的最小的2次幂值sszie,所以最终外层segment的容量取的是ssize,而不是concurrencyLevel。

而sshift同样在while中,它每次+1,记录的是ssize 到底左移了几次。

中间两个segmentShift和segmentMask的赋值先不看,这里先明确一下segmentShift是32-ssize左移次数,也就是32-log2(外层数组容量),segmentMask是ssize-1,也就是外层数组容量-1

接着判断存放元素的容量是否大于的规定的最大容量

int c = initialCapacity / ssize 就是计算了每个segment的内部平均要放几个元素,由于/除法会舍掉小数,所以还用个if判断是否需要+1才能保证每个分段中的存放的元素总和大于等于传入的元素大小。但这里还不是内部HashEntry的数组容量大小!

因为前面说了,内部的数组容量同样要为2的n次幂,所以下面又对cap进行了与c比较的左移循环,目的同样是为了找到大于等于c的最小2次幂作为内部数组的容量。

ConcurrentHashMap的put方法

这里的ConcurrentHashMap是不允许存放null值的,所以这里判断了value==null的话,会抛出个NullPointerException()异常。

接着看到下面的if中先用UNSAFE取了segments中第j个位置的Segment对象值赋给s,并判断s是否为null。为null的话,调用ensureSegment方法在这个位置上创建一个Segment对象。不为null的话,再调用segments中的put方法进行键值对的放入。

既然j是数组下标的话,那么hash方法和

(hash >>> segmentShift) & segmentMask;

这一行就是用来计算定位外层数组下标位置的方法。

public V put(K key, V value) {
        Segment<K,V> s;
        if (value == null)
            throw new NullPointerException();
        int hash = hash(key);
        int j = (hash >>> segmentShift) & segmentMask;
        if ((s = (Segment<K,V>)UNSAFE.getObject          // nonvolatile; recheck
             (segments, (j << SSHIFT) + SBASE)) == null) //  in ensureSegment
            s = ensureSegment(j);
        return s.put(key, hash, value, false);
    }
           

hash方法不做过多的研究,之前分析HashMap的源码说过,就是为了让结果更散列,得到一个hash值。

主要分析

(hash >>> segmentShift) & segmentMask;

这一行代码

先把前面的hash >>> segmentShift看成某个值x,而segmentMask是2的次幂-1,转化为二进制后就是低位全是连续的1,比如8-1=7的二进制就是0000 0000 0000 0000 0000 0000 0000 0111

那么它和任意值进行与&操作后,所得结果的范围都会在0-7之间,这样也就保证了计算出的j这个外层数组下标不会越界。

hash是根据key算出来的值,segmentShift是32-log2(外层数组容量),segmentMask是外层数组容量-1

再说的直白一点segmentShift就是数组容量转为二进制后最高位左面0的个数 + 1。比如8=23,转为二进制就是0000 0000 0000 0000 0000 0000 0000 1000 ,那么segmentShift = 28 + 1 = 29

hash右移了segmentShift位之后,就能让hash的高位转都到低位,且转到低位的数量与外层数组容量有关。

这里为什么要进行hash >>> segmentShift这样的操作,也没太懂,就暂且理解为都是为了让所计算出的j数组下标更散列,减少哈希冲突。

后面的ensureSegment就是把创建出来的Segment放入指定下标,最后再调用Segment中的put方法把元素放入。

ConcurrentHashMap的ensureSegment方法

这个方法用来在segement的指定下标位置创建一个Segment对象

这里的u就是前面说到的UNSAFE操作中的数组下标的意思。每次进来都先用UNSAFE操作取最新值去判断这个u位置上的Segment对象是否为null。

因为钱买你构造方法中说过了在构造ConcurrentHashMap的时候就会把Segment数组中的第一个位置的对象创建出来,方便后面再创建Segment对象的时候,可以减少计算,直接去第一个位置上的对象中取数据。

这里的proto就是发挥了这个作用

从proto中取出内层HashEntry数组的容量cap,取出加载因子ld,两者计算得到扩容阈值thresolad。

接着就开始创建该Segment对象内部的HashEntry数组对象,new出一个tab后,再次确认一下,在上诉过程中,是否有其他线程对该位置上的Segment对象先一步new好,所以再去UNSAFE取一下,如果还为null的话,才是真正的执行

new Segment<K,V>(lf, threshold, tab);

。到这一步,也仅仅是吧这个对象new出来了,并没有放入Segment数组中指定的位置。

所以new完后的下一步还是通过UNSAFE操作再去获取一次最新的对象,如果还为null,这里就执行cas操作

UNSAFE.compareAndSwapObject(ss, u, null, seg = s)

这个表示的意思是,判断ss中第u个位置是否为null,如果是的话,就把这个位置上seg的值更新为s。

这样就完成了Segment数组上指定下标位置的Segment对象的创建,并且保证了原子性。

private Segment<K,V> ensureSegment(int k) {
        final Segment<K,V>[] ss = this.segments;
        long u = (k << SSHIFT) + SBASE; // raw offset
        Segment<K,V> seg;
        if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u)) == null) {
            Segment<K,V> proto = ss[0]; // use segment 0 as prototype
            int cap = proto.table.length;
            float lf = proto.loadFactor;
            int threshold = (int)(cap * lf);
            HashEntry<K,V>[] tab = (HashEntry<K,V>[])new HashEntry[cap];
            if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u))
                == null) { // recheck
                Segment<K,V> s = new Segment<K,V>(lf, threshold, tab);
                while ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u))
                       == null) {
                    if (UNSAFE.compareAndSwapObject(ss, u, null, seg = s))
                        break;
                }
            }
        }
        return seg;
    }
           

Segment中的scanAndLockForPut方法

在看Segment中的put方法之前,会看到它里面的第一步就用到了scanAndLockForPut()方法,先看这个方法有什么用。

这里的entryForHash方法就直接说了,是取内层HashEntry数组中这个hash值算出来对应下标位置的元素

这里面最主要的就是在一个while循环里面一直去获取锁tryLock(),只有当获取到了才返回。

再尝试获取锁的过程中,他还干了一些事情,每尝试获取一次锁,它就会往下去遍历该位置上的链表,如果enull并且nodenull的话,这里就表示遍历到尾部了也没有key相等的元素或者这个位置上根本就没有元素,那么这个时候我们就把当前要放入的key-value,new成一个对象node,然后修改retries为0,让下一次的while尝试获取锁时,进入下面的else if中。

在下面的else if中,他会先判断当前重试获取锁的次数是否达到上限,达到了的话,就会调用lock()阻塞的方法一直等待获取锁,不做别的事情。没达到的话它会再一次的调用entryForHash去获取该链表最新的头节点,因为获取不到锁,并且在每次尝试获取和遍历链表的过程中,它以头节点的方式新插入了一个节点,那么有可能新插入的节点和我当前要插入的节点的key相同,所以又会修改retries回到-1,重新走上面的if中的逻辑。

最终返回的node不为null的话就是以已经把key和value打包好的Entry对象,如果为null的话,说明遍历的途中获取到锁了,提前返回出来了。

private HashEntry<K,V> scanAndLockForPut(K key, int hash, V value) {
            HashEntry<K,V> first = entryForHash(this, hash);
            HashEntry<K,V> e = first;
            HashEntry<K,V> node = null;
            int retries = -1; // negative while locating node
            while (!tryLock()) {
                HashEntry<K,V> f; // to recheck first below
                if (retries < 0) {
                    if (e == null) {
                        if (node == null) // speculatively create node
                            node = new HashEntry<K,V>(hash, key, value, null);
                        retries = 0;
                    }
                    else if (key.equals(e.key))
                        retries = 0;
                    else
                        e = e.next;
                }
                else if (++retries > MAX_SCAN_RETRIES) {
                    lock();
                    break;
                }
                else if ((retries & 1) == 0 &&
                         (f = entryForHash(this, hash)) != first) {
                    e = first = f; // re-traverse if entry changed
                    retries = -1;
                }
            }
            return node;
        }
           

Segment中的put方法

ConcurrentHashMap是线程安全的。该方法进入首先就会尝试获取锁tryLock(),在锁里会返回一个node对象,上面的scanAndLockForPut方法已经分析过,这里获取的node是null或者是已经把key和value打包好的Entry对象。如果如果获取不到的话调用scanAndLockForPut()方法。

进来先计算一下要放在内部数组的下标index,(tab.length - 1) & hash,这种某个值与数组长度-1的与操作都是为了控制计算范围不越界,并且尽可能的让计算到每个位置的概率相近。

接着调用entryAt()方法去获取内部数组index位置的内容,同样的是用UNSAFE操作去获取内存中的最新值。这里获取到的first 就是内层数组指定下标index位置链表上的第一个节点。

接着从这个节点开始遍历,每次遍历到的节点赋给e,有下面几种情况的时候break:

1.e不为null,这个位置上已经有元素了,那么就判断当前要放入的key是否和当前该链表上遍历到的元素的key相同,如果相同,在根据传入的onlyIfAbsent判断,是否需要更新value的值。

2.e为null,这里表示遍历完了链表没有相同的key,那么如果前面的node返回不为null的话,那么就直接修改node的next属性为first,否者就new一个Entry对象,同样的是把next指向first。

接着判断count+1,即元素的个数是否超过一定值【这里的thresold扩容阈值在HashMap源码中已经解释了】,如果超过了就rehash方法进行扩容。注意这里的扩容只会对内层HashEntry数组进行扩容,对外层的Segment数组大小不会改变,最后用setEntryAt方法,把node放在该数组下标位置上,即完成头插法的最后一步。最终unlock释放锁就完成了整个ConcurrentHashMap的put操作。

static final <K,V> HashEntry<K,V> entryAt(HashEntry<K,V>[] tab, int i) {
        return (tab == null) ? null :
            (HashEntry<K,V>) UNSAFE.getObjectVolatile
            (tab, ((long)i << TSHIFT) + TBASE);
}
final V put(K key, int hash, V value, boolean onlyIfAbsent) {
            HashEntry<K,V> node = tryLock() ? null :
                scanAndLockForPut(key, hash, value);
            V oldValue;
            try {
                HashEntry<K,V>[] tab = table;
                int index = (tab.length - 1) & hash;
                HashEntry<K,V> first = entryAt(tab, index);
                for (HashEntry<K,V> e = first;;) {
                    if (e != null) {
                        K k;
                        if ((k = e.key) == key ||
                            (e.hash == hash && key.equals(k))) {
                            oldValue = e.value;
                            if (!onlyIfAbsent) {
                                e.value = value;
                                ++modCount;
                            }
                            break;
                        }
                        e = e.next;
                    }
                    else {
                        if (node != null)
                            node.setNext(first);
                        else
                            node = new HashEntry<K,V>(hash, key, value, first);
                        int c = count + 1;
                        if (c > threshold && tab.length < MAXIMUM_CAPACITY)
                            rehash(node);
                        else
                            setEntryAt(tab, index, node);
                        ++modCount;
                        count = c;
                        oldValue = null;
                        break;
                    }
                }
            } finally {
                unlock();
            }
            return oldValue;
        }
           

Segment中的rehash扩容方法

扩容只对内层数组进行扩容,不是对外层数组!

private void rehash(HashEntry<K,V> node) {
            HashEntry<K,V>[] oldTable = table;
            int oldCapacity = oldTable.length;
            int newCapacity = oldCapacity << 1;
            threshold = (int)(newCapacity * loadFactor);
            HashEntry<K,V>[] newTable =
                (HashEntry<K,V>[]) new HashEntry[newCapacity];
            int sizeMask = newCapacity - 1;
            for (int i = 0; i < oldCapacity ; i++) {
                HashEntry<K,V> e = oldTable[i];
                if (e != null) {
                    HashEntry<K,V> next = e.next;
                    int idx = e.hash & sizeMask;
                    if (next == null)   //  Single node on list
                        newTable[idx] = e;
                    else { // Reuse consecutive sequence at same slot
                        HashEntry<K,V> lastRun = e;
                        int lastIdx = idx;
                        for (HashEntry<K,V> last = next;
                             last != null;
                             last = last.next) {
                            int k = last.hash & sizeMask;
                            if (k != lastIdx) {
                                lastIdx = k;
                                lastRun = last;
                            }
                        }
                        newTable[lastIdx] = lastRun;
                        // Clone remaining nodes
                        for (HashEntry<K,V> p = e; p != lastRun; p = p.next) {
                            V v = p.value;
                            int h = p.hash;
                            int k = h & sizeMask;
                            HashEntry<K,V> n = newTable[k];
                            newTable[k] = new HashEntry<K,V>(h, p.key, v, n);
                        }
                    }
                }
            }
            int nodeIndex = node.hash & sizeMask; // add the new node
            node.setNext(newTable[nodeIndex]);
            newTable[nodeIndex] = node;
            table = newTable;
        }
           

这里的扩容代码不细分析,大体看一下,里面有两个for循环,并且扩容和HashMap一样,原本第i个位置上的元素可能放到新数组的第i个位置,也可能放到第i+扩容大小上的位置,就是下面这种情况。

JDK1.7的ConcurrentHashMap的put源码到底是怎么回事

那么就有一种情况,该位置上从某个节点开始到链表结尾的所有节点,都放到的是同一个位置,那么就只需要把这一段内的第一个节点放过去即可。类似下面这种情况,那么我在移动的时候就不需要每个元素都一个个的移动,我只需要遍历到2元素的时候,把2直接放到0的下面即可。

JDK1.7的ConcurrentHashMap的put源码到底是怎么回事

所以上面的源码中两个循环,第一个循环就是在找某一段连续到最后的链表区间最后放到的是新数组的同一个位置的开始节点。

第二个循环的时候,就会从0开始遍历,到2结束,把其中的遍历的元素组个放入新数组的对应位置即可。这就通过两个循环达到了扩容的目的。