天天看点

并发编程(八)无锁并发: CAS(Compare And Swap)

Concurrency Programming 八

  • 无锁并发: CAS(Compare And Swap)
    • 原子整数类
    • 原子引用类
    • 原子数组类
    • 原子字段更新器类
    • 原子累加器类
    • Unsafe

无锁并发: CAS(Compare And Swap)

  • 可在多核 CPU环境下无阻塞的方式来保证原子性. 它不是通过加锁的方式来保护共享变量的线程安全
  • *特点: CAS适用于多核 CPU, 同时线程数不能多于核数的环境
原理比较:
  • 悲观锁(synchronized/ReentrantLock): 当前线程抢到锁, 则其它线程会被阻塞, 因此线程上下文切换频次会相对频繁
  • 乐观锁(CAS): 没有阻塞的状态, 每当线程分到 CPU时间片, 都会不断地尝试比较, 当比较成功, 则会交换(更改)
  • CAS例子:
interface Account {
    // 获取余额
    Integer getBalance();
    // 取款
    void withdraw(Integer amount);
}

class AccountSafe implements Account {
    // 原子整数类
    private AtomicInteger balance;

    public AccountSafe(Integer balance) {
        this.balance = new AtomicInteger(balance);
    }

    @Override
    public Integer getBalance() {
        return balance.get();
    }

    @Override
    public void withdraw(Integer amount) {
        // 不断尝试, 直到成功
        while (true) {
            // 旧值
            int prev = balance.get();
            // 将要设置的值
            int next = prev - amount;
            /** boolean compareAndSet(int prev, int next)方法, 在设置(Set) next值时, 会先比较 prev值与当前的余额
             * 1. 不一致, 则失败. 返回 false; 进入下次循环重试
             * 2. 一致, 则成功. 返回 true; 结束循环尝试 * */
            if (balance.compareAndSet(prev, next)) {
                // 成功设置, 结束循环尝试
                break;
            }
        }
        // 以上方式, 可以简化为下面的方式
        // balance.addAndGet(-1 * amount);
    }
}

public class App {
    public static void main(String[] args) {
        // 起始余额 10000
        Account account = new AccountSafe(10000);
        // 开始时间
        long start = System.nanoTime();
        List<Thread> ts = new ArrayList<>();
        // 创建 1000个线程
        for (int i = 0; i < 1000; i++) {
            ts.add(new Thread(() -> {
                // 每个线程做 -10
                account.withdraw(10);
            }));
        }
        // 启动所有(1000个)线程
        ts.forEach(Thread::start);
        ts.forEach(t -> {
            try {
                // 同步等待所有线程运行结束
                t.join();
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
        });
        // 最后正确的余额为 0
        System.out.println("当前余额: " + account.getBalance() +
                ", 耗费时间: " + (System.nanoTime()-start) / 1000_000 + " ms");
        // 当前余额: 0, 耗费时间: 556 ms
    }
}

           

原子整数类

  • AtomicBoolean, AtomicInteger, AtomicLong
  • AtomicInteger为例:
AtomicInteger i = new AtomicInteger(0);
// 获取并自增(i = 0, 结果 i = 1, 返回 0),类似于 i++
System.out.println(i.getAndIncrement());
// 自增并获取(i = 1, 结果 i = 2, 返回 2),类似于 ++i
System.out.println(i.incrementAndGet());
// 自减并获取(i = 2, 结果 i = 1, 返回 1),类似于 --i
System.out.println(i.decrementAndGet());
// 获取并自减(i = 1, 结果 i = 0, 返回 1),类似于 i--
System.out.println(i.getAndDecrement());
// 获取并加值(i = 0, 结果 i = 5, 返回 0)
System.out.println(i.getAndAdd(5));
// 加值并获取(i = 5, 结果 i = 0, 返回 0)
System.out.println(i.addAndGet(-5));
// 获取并更新(i = 0, p 为 i 的当前值, 结果 i = -2, 返回 0)
// 其中函数中的操作能保证原子,但函数需要无副作用
System.out.println(i.getAndUpdate(p -> p - 2));
// 更新并获取(i = -2, p 为 i 的当前值, 结果 i = 0, 返回 0)
// 其中函数中的操作能保证原子,但函数需要无副作用
System.out.println(i.updateAndGet(p -> p + 2));
// 获取并计算(i = 0, p 为 i 的当前值, x 为参数1, 结果 i = 10, 返回 0)
// 其中函数中的操作能保证原子,但函数需要无副作用
// getAndUpdate 如果在 lambda 中引用了外部的局部变量,要保证该局部变量是 final 的
// getAndAccumulate 可以通过 参数1 来引用外部的局部变量,但因为其不在 lambda 中因此不必是 final
System.out.println(i.getAndAccumulate(10, (p, x) -> p + x));
// 计算并获取(i = 10, p 为 i 的当前值, x 为参数1, 结果 i = 0, 返回 0)
// 其中函数中的操作能保证原子,但函数需要无副作用
System.out.println(i.accumulateAndGet(-10, (p, x) -> p + x));

           

原子引用类

