首頁  >  文章  >  Java  >  如何在Java中利用ConcurrentHashMap實現線程安全的映射?

如何在Java中利用ConcurrentHashMap實現線程安全的映射?

PHPz
PHPz轉載
2023-05-10 10:25:12877瀏覽

jdk1.7版本

資料結構

    /**
     * The segments, each of which is a specialized hash table.
     */
    final Segment<K,V>[] segments;

可以看到主要就是一個Segment數組,註解也寫了,每個都是一個特殊的hash table。

來看一下Segment是什麼東西。

static final class Segment<K,V> extends ReentrantLock implements Serializable {
    	......
            /**
         * The per-segment table. Elements are accessed via
         * entryAt/setEntryAt providing volatile semantics.
         */
        transient volatile HashEntry<K,V>[] table;
        transient int threshold;
        final float loadFactor;
    	// 构造函数
        Segment(float lf, int threshold, HashEntry<K,V>[] tab) {
            this.loadFactor = lf;
            this.threshold = threshold;
            this.table = tab;
        }
  		......
    }

上面是部分程式碼,可以看到Segment繼承了ReentrantLock,所以其實每個Segment就是一個鎖。

裡面存放著HashEntry數組,該變數用volatile修飾。 HashEntry和hashmap的節點類似,也是一個鍊錶的節點。

來看看具體的程式碼,可以看到和hashmap裡面稍微不同的是,他的成員變數有用volatile修飾。

    static final class HashEntry<K,V> {
        final int hash;
        final K key;
        volatile V value;
        volatile HashEntry<K,V> next;
        HashEntry(int hash, K key, V value, HashEntry<K,V> next) {
            this.hash = hash;
            this.key = key;
            this.value = value;
            this.next = next;
        }
        ......
    }

所以ConcurrentHashMap的資料結構差不多是下圖的樣子。

如何在Java中利用ConcurrentHashMap實現線程安全的映射?

在建構的時候,Segment 的數量由所謂的 concurrentcyLevel 決定,預設是 16,也可以在對應建構子直接指定。請注意,Java 需要它是 2 的冪數值,如果輸入是類似 15 這個非冪值,會自動調整到 16 之類 2 的冪數值。

來看看原始碼,先從簡單的get方法開始

get()

    public V get(Object key) {
        Segment<K,V> s; // manually integrate access methods to reduce overhead
        HashEntry<K,V>[] tab;
        int h = hash(key);
        long u = (((h >>> segmentShift) & segmentMask) << SSHIFT) + SBASE;
        // 通过unsafe获取Segment数组的元素
        if ((s = (Segment<K,V>)UNSAFE.getObjectVolatile(segments, u)) != null &&
            (tab = s.table) != null) {
            // 还是通过unsafe获取HashEntry数组的元素
            for (HashEntry<K,V> e = (HashEntry<K,V>) UNSAFE.getObjectVolatile
                     (tab, ((long)(((tab.length - 1) & h)) << TSHIFT) + TBASE);
                 e != null; e = e.next) {
                K k;
                if ((k = e.key) == key || (e.hash == h && key.equals(k)))
                    return e.value;
            }
        }
        return null;
    }

get的邏輯很簡單,就是找到Segment對應下標的HashEntry數組,再找到HashEntry數組對應下標的鍊錶頭,再遍歷鍊錶取得資料。

這個取得陣列中的資料是使用UNSAFE.getObjectVolatile(segments, u),unsafe提供了像c語言的可以直接存取記憶體的能力。此方法可以取得物件的相應偏移量的資料。 u就是計算好的一個偏移量,所以等同於segments[i],只是效率更高。

put()

    public V put(K key, V value) {
        Segment<K,V> s;
        if (value == null)
            throw new NullPointerException();
        int hash = hash(key);
        int j = (hash >>> segmentShift) & segmentMask;
        if ((s = (Segment<K,V>)UNSAFE.getObject          // nonvolatile; recheck
             (segments, (j << SSHIFT) + SBASE)) == null) //  in ensureSegment
            s = ensureSegment(j);
        return s.put(key, hash, value, false);
    }

而對於put 操作,是以Unsafe 呼叫方式,直接取得對應的Segment,然後進行線程安全的put 操作:

