天天看点

Java JUC包源码分析 - 多阶段任务Phaser

Phaser是jdk1.7才出现的,可以实现分阶段实现任务,多个线程执行第一个阶段任务,等待所有线程第一阶段执行完了才开始执行第二阶段,如此类推。其实就是多个栅栏CyclicBarrier,只不过这个Phaser比较灵活。

先看下用法: 上个例子写的不够贴切,再深入理解一下再写。。。。。。。

package com.pzx.test.test002;

import java.text.SimpleDateFormat;
import java.util.Date;
import java.util.concurrent.Phaser;

public class PhaserDemo {
    private static SimpleDateFormat sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss ");
    public static void main(String[] args) {
        final int workers = 3;

        MyPhaser myPhaser = new MyPhaser(workers);
        Worker[] worker = new Worker[workers];

        //myPhaser.bulkRegister(workers);
        Thread[] threads = new Thread[workers];
        for (int i=0; i<workers; i++) {
            worker[i] = new Worker(myPhaser);
            threads[i] = new Thread(worker[i], "thread" + i);
        }

        for (int i=0; i<workers; i++) {
            threads[i].start();
        }

    }

    static class MyPhaser extends Phaser {

        MyPhaser(int phaser) {
            super(phaser);
        }

        // 这个方法是一个回调函数,在所有线程执行完毕,到达下一阶段之前就会调用这个方法,如果这个方法返回true,就会停止移相器
        @Override
        protected boolean onAdvance(int phase, int registeredParties) {
            switch (phase) {
                case 0:
                    prepare();
                    return false;
                case 1:
                    firstPhaser();
                    return false;
                case 2:
                    secondPhaser();
                    return false;
                case 3:
                    thirdPhaser();
                    return false;
                default:
                    lastPhaser();
                    return true;
            }


        }

        private void prepare() {
            System.out.println(sdf.format(new Date()) + "the prepare phaser finished...");
        }

        private void lastPhaser() {
            System.out.println(sdf.format(new Date()) + "the last phaser finished...");
        }

        private void thirdPhaser() {
            System.out.println(sdf.format(new Date()) + "the third phaser finished...");
        }

        private void secondPhaser() {
            System.out.println(sdf.format(new Date()) + "the second phaser finished...");
        }

        private void firstPhaser() {
            System.out.println(sdf.format(new Date()) + "the first phaser finished...");
        }
    }

    static class Worker implements Runnable {

        private MyPhaser phaser;

        Worker(MyPhaser phaser) {
            this.phaser = phaser;
        }

        @Override
        public void run() {
            prepareWork();
            phaser.arriveAndAwaitAdvance();
            // 如果把下面的阶段都注视掉就是一个栅栏
            firstPhaserWork();
            phaser.arriveAndAwaitAdvance();
            secondPhaserWork();
            phaser.arriveAndAwaitAdvance();
            thirdPhaserWork();
            phaser.arriveAndAwaitAdvance();
            lastPhaserWork();
            phaser.arriveAndAwaitAdvance();
        }