  • AtomicReference, AtomicStampedReference, AtomicMarkableReference
  • AtomicReference为例:
interface Account {
    // 获取余额
    BigDecimal getBalance();
    // 取款
    void withdraw(BigDecimal amount);
}

class AccountSafe implements Account {
    AtomicReference<BigDecimal> ref;

    public AccountSafe(BigDecimal balance) {
        ref = new AtomicReference<>(balance);
    }

    @Override
    public BigDecimal getBalance() {
        return ref.get();
    }

    @Override
    public void withdraw(BigDecimal amount) {
        while (true) {
            BigDecimal prev = ref.get();
            BigDecimal next = prev.subtract(amount);
            if (ref.compareAndSet(prev, next)) {
                break;
            }
        }
    }
}

public class App {
    public static void main(String[] args) {
        // 起始余额 10000
        Account account = new AccountSafe(new BigDecimal("10000"));
        // 开始时间
        long start = System.nanoTime();
        List<Thread> ts = new ArrayList<>();
        // 创建 1000个线程
        for (int i = 0; i < 1000; i++) {
            ts.add(new Thread(() -> {
                // 每个线程做 -10
                account.withdraw(BigDecimal.TEN);
            }));
        }
        // 启动所有(1000个)线程
        ts.forEach(Thread::start);
        ts.forEach(t -> {
            try {
                // 同步等待所有线程运行结束
                t.join();
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
        });
        // 最后正确的余额为 0
        System.out.println("当前余额: " + account.getBalance() +
                ", 耗费时间: " + (System.nanoTime()-start) / 1000_000 + " ms");
        // 当前余额: 0, 耗费时间: 395 ms
    }
}

           
  • ABA问题: 比如 AtomicReference共享变量最初值 A, 在 CAS比较时, 无法感知到从 A改为 B, 再改回 A的情况.

    *为了追踪此种变化过程就有了附带版本号的原子引用 AtomicStampedReference, 一旦出现以上情况, 则会更改失败以及可以查询当前版本号

public class App {
    // 原子引用(附加版本)
    static AtomicStampedReference<String> ref = new AtomicStampedReference<>("A", 0);
    public static void main(String[] args) throws InterruptedException {
        // 获取版本号
        int stamp = ref.getStamp();
        System.out.println("起始版本 " + stamp);
        // 通过子线程修改
        other();
        TimeUnit.MILLISECONDS.sleep(1000);
        // 在主线程, 延迟1秒后尝试修改, 请求版本号为0, 改后1
        System.out.print("A尝试改为 C " +
                ref.compareAndSet(ref.getReference(), "C", stamp, stamp + 1));
        System.out.println(", 当前版本为 " + ref.getStamp());
    }

    private static void other() throws InterruptedException {
        // 线程1
        new Thread(() -> {
            System.out.print("A尝试改为 B " +
                    ref.compareAndSet(ref.getReference(), "B", ref.getStamp(), ref.getStamp() + 1));
            System.out.println(", 版本更新为 " + ref.getStamp());
        }, "t1").start();
        // 睡眠, 延迟0.5秒
        TimeUnit.MILLISECONDS.sleep(500);
        // 线程2
        new Thread(() -> {
            System.out.print("B尝试改为 A " +
                    ref.compareAndSet(ref.getReference(), "A", ref.getStamp(), ref.getStamp() + 1));
            System.out.println(", 版本更新为 " + ref.getStamp());
        }, "t2").start();
    }
}
起始版本 0
A尝试改为 B true, 版本更新为 1 # t1
B尝试改为 A true, 版本更新为 2 # t2
A尝试改为 C false, 当前版本为 2 # 主线程

           

原子数组类