主要邏輯在Segment內部的put方法

final V put(K key, int hash, V value, boolean onlyIfAbsent) {
            // scanAndLockForPut会去查找是否有key相同Node
            // 无论如何,确保获取锁
            HashEntry<K,V> node = tryLock() ? null :
                scanAndLockForPut(key, hash, value);
            V oldValue;
            try {
                HashEntry<K,V>[] tab = table;
                int index = (tab.length - 1) & hash;
                HashEntry<K,V> first = entryAt(tab, index);
                for (HashEntry<K,V> e = first;;) {
                    if (e != null) {
                        K k;
                        // 更新已有value...
                    }
                    else {
                        // 放置HashEntry到特定位置,如果超过阈值,进行rehash
                        // ...
                    }
                }
            } finally {
                unlock();
            }
            return oldValue;
        }

size()

來看一下主要的程式碼,

for (;;) {
    // 如果重试次数等于默认的2,就锁住所有的segment,来计算值
    if (retries++ == RETRIES_BEFORE_LOCK) {
        for (int j = 0; j < segments.length; ++j)
            ensureSegment(j).lock(); // force creation
    }
    sum = 0L;
    size = 0;
    overflow = false;
    for (int j = 0; j < segments.length; ++j) {
        Segment<K,V> seg = segmentAt(segments, j);
        if (seg != null) {
            sum += seg.modCount;
            int c = seg.count;
            if (c < 0 || (size += c) < 0)
                overflow = true;
        }
    }
    // 如果sum不再变化,就表示得到了一个确切的值
    if (sum == last)
        break;
    last = sum;
}

這裡其實就是計算所有segment的數量和,如果數量和跟上次取得到的值相等,就表示map沒有進行操作,這個值是相對正確的。如果重試兩次之後還是沒法得到一個統一的值,就鎖住所有的segment,再來取得值。

擴容

private void rehash(HashEntry<K,V> node) {
            HashEntry<K,V>[] oldTable = table;
            int oldCapacity = oldTable.length;
    		// 新表的大小是原来的两倍
            int newCapacity = oldCapacity << 1;
            threshold = (int)(newCapacity * loadFactor);
            HashEntry<K,V>[] newTable =
                (HashEntry<K,V>[]) new HashEntry[newCapacity];
            int sizeMask = newCapacity - 1;
            for (int i = 0; i < oldCapacity ; i++) {
                HashEntry<K,V> e = oldTable[i];
                if (e != null) {
                    HashEntry<K,V> next = e.next;
                    int idx = e.hash & sizeMask;
                    if (next == null)   //  Single node on list
                        newTable[idx] = e;
                    else { // Reuse consecutive sequence at same slot
                        // 如果有多个节点
                        HashEntry<K,V> lastRun = e;
                        int lastIdx = idx;
                        // 这里操作就是找到末尾的一段索引值都相同的链表节点,这段的头结点是lastRun.
                        for (HashEntry<K,V> last = next;
                             last != null;
                             last = last.next) {
                            int k = last.hash & sizeMask;
                            if (k != lastIdx) {
                                lastIdx = k;
                                lastRun = last;
                            }
                        }
                        // 然后将lastRun结点赋值给数组位置,这样lastRun后面的节点也跟着过去了。
                        newTable[lastIdx] = lastRun;
                        // 之后就是复制开头到lastRun之间的节点
                        // Clone remaining nodes
                        for (HashEntry<K,V> p = e; p != lastRun; p = p.next) {
                            V v = p.value;
                            int h = p.hash;
                            int k = h & sizeMask;
                            HashEntry<K,V> n = newTable[k];
                            newTable[k] = new HashEntry<K,V>(h, p.key, v, n);
                        }
                    }
                }
            }
            int nodeIndex = node.hash & sizeMask; // add the new node
            node.setNext(newTable[nodeIndex]);
            newTable[nodeIndex] = node;
            table = newTable;
        }

jdk1.8版本

資料結構

1.8的版本的ConcurrentHashmap整體上和Hashmap有點像,但是去除了segment,而是使用node的數組。

transient volatile Node<K,V>[] table;

1.8中還是有Segment這個內部類,但是存在的意義只是為了序列化相容,實際上已經不使用了。