        private void prepareWork() {
            System.out.println(sdf.format(new Date()) + Thread.currentThread().getName() + " start prepare at the " + phaser.getPhase()+ " phaser");
            try {
                Thread.sleep(2000);
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
        }

        private void lastPhaserWork() {
            System.out.println(sdf.format(new Date()) + Thread.currentThread().getName() + " start finish at the " + phaser.getPhase()+ " phaser");
            try {
                Thread.sleep(2000);
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
        }

        private void thirdPhaserWork() {
            System.out.println(sdf.format(new Date()) + Thread.currentThread().getName() + " start work at the " + phaser.getPhase()+ " phaser");
            try {
                Thread.sleep(3000);
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
        }

        private void secondPhaserWork() {
            System.out.println(sdf.format(new Date()) + Thread.currentThread().getName() + " start work at the " + phaser.getPhase()+ " phaser");
            try {
                Thread.sleep(3000);
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
        }

        private void firstPhaserWork() {
            System.out.println(sdf.format(new Date()) + Thread.currentThread().getName() + " start work at the " + phaser.getPhase()+ " phaser");
            try {
                Thread.sleep(3000);
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
        }
    }
}
           

执行结果:

2018-09-25 12:56:16 thread1 start prepare at the 0 phaser

2018-09-25 12:56:16 thread2 start prepare at the 0 phaser

2018-09-25 12:56:16 thread0 start prepare at the 0 phaser

2018-09-25 12:56:18 the prepare phaser finished...

2018-09-25 12:56:18 thread0 start work at the 1 phaser

2018-09-25 12:56:18 thread1 start work at the 1 phaser

2018-09-25 12:56:18 thread2 start work at the 1 phaser

2018-09-25 12:56:21 the first phaser finished...

2018-09-25 12:56:21 thread2 start work at the 2 phaser

2018-09-25 12:56:21 thread1 start work at the 2 phaser

2018-09-25 12:56:21 thread0 start work at the 2 phaser

2018-09-25 12:56:24 the second phaser finished...

2018-09-25 12:56:24 thread0 start work at the 3 phaser

2018-09-25 12:56:24 thread2 start work at the 3 phaser

2018-09-25 12:56:24 thread1 start work at the 3 phaser

2018-09-25 12:56:27 the third phaser finished...

2018-09-25 12:56:27 thread1 start finish at the 4 phaser

2018-09-25 12:56:27 thread0 start finish at the 4 phaser

2018-09-25 12:56:27 thread2 start finish at the 4 phaser

2018-09-25 12:56:29 the last phaser finished...

接下来看下各个方法的源码与作用:

public class Phaser {
     /*
     * 还没有达到栅栏的线程个数
     * unarrived  -- the number of parties yet to hit barrier (bits  0-15)
     * 参与的线程个数
     * parties    -- the number of parties to wait            (bits 16-31)
     * 栅栏的代数,也就是下面说的多少代
     * phase      -- the generation of the barrier            (bits 32-62)
     * 栅栏是否终止
     * terminated -- set if barrier is terminated             (bit  63 / sign)
     */
    // 上面是state状态变量各个位的取值
    private volatile long state;

    /**
     * The parent of this phaser, or null if none
     */
    // Phaser是多层次的,可以有父Phaser
    private final Phaser parent;

    /**
     * The root of phaser tree. Equals this if not in a tree.
     */
    // 当变成树机构时就会有root节点        
    private final Phaser root;


    // 构造函数
    // parties是参与者的个数,就是线程的个数
    public Phaser(Phaser parent, int parties) {
        if (parties >>> PARTIES_SHIFT != 0)
            throw new IllegalArgumentException("Illegal number of parties");
        int phase = 0;
        this.parent = parent;
        // 如果有父节点,就说树形结构,那就需要把根节点确定下来
        if (parent != null) {
            final Phaser root = parent.root;
            this.root = root;
            this.evenQ = root.evenQ;
            this.oddQ = root.oddQ;
            if (parties != 0)
                phase = parent.doRegister(1);
        }
        else {
            this.root = this;
            this.evenQ = new AtomicReference<QNode>();
            this.oddQ = new AtomicReference<QNode>();
        }
        // 确定状态变量的值
        this.state = (parties == 0) ? (long)EMPTY :
            ((long)phase << PHASE_SHIFT) |
            ((long)parties << PARTIES_SHIFT) |
            ((long)parties);
    }

    public Phaser() {
        this(null, 0);
    }
    public Phaser(int parties) {
        this(null, parties);
    }
    public Phaser(Phaser parent) {
        this(parent, 0);
    }

    // 注册一个参与者进来
    public int register() {
        // 把parties和unarrived个数都+1
        return doRegister(1);
    }
    // 注册多个参与者进来,如果parties为0,则返回当前处于的代数
    public int bulkRegister(int parties) {
        if (parties < 0)
            throw new IllegalArgumentException();
        if (parties == 0)
            return getPhase();
        return doRegister(parties);
    }
    
    // 一直阻塞,等待其他parties线程到达当前phaser,同时返回当前phaser
    public int arrive() {
        return doArrive(ONE_ARRIVAL);
    }
    // 达到phaser,不等其他线程到达就回收一个线程
    public int arriveAndDeregister() {
        return doArrive(ONE_DEREGISTER);
    }
    // 到达phaser,并且等待其他线程的到达当前phaser,等同于awaitAdvance(arrive()),
    public int arriveAndAwaitAdvance() {
        // Specialization of doArrive+awaitAdvance eliminating some reads/paths
        final Phaser root = this.root;
        for (;;) {
            long s = (root == this) ? state : reconcileState();
            int phase = (int)(s >>> PHASE_SHIFT);
            if (phase < 0)
                return phase;
            int counts = (int)s;
            int unarrived = (counts == EMPTY) ? 0 : (counts & UNARRIVED_MASK);
            if (unarrived <= 0)
                throw new IllegalStateException(badArrive(s));
            if (UNSAFE.compareAndSwapLong(this, stateOffset, s,
                                          s -= ONE_ARRIVAL)) {
                // 如果没到达的线程数大于1,就阻塞。
                if (unarrived > 1)
                    return root.internalAwaitAdvance(phase, null);
                // 如果当前phaser不是根结点,就让父节点阻塞
                if (root != this)
                    return parent.arriveAndAwaitAdvance();
                long n = s & PARTIES_MASK;  // base of next state
                int nextUnarrived = (int)n >>> PARTIES_SHIFT;
                // 能走到这,unarrived=0,也就是所有线程都到达了,这时执行回调函数onAdvance
                if (onAdvance(phase, nextUnarrived))
                    // 返回true说明要终止phaser了,所以修改状态值
                    n |= TERMINATION_BIT;
                else if (nextUnarrived == 0)
                    n |= EMPTY;
                else
                    n |= nextUnarrived;
                // 计算下一代
                int nextPhase = (phase + 1) & MAX_PHASE;
                n |= (long)nextPhase << PHASE_SHIFT;
                if (!UNSAFE.compareAndSwapLong(this, stateOffset, s, n))
                    return (int)(state >>> PHASE_SHIFT); // terminated
                // 从phaser的队列里移除并唤醒那些阻塞的线程
                releaseWaiters(phase);
                // 返回下一代phase
                return nextPhase;
            }
        }
    }
    
    // 如果当前phaser在第phase代时,就阻塞等待其他parties到达,否则就直接返回当前代数。
    // 这个方法没有抛出异常
    public int awaitAdvance(int phase) {
        final Phaser root = this.root;
        long s = (root == this) ? state : reconcileState();
        int p = (int)(s >>> PHASE_SHIFT);
        if (phase < 0)
            return phase;
        if (p == phase)
            return root.internalAwaitAdvance(phase, null);
        return p;
    }
    // 相比上个方法,这个方法是当前线程被中断时抛出中断异常,还有个阻塞超时的方法
    public int awaitAdvanceInterruptibly(int phase)
        throws InterruptedException {
        final Phaser root = this.root;
        long s = (root == this) ? state : reconcileState();
        int p = (int)(s >>> PHASE_SHIFT);
        if (phase < 0)
            return phase;
        if (p == phase) {
            QNode node = new QNode(this, phase, true, false, 0L);
            p = root.internalAwaitAdvance(phase, node);
            if (node.wasInterrupted)
                throw new InterruptedException();
        }
        return p;
    }
    // 强行终止phaser
    public void forceTermination() {
        // Only need to change root state
        final Phaser root = this.root;
        long s;
        while ((s = root.state) >= 0) {
            if (UNSAFE.compareAndSwapLong(root, stateOffset,
                                          s, s | TERMINATION_BIT)) {
                // signal all threads
                releaseWaiters(0); // Waiters on evenQ
                releaseWaiters(1); // Waiters on oddQ
                return;
            }
        }
    }
    // 回调方法,一般是要重写这个方法来决定何时终止phaser
    protected boolean onAdvance(int phase, int registeredParties) {
        return registeredParties == 0;
    }
    
}
           

考虑用Phaser代替CountDownLatch和CyclicBarrier

CountDownLatch的await()和countdown()分别对应Phaser的awaitAdvance(int n)和arrive();

CyclicBarrier的await()对应着Phaser的arriveAndAwaitAdvance();