  • AtomicIntegerArray, AtomicLongArray, AtomicReferenceArray
  • 线程[安全& 不安全]的数组演示例子: 保护数组内元素的线程安全
public class App {
    private static <T> void demo(
            Supplier<T> arraySupplier, // 提供数组(线程不安全数组/线程安全数组)
            Function<T, Integer> lengthFun, // 获取数组长度的方法
            BiConsumer<T, Integer> putConsumer, // 自增方法, 回传 array, index
            Consumer<T> printConsumer ) { // 打印数组的方法
        List<Thread> ts = new ArrayList<>();
        T array = arraySupplier.get();
        int length = lengthFun.apply(array);
        for (int i = 0; i < length; i++) {
            // 创建线程, 同时对提供的数组, 做10000次操作
            ts.add(new Thread(() -> {
                for (int j = 0; j < 10000; j++) {
                    putConsumer.accept(array, j % length);
                }
            }));
        }
        // 启动所有线程
        ts.forEach(t -> t.start());
        ts.forEach(t -> {
            try {
                // 同步等待所有线程运行结束
                t.join();
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
        });
        printConsumer.accept(array);
    }
}

           
  • 线程不安全的数组演示例子:
public static void main(String[] args) {
        demo(
            () -> new int[10],
            (array) -> array.length,
            (array, index) -> array[index]++,
            array-> System.out.println(Arrays.toString(array))
        );
    }

> [9902, 9905, 9907, 9906, 9918, 9918, 9916, 9913, 9902, 9903]

           
  • 线程安全的数组演示例子:
public static void main(String[] args) {
        demo(
            () -> new AtomicIntegerArray(10),
            (array) -> array.length(),
            (array, index) -> array.getAndIncrement(index),
            array -> System.out.println(array)
        );
    }

> [10000, 10000, 10000, 10000, 10000, 10000, 10000, 10000, 10000, 10000]

           

原子字段更新器类

  • AtomicReferenceFieldUpdater, AtomicIntegerFieldUpdater, AtomicLongFieldUpdater
  • 可以对对象内的域(Field)进行原子操作, 注: 属性必须 volatile修饰, 否则会抛异常
public class App {
    private volatile int field;
    public static void main(String[] args) {
        AtomicIntegerFieldUpdater fieldUpdater = AtomicIntegerFieldUpdater.newUpdater(App.class, "field");
        App app = new App();
        fieldUpdater.compareAndSet(app, 0, 10);
        // 修改成功 field = 10
        System.out.println(app.field);
        // 修改成功 field = 20
        fieldUpdater.compareAndSet(app, 10, 20);
        System.out.println(app.field);
        // 修改失败 field = 20
        fieldUpdater.compareAndSet(app, 10, 30);
        System.out.println(app.field);
        // 修改成功 field = 40
        fieldUpdater.compareAndSet(app, 20, 40);
        System.out.println(app.field);
    }
}
10
20
20
40

           

原子累加器类

  • LongAdder, DoubleAdder
  • LongAdder使用例子:
public class App33 {
    private static <T> void demo(Supplier<T> adderSupplier, Consumer<T> action) {
        T adder = adderSupplier.get();
        long start = System.nanoTime();
        List<Thread> ts = new ArrayList<>();
        // 创建 10个线程, 每个线程累加100万次
        for (int i = 0; i < 10; i++) {
            ts.add(new Thread(() -> {
                for (int j = 0; j < 1000000; j++) {
                    action.accept(adder);
                }
            }));
        }
        // 启动所有(10个)线程
        ts.forEach(t -> t.start());
        ts.forEach(t -> {
            try {
                // 同步等待所有线程运行结束
                t.join();
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
        });
        System.out.println("总计: " + adder +
                ", 耗费时间: " + (System.nanoTime()-start) / 1000_000 + " ms");
    }

    public static void main(String[] args) {
        System.out.println("LongAdder:");
        for (int i = 0; i < 5; i++) {
            demo(() -> new LongAdder(), adder -> adder.increment());
        }
    }
}
LongAdder:
总计: 10000000, 耗费时间: 229 ms
总计: 10000000, 耗费时间: 94 ms
总计: 10000000, 耗费时间: 134 ms
总计: 10000000, 耗费时间: 65 ms
总计: 10000000, 耗费时间: 92 ms

           
  • 与累加器 AtomicLong性能比较:
public static void main(String[] args) {
        System.out.println("AtomicLong:");
        for (int i = 0; i < 5; i++) {
            demo(() -> new AtomicLong(), adder -> adder.getAndIncrement());
        }
    }
AtomicLong:
总计: 10000000, 耗费时间: 192 ms
总计: 10000000, 耗费时间: 180 ms
总计: 10000000, 耗费时间: 209 ms
总计: 10000000, 耗费时间: 209 ms
总计: 10000000, 耗费时间: 247 ms

           

Unsafe