來看node節點

    static class Node<K,V> implements Map.Entry<K,V> {
        final int hash;
        final K key;
        volatile V val;
        volatile Node<K,V> next;
        Node(int hash, K key, V val, Node<K,V> next) {
            this.hash = hash;
            this.key = key;
            this.val = val;
            this.next = next;
        }
        ......
    }

和HashMap中的node節點類似,也是實作Map.Entry,不同的是val和next加上了volatile修飾來確保可見度。

put()

    final V putVal(K key, V value, boolean onlyIfAbsent) {
        if (key == null || value == null) throw new NullPointerException();
        int hash = spread(key.hashCode());
        int binCount = 0;
        for (Node<K,V>[] tab = table;;) {
            Node<K,V> f; int n, i, fh;
            if (tab == null || (n = tab.length) == 0)
                // 初始化
                tab = initTable();
            else if ((f = tabAt(tab, i = (n - 1) & hash)) == null) {
                // 利用CAS去进行无锁线程安全操作,如果bin是空的
                if (casTabAt(tab, i, null,
                             new Node<K,V>(hash, key, value, null)))
                    break;                   // no lock when adding to empty bin
            }
            else if ((fh = f.hash) == MOVED)
                tab = helpTransfer(tab, f);
            else {
                V oldVal = null;
                synchronized (f) {
                     // 细粒度的同步修改操作... 
                    if (tabAt(tab, i) == f) {
                        if (fh >= 0) {
                            binCount = 1;
                            for (Node<K,V> e = f;; ++binCount) {
                                K ek;
                                // 找到相同key就更新
                                if (e.hash == hash &&
                                    ((ek = e.key) == key ||
                                     (ek != null && key.equals(ek)))) {
                                    oldVal = e.val;
                                    if (!onlyIfAbsent)
                                        e.val = value;
                                    break;
                                }
                                Node<K,V> pred = e;
                                // 没有相同的就新增
                                if ((e = e.next) == null) {
                                    pred.next = new Node<K,V>(hash, key,
                                                              value, null);
                                    break;
                                }
                            }
                        }
                        // 如果是树节点,进行树的操作
                        else if (f instanceof TreeBin) {
                            Node<K,V> p;
                            binCount = 2;
                            if ((p = ((TreeBin<K,V>)f).putTreeVal(hash, key,
                                                           value)) != null) {
                                oldVal = p.val;
                                if (!onlyIfAbsent)
                                    p.val = value;
                            }
                        }
                    }
                }
                // Bin超过阈值,进行树化
                if (binCount != 0) {
                    if (binCount >= TREEIFY_THRESHOLD)
                        treeifyBin(tab, i);
                    if (oldVal != null)
                        return oldVal;
                    break;
                }
            }
        }
        addCount(1L, binCount);
        return null;
    }

可以看到,在同步邏輯上,它使用的是 synchronized,而不是通常建議的 ReentrantLock 之類,這是為什麼呢?現在 JDK1.8 中,synchronized 已經被不斷優化,可以不再過分擔心性能差異,另外,相比於 ReentrantLock,它可以減少內存消耗,這是一個非常大的優勢。

同時,更多細節實作透過使用 Unsafe 進行了最佳化,例如 tabAt 是直接利用 getObjectAcquire,避免間接呼叫的開銷。

那麼,再來看看size是怎麼操作的呢?

    final long sumCount() {
        CounterCell[] as = counterCells; CounterCell a;
        long sum = baseCount;
        if (as != null) {
            for (int i = 0; i < as.length; ++i) {
                if ((a = as[i]) != null)
                    sum += a.value;
            }
        }
        return sum;
    }

這裡就是取得成員變數counterCells,遍歷取得總數。

其實,對於 CounterCell 的操作,是基於 java.util.concurrent.atomic.LongAdder 進行的,是一種 JVM 利用空間換取更高效率的方法,利用了Striped64內部的複雜邏輯。這個東西非常小眾,大多數情況下,建議還是使用 AtomicLong,足以滿足絕大部分應用的效能需求。