  • Unsafe对象无法直接调用, 而只能通过反射获得
  • 使用例子:
class UnsafeAccessor {
    static Unsafe unsafe;
    static {
        try {
            Field theUnsafe = Unsafe.class.getDeclaredField("theUnsafe");
            // 无法读取 public以外的, 此项默认 false. 设置为 true来取消封装
            theUnsafe.setAccessible(true);
            unsafe = (Unsafe) theUnsafe.get(null);
        } catch (NoSuchFieldException | IllegalAccessException e) {
            throw new Error(e);
        }
    }
    static Unsafe getUnsafe() {
        return unsafe;
    }
}

@Data
class Student {
    volatile int id;
    volatile String name;
}

public class App34 {
    public static void main(String[] args) throws NoSuchFieldException {
        Unsafe unsafe = UnsafeAccessor.getUnsafe();
        Field id = Student.class.getDeclaredField("id");
        Field name = Student.class.getDeclaredField("name");
        // 获得成员变量的偏移量
        long idOffset = unsafe.objectFieldOffset(id);
        long nameOffset = unsafe.objectFieldOffset(name);
        Student student = new Student();
        // 使用 CAS方法替换成员变量的值
        System.out.println("设置 id: " + unsafe.compareAndSwapInt(student, idOffset, 0, 20));
        System.out.println("设置 name: " + unsafe.compareAndSwapObject(student, nameOffset, null, "张三"));
        System.out.println(student);
    }
}
设置 id: true
设置 name: true
Student(id=20, name=张三)

           
  • 自定义原子整数例子:
class AtomicData {
    private volatile int data;
    static final Unsafe unsafe;
    static final long DATA_OFFSET;
    static {
        unsafe = UnsafeAccessor.getUnsafe();
        try {
            // data属性在对象中的偏移量, 用于 Unsafe直接访问该属性
            DATA_OFFSET = unsafe.objectFieldOffset(AtomicData.class.getDeclaredField("data"));
        } catch (NoSuchFieldException e) {
            throw new Error(e);
        }
    }
    
    public AtomicData(int data) {
        this.data = data;
    }

    public void withdraw(int amount) {
        int oldValue;
        while(true) {
            oldValue = data;
            // CAS尝试修改 data为(旧值 - amount)
            if (unsafe.compareAndSwapInt(this, DATA_OFFSET, oldValue, oldValue - amount)) {
                return;
            }
        }
    }

    public int getData() {
        return data;
    }
}

interface Account {
    // 获取余额
    Integer getBalance();
    // 取款
    void withdraw(Integer amount);
}

class AccountSafe implements Account {
    private AtomicData balance;

    public AccountSafe(Integer balance) {
        this.balance = new AtomicData(balance);
    }

    @Override
    public Integer getBalance() {
        return balance.getData();
    }

    @Override
    public void withdraw(Integer amount) {
        balance.withdraw(amount);
    }
}

public class App {
    public static void main(String[] args) {
        // 起始余额 10000
        Account account = new AccountSafe(10000);
        // 开始时间
        long start = System.nanoTime();
        List<Thread> ts = new ArrayList<>();
        // 创建 1000个线程
        for (int i = 0; i < 1000; i++) {
            ts.add(new Thread(() -> {
                // 每个线程做 -10
                account.withdraw(10);
            }));
        }
        // 启动所有(1000个)线程
        ts.forEach(Thread::start);
        ts.forEach(t -> {
            try {
                // 同步等待所有线程运行结束
                t.join();
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
        });
        // 最后正确的余额为 0
        System.out.println("当前余额: " + account.getBalance() +
                ", 耗费时间: " + (System.nanoTime()-start) / 1000_000 + " ms");
        // 当前余额: 0, 耗费时间: 421 ms
    }
}

           
如果您觉得有帮助,欢迎点赞哦 ~ 谢谢!!