擴容

 private final void transfer(Node<K,V>[] tab, Node<K,V>[] nextTab) {
		......
        // 初始化
        if (nextTab == null) {            // initiating
            try {
                @SuppressWarnings("unchecked")
                Node<K,V>[] nt = (Node<K,V>[])new Node<?,?>[n << 1];
                nextTab = nt;
            } catch (Throwable ex) {      // try to cope with OOME
                sizeCtl = Integer.MAX_VALUE;
                return;
            }
            nextTable = nextTab;
            transferIndex = n;
        }
        int nextn = nextTab.length;
        ForwardingNode<K,V> fwd = new ForwardingNode<K,V>(nextTab);
     	// 是否继续处理下一个
        boolean advance = true;
     	// 是否完成
        boolean finishing = false; // to ensure sweep before committing nextTab
        for (int i = 0, bound = 0;;) {
            Node<K,V> f; int fh;
            while (advance) {
                int nextIndex, nextBound;
                if (--i >= bound || finishing)
                    advance = false;
                else if ((nextIndex = transferIndex) <= 0) {
                    i = -1;
                    advance = false;
                }
                // 首次循环才会进来这里
                else if (U.compareAndSwapInt
                         (this, TRANSFERINDEX, nextIndex,
                          nextBound = (nextIndex > stride ?
                                       nextIndex - stride : 0))) {
                    bound = nextBound;
                    i = nextIndex - 1;
                    advance = false;
                }
            }
            if (i < 0 || i >= n || i + n >= nextn) {
                int sc;
                //扩容结束后做后续工作
                if (finishing) {
                    nextTable = null;
                    table = nextTab;
                    sizeCtl = (n << 1) - (n >>> 1);
                    return;
                }
                //每当一条线程扩容结束就会更新一次 sizeCtl 的值,进行减 1 操作
                if (U.compareAndSwapInt(this, SIZECTL, sc = sizeCtl, sc - 1)) {
                    if ((sc - 2) != resizeStamp(n) << RESIZE_STAMP_SHIFT)
                        return;
                    finishing = advance = true;
                    i = n; // recheck before commit
                }
            }
            // 如果是null,设置fwd
            else if ((f = tabAt(tab, i)) == null)
                advance = casTabAt(tab, i, null, fwd);
            // 说明该位置已经被处理过了,不需要再处理
            else if ((fh = f.hash) == MOVED)
                advance = true; // already processed
            else {
                // 真正的处理逻辑
                synchronized (f) {
                    if (tabAt(tab, i) == f) {
                        Node<K,V> ln, hn;
                        if (fh >= 0) {
                            int runBit = fh & n;
                            Node<K,V> lastRun = f;
                            for (Node<K,V> p = f.next; p != null; p = p.next) {
                                int b = p.hash & n;
                                if (b != runBit) {
                                    runBit = b;
                                    lastRun = p;
                                }
                            }
                            if (runBit == 0) {
                                ln = lastRun;
                                hn = null;
                            }
                            else {
                                hn = lastRun;
                                ln = null;
                            }
                            for (Node<K,V> p = f; p != lastRun; p = p.next) {
                                int ph = p.hash; K pk = p.key; V pv = p.val;
                                if ((ph & n) == 0)
                                    ln = new Node<K,V>(ph, pk, pv, ln);
                                else
                                    hn = new Node<K,V>(ph, pk, pv, hn);
                            }
                            setTabAt(nextTab, i, ln);
                            setTabAt(nextTab, i + n, hn);
                            setTabAt(tab, i, fwd);
                            advance = true;
                        }
                        // 树节点操作
                        else if (f instanceof TreeBin) {
                            ......
                        }
                    }
                }
            }
        }
    }
     }
                        setTabAt(nextTab, i, ln);
                        setTabAt(nextTab, i + n, hn);
                        setTabAt(tab, i, fwd);
                        advance = true;
                    }
                    // 树节点操作
                    else if (f instanceof TreeBin) {
                        ......
                    }
                }
            }
        }
    }
}

核心邏輯和HashMap一樣也是建立兩個鍊錶,只是多了取得lastRun的操作。

以上是如何在Java中利用ConcurrentHashMap實現線程安全的映射?的詳細內容。更多資訊請關注PHP中文網其他相關文章!

陳述:
本文轉載於:yisu.com。如有侵權,請聯絡admin@php.cn刪除