天天看点

手写RPC-具体实现细节详解(近4w字详情)

这里面是远程通讯的核心,包含了网络通信、编解码协议、远程调用、注册中心、负载均衡等核心代码都在这里面,下面就详细分析下;

Tips:这里的文章有点滞后性,实际代码有一点修改,所以在看这里的内容的时候最好是跟着代码一起看,但是核心思路和绝大部分代码都是一样的,有一些修改:

  • 将实际调用改成了异步;
  • 新增了优雅上下线;
  • 新增了filter和快速失败等;

不过这些看下代码就懂了,很简单;所以这个文章影响不大,只是为了让新手能更好理解;

项目目录

.
└── core
    ├── compress
    │   ├── DefaultCompressor.java
    │   └── GzipCompressor.java
    ├── config
    │   ├── ConfigManager.java
    │   └── loader
    │       ├── PropertiesConfigLoader.java
    │       └── SystemPropertyLoader.java
    ├── loadbalance
    │   ├── AbstractLoadBalance.java
    │   ├── LoadBalanceFactory.java
    │   └── rule
    │       ├── random
    │       │   ├── RandomRule.java
    │       │   └── RandomWeightRule.java
    │       └── round
    │           ├── NormalWeightRoundRule.java
    │           ├── RoundRule.java
    │           └── SmoothWeightRoundRule.java
    ├── network
    │   ├── cache
    │   │   ├── ConnectCache.java
    │   │   ├── RegisterInfoCache.java
    │   │   └── SimpleRpcServiceCache.java
    │   ├── client
    │   │   ├── ClientSocketHandler.java
    │   │   └── RpcClientSocket.java
    │   ├── codec
    │   │   ├── RpcDecoder.java
    │   │   ├── RpcEncoder.java
    │   │   ├── RpcMessageDecoder.java
    │   │   └── RpcMessageEncoder.java
    │   ├── message
    │   │   ├── Request.java
    │   │   ├── Response.java
    │   │   └── RpcMessage.java
    │   ├── send
    │   │   ├── SyncWrite.java
    │   │   ├── SyncWriteFuture.java
    │   │   ├── SyncWriteMap.java
    │   │   └── WriteFuture.java
    │   └── server
    │       ├── RpcServerSocket.java
    │       ├── ServerSocketHandler.java
    │       └── hook
    │           └── ServerExitHook.java
    ├── reflect
    │   ├── RpcInvocationHandler.java
    │   └── RpcProxy.java
    ├── register
    │   ├── AbstractRegisterCenter.java
    │   ├── RegisterCenterFactory.java
    │   └── strategy
    │       ├── LocalRegisterCenter.java
    │       └── RedisRegisterCenter.java
    └── serializer
        └── ProtostuffSerializer.java      

简单介绍下项目:

  • ​compress​

    ​:这个下面是压缩类的实现,当然其对应的接口是放在了common项目里面,然后也提供了SPI机制,可以使用自己提供的实现;
  • ​config​

    ​:这个包下面就是配置加载器相关的实现
  1. ​ConfigLoader​

    ​:这个是对应的文件加载方式,其实就是从key=value的格式总通过key获取对应的value值;
  2. ​ConfigManager​

    ​:这里面就是包含了文件扫描,设置值对象等;
  • ​loadbalance​

    ​:这里面就是对应的负载均衡算法;
  • ​network​

    ​:这里面就是对应netty相关的东西,编解码、nettyserver、nettyclient等;
  • ​reflect​

    ​:反射相关的东西,其实应该叫动态代理相关的东西;远程调用的核心也在这里;
  • ​register​

    ​:这个是注册中心在的地方,本来是打算单独拿出去做个小模块的,后面觉得注册中心单独使用的概率不大,不过后续可以考虑,但是纬度没有想好,是以中间件作维度,还是以制定数据格式作维度需要思考一下;
  • ​serializer​

    ​:序列化相关的东西;

SPI机制

这里讲讲RPC的​

​SPI​

​​机制,java里面自带的是​

​Service Provider Interface​

​;其实就是提供接口类,但是实现由三方提供;不太知道的同学可以随便找个文章看看;

接口解析

simple-rpc里面的话,也是参考这种规则,其实实现也比较简单,一个用于方便获取数据的holder :​

​ExtensionHolder、​

​​和一个具体实现的加载器:​

​ExtensionLoader​

​;跟着代码看:

/** 扩展类实例缓存 {name: 扩展类实例} */
private final Map<String, T> extensionsCache = new ConcurrentHashMap<>(8);

/** 扩展加载器实例缓存 {类型:加载器实例} */
private static final Map<Class<?>, ExtensionLoader<?>> extensionLoaderCache = new ConcurrentHashMap<>(8);

/** 扩展类配置列表缓存 {type: {name, 扩展类}} */
private final ExtensionHolder<Map<String, Class<?>>> extensionClassesCache = new ExtensionHolder<>();

/** 创建扩展实例类的锁缓存 {name: synchronized 持有的锁} */
private final Map<String, Object> createExtensionLockMap = new ConcurrentHashMap<>(8);

/** 扩展类加载器的类型 */
private final Class<T> type;

/** 扩展类存放的目录地址 */
private static final String EXTENSION_PATH = "META-INF/simple-rpc/";

/** 默认扩展名缓存 */
private final String defaultNameCache;

private ExtensionLoader(Class<T> type) {
    this.type = type;
    SimpleRpcSPI annotation = type.getAnnotation(SimpleRpcSPI.class);
    defaultNameCache = annotation.value();
}      

这里就是各种缓存,具体的就不说了,注释里面很详细,有点就是默认值是由注解里面的value值去获取的,这里都是有默认初始值的;

能使用SPI机制的接口都放在了​

​simple-rpc-common​

​模块里面:

@SimpleRpcSPI(value = "default")
public interface Compressor {...}

@SimpleRpcSPI(value = "simple-rpc-property")
public interface ConfigLoader {...}

@SimpleRpcSPI(value = "redis")
public interface RegisterCenter {...}

@SimpleRpcSPI(value = "protostuff")
public interface Serializer {...}

@SimpleRpcSPI("round")
public interface SimpleRpcLoadBalance {...}      

接下来就是尝试去获取扩展加载器:

/**
 * 获取对应类型的扩展加载器实例
 *
 * @param type 扩展类加载器的类型
 * @return 扩展类加载器实例
 */
public static <S> ExtensionLoader<S> getLoader(Class<S> type) {
    // 扩展类型必须是接口
    if (!type.isInterface()) {
        throw new IllegalStateException(type.getName() + " is not interface");
    }
    SimpleRpcSPI annotation = type.getAnnotation(SimpleRpcSPI.class);
    if (annotation == null) {
        throw new IllegalStateException(type.getName() + " has not @SimpleRpcSPI annotation.");
    }
    ExtensionLoader<?> extensionLoader = extensionLoaderCache.get(type);
    if (extensionLoader != null) {
        //noinspection unchecked
        return (ExtensionLoader<S>) extensionLoader;
    }
    extensionLoader = new ExtensionLoader<>(type);
    extensionLoaderCache.putIfAbsent(type, extensionLoader);
    //noinspection unchecked
    return (ExtensionLoader<S>) extensionLoader;
}      

扩展类加载器这里定义的是static缓存,全局共用,先做一些前置的拦截校验,判断是否是​

​@SimpleRpcSPI​

​​标记的类,是的话继续往下走;尝试从缓存中拿数据,如果不存在则返回;缓存可以防止频繁创建加载器;比如加载​

​Serializer​

​的SPI,会尝试从缓存中拿出来:

ExtensionLoader.getLoader(Serializer.class).getExtension("json");      

有了对应类型的SPI加载器,接下来就是加载对应的SPI的实现类,简单看下逻辑:

/**
 * 根据名字获取扩展类实例(单例)
 *
 * @param name 扩展类在配置文件中配置的名字. 如果名字是空的或者空白的,则返回默认扩展
 * @return 单例扩展类实例,如果找不到,则抛出异常
 */
public T getExtension(String name) {
    if (StrUtil.isBlank(name)) {
        return getDefaultExtension();
    }
    // 从缓存中获取单例
    T extension = extensionsCache.get(name);
    if (extension == null) {
        Object lock = createExtensionLockMap.computeIfAbsent(name, k -> new Object());
        //noinspection SynchronizationOnLocalVariableOrMethodParameter
        synchronized (lock) {
            extension = extensionsCache.get(name);
            if (extension == null) {
                extension = createExtension(name);
                extensionsCache.put(name, extension);
            }
        }
    }
    return extension;
}      

这里有几点:

  • 如果是未指定扩展名,那么选用默认的扩展类,这里是有simple-rpc-core自己提供实现;
  • 这里同样使用缓存,尝试从指定类型的扩展类缓存中拿到对应类型的扩展类,如果不存在的话加锁尝试再次获取;获取不到则创建新的实例;此处返回的是单例;

然后就是创建实例的过程,也比较简单:

private T createExtension(String name) {
    // 获取当前类型所有扩展类
    Map<String, Class<?>> extensionClasses = getAllExtensionClasses();
    // 再根据名字找到对应的扩展类
    Class<?> clazz = extensionClasses.get(name);
    if (clazz == null) { throw new ...}
    ...
    return (T) clazz.newInstance();
    ...
}

private Map<String, Class<?>> getAllExtensionClasses() {
    Map<String, Class<?>> extensionClasses = extensionClassesCache.get();
    if (extensionClasses != null) {
        return extensionClasses;
    }
    synchronized (extensionClassesCache) {
        extensionClasses = extensionClassesCache.get();
        if (extensionClasses == null) {
            extensionClasses = loadClassesFromResources();
            extensionClassesCache.set(extensionClasses);
        }
    }
    return extensionClasses;
}

private Map<String, Class<?>> loadClassesFromResources() {
    Map<String, Class<?>> extensionClasses = new ConcurrentHashMap<>();
    // 扩展配置文件名
    String fileName = EXTENSION_PATH + type.getName();
    // 拿到资源文件夹
    ClassLoader classLoader = ExtensionLoader.class.getClassLoader();
    try {
        Enumeration<URL> resources = classLoader.getResources(fileName);
        while (resources.hasMoreElements()) {
            URL url = resources.nextElement();
            try (BufferedReader reader = new BufferedReader(new InputStreamReader(url.openStream(), StandardCharsets.UTF_8))) {
                // 开始读文件
                while (true) {
                    String line = reader.readLine();
                    if (line == null) {
                        break;
                    }
                    parseLine(line, extensionClasses);
                }
            }
        }
    ...
    return extensionClasses;
}

private void parseLine(String line, Map<String, Class<?>> extensionClasses) throws ClassNotFoundException {
    line = line.trim();
    // 忽略#号开头的注释
    if (line.startsWith("#")) {
        return;
    }
    String[] kv = line.split("=");
    if (kv.length != 2 || kv[0].length() == 0 || kv[1].length() == 0) {
        throw new IllegalStateException("Extension file parsing error. Invalid format!");
    }
    if (extensionClasses.containsKey(kv[0])) {
        throw new IllegalStateException(kv[0] + " is already exists!");
    }
    Class<?> clazz = ExtensionLoader.class.getClassLoader().loadClass(kv[1]);
    extensionClasses.put(kv[0], clazz);
}      

调用链路:​

​createExtension​

​​ → ​

​getAllExtensionClasses​

​​ → ​

​loadClassesFromResources​

​​ → ​

​parseLine​

​;

四个方法简单分析下:

  • ​createExtension​

    ​​:这里面会尝试从缓存​

    ​extensionClassesCache​

    ​里面拿到Class,拿到之后创建对应的实例;
  • ​getAllExtensionClasses​

    ​:就是将获取的数据塞到缓存里面;
  • ​loadClassesFromResources​

    ​​:这里的话是从资源文件里面去获取对应的数据,然后逐行解析;这里的文件名就是对应的SPI接口名,比如是序列化的,那么文件路径就是:​

    ​META-INF/simple-rpc/com.simple.rpc.common.interfaces.Serializer​

    ​;
  • ​parseLine​

    ​​:这个就是​

    ​json=com.simple.rpc.test.core.common.test.spi.JSONSerializer​

    ​这段解析出来,然后value存到缓存里面即可;

SPI测试

测试代码在com.simple.rpc.test.ExtensionLoaderTest中:

@Test
public void testSerialize() {
    System.out.println(ExtensionLoader.getLoader(Serializer.class).getExtension("json"));
    System.out.println(ExtensionLoader.getLoader(RegisterCenter.class).getExtension("mysql"));
    System.out.println(ExtensionLoader.getLoader(ConfigLoader.class).getExtension("spi-property"));
}      
手写RPC-具体实现细节详解(近4w字详情)

一般SPI接口是需要放在公共的地方,全系统是保持一致的;

注册中心

上面也说了,本来是打算单独一个服务:simple-rpc-register,但是时间受限和考虑维度暂时没想好问题,所以目前是搁置了;之后思考一下在考虑单独拿出来,这里先讲讲注册相关的逻辑;

基本方法

注册中心不复杂,就是存生产者提供的服务的元信息:

@SimpleRpcSPI(value = "redis")
public interface RegisterCenter {
    void init(SimpleRpcUrl url);

    Boolean register(RegisterInfo request);

    String get(RegisterInfo request);
    
    Boolean unregister(HookEntity hookEntity);
}      
  • ​init​

    ​:初始化接口,就想初始化jedis连接、mysql连接等;提供其连接后的存储能力;
  • ​register​

    ​:注册信息,这里面就是一些接口的基本信息和服务端的连接信息,可以具体看看代码里面的实体字段;
  • ​get​

    ​:获取对应的信息,等会后面一起写下数据格式;其实就只在设计数据格式,然后存取的操作;
  • ​unregister​

    ​:这个就是注销所有的注册信息,就是删除对应的注册元数据信息;

数据格式

存储分为三个部分:

  • ​key​

    ​​:这里是接口权限定名+ “_” + 别名,​

    ​com.simple.rpc.HelloService_helloService​

    ​;
  • ​host_port​

    ​​:类似于子key,存储的是服务端url的信息,​

    ​127.0.0.1_41200​

    ​;
  • ​registerInfo​

    ​:注册的元数据信息,出上面两个key信息之外,还会有一些属性配置;
/**
* 将三个字段构建然后存入:
*                                          --- 127.0.0.1_41201 --- {"alias":"xxx","host":"127.0.0.1","port":41201,"serializer":"serializer","weights":20}
*                                          -
* com.simple.rpc.HelloService_helloService --- 127.0.0.1_41202 --- {"alias":"xxx","host":"127.0.0.1","port":41202,"serializer":"serializer","weights":30}
*                                          -
*                                          --- 127.0.0.1_41203 --- {"alias":"xxx","host":"127.0.0.1","port":41203,"serializer":"serializer","weights":50}
*
*/      

接口实现

一些基本的前置构建进行了抽象:​

​com.simple.rpc.core.register.AbstractRegisterCenter​

注册接口

@Override
public Boolean register(RegisterInfo request) {
    String key = request.getInterfaceName() + SymbolConstant.UNDERLINE + request.getAlias();
    String fieldKey = request.getHost() + SymbolConstant.UNDERLINE + request.getPort();
    String value = JSON.toJSONString(request);
    return buildDataAndSave(key, fieldKey, value);
}

// 这里由子类取实现存操作
protected abstract Boolean buildDataAndSave(String key, String hostPort, String request);      

获取服务接口

@Override
public String get(RegisterInfo request) {
    String key = request.getInterfaceName() + SymbolConstant.UNDERLINE + request.getAlias();
    Map<String, String> stringStringMap = getLoadBalanceData(key);
    // 负载均衡策略
    String rule = Objects.isNull(request.getLoadBalanceRule()) ? LoadBalanceRule.ROUND.getName() : request.getLoadBalanceRule();
    return ExtensionLoader.getLoader(SimpleRpcLoadBalance.class).getExtension(rule).loadBalance(stringStringMap);
}

/**
 * 数据格式:{"127.0.0.1_41200" : "{"requestId: 1"}"}
 * 描述:通过 key(com.simple.rpc.AService_aService)获取的到 下面的map格式
 * map格式:前面的key以 host + "_" + port 组成;后面是对应的request信息的json格式
 *
 * @param key
 * @return
 */
protected abstract Map<String, String> getLoadBalanceData(String key);      

这里的话就是加入了负载算法,其他的也就是从存数据的地方根据制定的数据格式,拿到对应的注册信息;关于负载算法,之后可以单独说;

这里是支持SPI机制来实现不同的负载均衡的,当然,rpc-core里面也提供了多种负载算法;

工厂提供接口

public class RegisterCenterFactory {

    public static RegisterCenter create(String registerType) {
        return getRegisterCenter(registerType);
    }

    private static RegisterCenter getRegisterCenter(String registerType) {
        // 根据SPI找到对应的注册中心,其实这里是会根据address进行解析然后判断是哪种类型; 
        return ExtensionLoader.getLoader(RegisterCenter.class).getExtension(registerType);
    }
}      

从这就可以拿到对应的注册中心,然后在进行其他操作,这里放到之后整合springboot的时候在结合起来看看;

注销服务信息

@Override
public Boolean unregister(HookEntity hookEntity) {
    List<String> rpcServiceNames = hookEntity.getRpcServiceNames();
    String fieldKey = hookEntity.getServerUrl() + SymbolConstant.UNDERLINE + hookEntity.getServerPort();
    AtomicReference<Long> hdel = new AtomicReference<>(0L);
    rpcServiceNames.forEach(name -> {
        hdel.set(jedis.hdel(name, fieldKey));
    });

    return hdel.get() > 0;
}      

这里是redis为注册中心的实现,就是删除信息;没什么可讲的;之后只需要挂在springboot容器退出的时候去调用即可;

抽象类里面的子实现这里就不把代码粘贴上来了;

总而言之,注册中心就是:​

​注册元信息​

​、​

​获取注册元信息​

​、​

​注销元信息​

​;

配置中心

配置加载器

这里rpc提供自定义配置能力的地方;同样提供SPI机制:

@SimpleRpcSPI(value = "simple-rpc-property")
public interface ConfigLoader {

    /**
     * 加载配置项
     *
     * @param key 配置的 key
     * @return 配置项的值,如果不存在,返回 null
     */
    String loadConfigItem(String key);
}      

这是loader提供的能力,也就是通过key获取value;

系统提供了两个子实现:

  • ​PropertiesConfigLoader​

    ​​:​

    ​setting = SettingUtil.get("simple-rpc.properties");​

    ​​使用hutool的工具类从指定文件名里面拿到对应的配置信息;​

    ​return setting.getStr(key);​

  • ​SystemPropertyLoader​

    ​​:​

    ​System.getProperty(key);​

    ​这个比较简单,就是直接拿系统相关的参数;

配置管理器

这里面的话就提供了文件加载、行数据加载、实体映射等;

构造器就是在做初始化操作:将目前系统提供的加载器放到缓存里面,然后在​

​loadConfigItem​

​​方法里面轮询获取,然后优先顺序是:​

​system-property​

​​ > ​

​simple-rpc-property​

​​ > ​

​spi-property​

​;前者为空的情况下,会顺延;之后测试一下;

private ConfigManager() {
    // 按照优先级放好
    List<String> configLoaderNames = Arrays.asList("system-property", "simple-rpc-property", "spi-property");
    configLoaders = new ArrayList<>(configLoaderNames.size());
    for (String loaderName : configLoaderNames) {
        ConfigLoader configLoader = ExtensionLoader.getLoader(ConfigLoader.class).getExtension(loaderName);
        configLoaders.add(configLoader);
    }
}

/**
 * 获取配置项
 *
 * @param key 配置项的 key
 * @return 如果获取不到,返回 null
 */
public String loadConfigItem(String key) {
    // 按照优先级,先获取到就返回
    for (ConfigLoader configLoader : configLoaders) {
        String value = configLoader.loadConfigItem(key);
        if (value != null) {
            return value;
        }
    }
    return null;
}      

下面代码做的是缓存操作,根据各个不同的配置类,返回不同的配置类加载器:

/**
 * 加载配置,有缓存
 *
 * @param clazz 配置类型
 * @param <T>   类型
 * @return 配置实体类
 */
@SuppressWarnings("unchecked")
public <T> T loadConfig(Class<T> clazz) {
    T config = (T) configCache.get(clazz);
    if (config == null) {
        config = loadAndCreateConfig(clazz);
        configCache.put(clazz, config);
    }
    return config;
}      

继续看看实体映射:

/**
 * 加载并创建配置类
 *
 * @param clazz 类型 class
 * @param <T>   类型
 * @return 配置类,load 不到的字段为 null
 */
private <T> T loadAndCreateConfig(Class<T> clazz) {
    SimpleRpcConfig configAnnotation = clazz.getAnnotation(SimpleRpcConfig.class);
    if (configAnnotation == null) {
        throw new IllegalStateException("config class " + clazz.getName() + " must has @SimpleRpcConfig annotation");
    }
    String prefix = configAnnotation.prefix();
    if (StrUtil.isBlank(prefix)) {
        throw new IllegalArgumentException("config class " + clazz.getName() + "@SimpleRpcConfig annotation must has prefix");
    }
    try {
        T configObject = clazz.newInstance();
        for (Field field : clazz.getDeclaredFields()) {
            // 忽略掉静态的
            if (Modifier.isStatic(field.getModifiers())) {
                continue;
            }
            String configKey = prefix + "." + field.getName();
            String value = loadConfigItem(configKey);
            if (value == null) {
                continue;
            }

            Object convertedValue = Convert.convert(field.getType(), value);
            field.setAccessible(true);
            field.set(configObject, convertedValue);
        }
        return configObject;
    } catch (Exception e) {
        throw new IllegalStateException(e.getMessage(), e);
    }
}      

这里简单讲下几个操作:

  • 首先是拿到类上的​

    ​@SimpleRpcConfig​

    ​注解,用来判断其是否是配置类;
  • 拿到前缀信息,跟springboot里面的配置类使用方式其实差不多;
  • 创建该配置类,然后遍历其所有的字段,通过​

    ​String value = loadConfigItem(configKey);​

    ​获取从配置文件里面加载出来的key进行赋值;

整个逻辑就差不多是这样;

配置加载器测试

加载测试

在​

​resources​

​​资源目录下面新建文件​

​simple-rpc.properties​

​:

simple.rpc.register.address=redis://127.0.0.1:6379
simple.rpc.register.password=123456
simple.rpc.base.loadBalanceRule=round
simple.rpc.base.timeout=10000
simple.rpc.base.retryNum=10      

然后新建测试类:​

​com.simple.rpc.test.ConfigLoaderTest​

​,然后测试注册信息的获取:

@Test
public void getRegistryConfig() {
    RegistryConfig registryConfig = ConfigManager.getInstant().loadConfig(RegistryConfig.class);
    System.out.println(registryConfig);
    System.out.println(SimpleRpcUrl.toSimpleRpcUrl(registryConfig));
}      

可以看看结果:

手写RPC-具体实现细节详解(近4w字详情)

目前是已经加载到了信息;

加载器的优先级

@Test
public void getBaseConfig() {
    BaseConfig baseConfig = ConfigManager.getInstant().loadConfig(BaseConfig.class);
    System.out.println(baseConfig);
}      

看看结果:

手写RPC-具体实现细节详解(近4w字详情)

前面代码里面也看到过,​

​SPI​

​的优先级最低(这里先不关心SPI实现),所以最先加载,下面的输出就是加载RPC配置文件的结果,所以优先级就是如此,但是后面整合springboot的时候,几乎都是使用其默认的配置类,但是这个也可以作为一个扩展点,灵活性更高;

负载均衡

思路分析

simple-rpc的负载思路是:从注册中心区注册信息的那里进行处理,从多个url中选取一个,同样是抽象类里面做大部分逻辑:

public abstract class AbstractLoadBalance implements SimpleRpcLoadBalance {

    @Override
    public String loadBalance(Map<String, String> services) {
        if (CollectionUtil.isEmpty(services)) {
            return null;
        }
        Map<String, LoadBalanceParam> selectMap = new ConcurrentHashMap<>(4);
        Set<String> urls = services.keySet();
        for (String url : urls) {
            Request request = JSON.parseObject(services.get(url), Request.class);
            LoadBalanceParam param = new LoadBalanceParam();
            param.setWeights(request.getWeights());
            selectMap.put(url, param);
        }
        String selectUrl = select(selectMap);
        return services.get(selectUrl);
    }

    /**
     * 实际的负载算法
     *
     * @param urls
     * @return
     */
    public abstract String select(Map<String, LoadBalanceParam> urls);
}      

简单解析一下方法:

  • 这里是去构建一个​

    ​key = 127.0.0.1_41201,value= new LoadBalanceParam;​

    ​的map,这里是想着之后可以扩展,一些负载算法需要依赖于一些值,都可以放到后面这个实体里面传给对应的SPI实现;
  • 然后关于url的选择,则由具体的算法实现;下面就简单分析一下几种不同的负载策略;

负载算法分析

这里的话,简单的就不分析了,也就是​

​round​

​​、​

​random​

​;

轮询权重算法

算法思路:就是根据权重拉长范围,然后根据随机数分配,想当于权重大的被随机到的概率也越大;

public class RandomWeightRule extends AbstractLoadBalance {

    /**
     * 权重:
     * - 127.0.0.1_8881: 20
     * - 127.0.0.1_8882: 30
     * - 127.0.0.1_8883: 50
     * <p>
     * 拉长范围:|0 -- 127.0.0.1_8881 -- 20 | 21 —- 127.0.0.1_8882 -- 50 | 51 -- 127.0.0.1_8883 -- 100|
     *
     * @param urls
     * @return
     */
    @Override
    public String select(Map<String, LoadBalanceParam> urls) {
        int range = 0;
        for (LoadBalanceParam value : urls.values()) {
            range += value.getWeights();
        }
        int index = RandomUtil.randomInt(range);
        for (String url : urls.keySet()) {
            LoadBalanceParam param = urls.get(url);
            Integer weights = param.getWeights();
            if (index < weights) {
                return url;
            }
            index -= weights;
        }
        return "";
    }
}      

这里分析几个案例(注释里面的信息):

  • 随机数10,然后遍历到的​

    ​127.0.0.1_8881: 20​

    ​​,这个时候​

    ​index(10) < weights(20)​

    ​符合直接返回;
  • 随机数40,第一次遍历就是上面情况,不符合,然后计算​

    ​index(40) -= weights(20)​

    ​​ 这个时候开始第二轮,此时​

    ​index = 40 - 20 = 20​

    ​​,20是第一次url的权重,然后再次判断,​

    ​index(20) < weights(30)​

    ​​符合;证明轮询到了​

    ​127.0.0.1_8882: 30​

    ​;
  • 随机数70,根据上面两种情况,这里相减两次,最后符合条件已经是第三轮轮询,返回的则是​

    ​127.0.0.1_8883: 50​

普通的加权轮询算法

负载思路:其实这个也简单,就是求对应几个服务权重的最小公约数,然后根据用 权重 / 最小公约数 = 轮询次数;这里就是比较粗暴,现在前面的机器如果权重过大,可能大部分负载都压到该机器;如果权重是321这里的效果就是​

​AAABBC​

​;

public class NormalWeightRoundRule extends AbstractLoadBalance {

    private static Map<String, Integer> urlMap = new ConcurrentHashMap<>(4);

    /**
     * 权重:
     * - 127.0.0.1_8881: 20
     * - 127.0.0.1_8882: 30
     * - 127.0.0.1_8883: 50
     * <p>
     * 找到最小公约数:转换成 url : 20 / 10 = 2 这种数据格式,然后遍历map,在减去对应的次数
     *
     * @param urls
     * @return
     */
    @Override
    public String select(Map<String, LoadBalanceParam> urls) {
        List<Integer> weights = new ArrayList<>(10);
        for (String url : urls.keySet()) {
            weights.add(urls.get(url).getWeights());
        }
        int maxGys = ngcd(weights, weights.size());
        if (CollectionUtil.isEmpty(urlMap)) {
            // 将权重除以最小公约数
            urls.forEach((url, param) -> urlMap.put(url, param.getWeights() / maxGys));
        }
        for (String url : urlMap.keySet()) {
            Integer num = urlMap.get(url);
            if (num > 0) {
                urlMap.put(url, --num);
                return url;
            } else {
                urlMap.remove(url);
                if (CollectionUtil.isEmpty(urlMap)) {
                    // 将权重除以最小公约数
                    urls.forEach((url1, param) -> urlMap.put(url1, param.getWeights() / maxGys));
                    ArrayList<String> urlFirst = new ArrayList<>(urlMap.keySet());
                    Integer numFirst = urlMap.get(urlFirst.get(0));
                    urlMap.put(urlFirst.get(0), --numFirst);
                    return urlFirst.get(0);
                }
            }
        }
        return "";
    }

    public static int gcd(int x, int y) {
        return y == 0 ? x : gcd(y, x % y);
    }

    public static int ngcd(List<Integer> target, int z) {
        if (z == 1) {
            //真正返回的最大公约数
            return target.get(0);
        }
        //递归调用,两个数两个数的求
        return gcd(target.get(z - 1), ngcd(target, z - 1));
    }
}      

这里其实就两点:

  • 首先是通过递归求出权重的最小公约数
  • 拿到最小公约数后开始构建​

    ​key = url,value = 次数​

    ​的map,然后用于遍历,每次数量减1,全部轮询结束后重置缓存;

平滑加权轮询

负载思路:这个算法的话,其实是保证在一个周期内服务根据权重被轮询到的总次数是一致的,但是在周期内是随机的,跟上面一样,如果权重是321,那么这里的效果就是​

​ABACBA​

​;

public class SmoothWeightRoundRule extends AbstractLoadBalance {

    private static Map<String, ServiceWeight> weightPathMap = new ConcurrentHashMap<>(4);

    /**
     * 权重:
     * - 127.0.0.1_8881: 3
     * - 127.0.0.1_8882: 2
     * - 127.0.0.1_8883: 1
     * <p>
     * |------------|---------------|------------|----------------|--------------------------------|
     * | requestNum | currentWeight | currentMax |      url       | max(currentWeight)-totalWeight |
     * |------------|---------------|------------|----------------|--------------------------------|
     * |     1      |     3,2,1     |     3      | 127.0.0.1_8881 |             -3,2,1             |
     * |------------|---------------|------------|----------------|--------------------------------|
     * |     2      |     0,2,1     |     2      | 127.0.0.1_8882 |             0,-4,1             |
     * |------------|---------------|------------|----------------|--------------------------------|
     * |     3      |    3,-2,1     |     3      | 127.0.0.1_8881 |             -3,-2,1            |
     * |------------|---------------|------------|----------------|--------------------------------|
     * |     4      |     0,0,2     |     2      | 127.0.0.1_8883 |              0,0,-5            |
     * |------------|---------------|------------|----------------|--------------------------------|
     * |     5      |     3,2,-4    |     3      | 127.0.0.1_8881 |             -3,2,-4            |
     * |------------|---------------|------------|----------------|--------------------------------|
     * |     6      |     0,4,-3    |     4      | 127.0.0.1_8882 |              0,-2,-3           |
     * |------------|---------------|------------|----------------|--------------------------------|
     * <p>
     * - 每次从当前权重中选出最大权重代表的服务器作为返回结果
     * - 选择完具体服务器后,把当前最大权重减去权重总和,再把所有权重都跟初始权重(3,2,1)相加
     * - 3,2,1,选出3代表的127.0.0.1_8881服务器,然后减去6成为 -3,2,1,再加上3,2,1,成为0,2,1,下次的当前权重即为0,2,1
     *
     * @param urls
     * @return
     */
    @Override
    public String select(Map<String, LoadBalanceParam> urls) {
        return getServerPath(urls);
    }

    /**
     * 每个url对应一个权重
     *
     * @param urls
     * @return
     */
    public static String getServerPath(Map<String, LoadBalanceParam> urls) {
        // 拿到总权重
        int range = 0;
        for (LoadBalanceParam value : urls.values()) {
            range += value.getWeights();
        }

        // 动态权重Map为空,那么初始化端口权重map
        if (weightPathMap.isEmpty()) {
            urls.forEach((url, param) -> weightPathMap.put(url, new ServiceWeight(url, param.getWeights(), 0)));
        }
        // 当前权重 + 初始权重 = 下一轮的权重
        for (ServiceWeight serviceWeight : weightPathMap.values()) {
            serviceWeight.curWeight += serviceWeight.weight;
        }

        // 将最大的权重赋值
        ServiceWeight maxCurWeight = null;
        for (ServiceWeight weight : weightPathMap.values()) {
            // 找最大的权重值
            if (maxCurWeight == null || weight.curWeight > maxCurWeight.curWeight) {
                maxCurWeight = weight;
            }
        }
        // 减去总权重
        assert maxCurWeight != null;
        maxCurWeight.curWeight -= range;
        // 拿到端口
        return maxCurWeight.url;
    }
}      

大致就是下面几个步骤:

  • 每次从当前权重中选出最大权重代表的服务器作为返回结果
  • 选择完具体服务器后,把当前最大权重减去权重总和,再把所有权重都跟初始权重(3,2,1)相加
  • 3,2,1,选出3代表的127.0.0.1_8881服务器,然后减去6成为 -3,2,1,再加上3,2,1,成为0,2,1,下次的当前权重即为0,2,1

测试的话代码里面有,这就不粘贴上来了,然后负载均衡也是支持SPI机制的,作用于从注册中心获取注册信息的时候,会根据url进行选取;

压缩、序列化

这里的话,压缩是默认没有使用的,也就是直接依赖于序列化工具,但是rpc-core里面有实现​

​GZIP​

​​;序列化默认是使用​

​protostuff​

​,这个序列化性能还是不错的;

这里简单提下,直接实现对应的接口,就可以通过SPI机制加载指定的压缩工具或者序列化工具:

压缩

@SimpleRpcSPI(value = "default")
public interface Compressor {

    byte[] compress(byte[] bytes);

    byte[] decompress(byte[] bytes);
}      

序列化

@SimpleRpcSPI(value = "protostuff")
public interface Serializer {

    byte[] serialize(Object object);

    <T> T deserialize(byte[] bytes, Class<T> clazz);
}      

在​

​simple-rpc-test-common​

​​模块里面有对应的SPI实现:​

​com.simple.rpc.test.common.starter.spi.JSONSerializer​

RPC协议和通讯

这里大部分都算是netty的东西,先从服务端、客户端的启动说说,然后在到代理反射,最后就是双端通讯过程的编解码和handler处理;

netty双端

服务端和客户端的设计理念都是实现​

​Runnable​

​,然后核心就在run方法里面,先说说client端的启动:

public RpcServerSocket(Request request) {
    this.request = request;
}      

这里构建request,里面包含心跳时间,服务端的host和port,然后就是启动client;run方法里面的核心内容:

// 选用客户端的启动器
Bootstrap b = new Bootstrap();
b.group(workerGroup);
b.channel(NioSocketChannel.class);
b.option(ChannelOption.AUTO_READ, true);
b.option(ChannelOption.SO_KEEPALIVE, Boolean.TRUE);
b.option(ChannelOption.TCP_NODELAY, Boolean.TRUE);
b.handler(new ChannelInitializer<SocketChannel>() {...});
// 进行服务器连接
ChannelFuture f = b.connect(host, port).sync();
this.future = f;
f.channel().closeFuture().sync();      

关于那些参数什么意思,等下说server端的时候再讲,这里面比较重要的一个地方​

​this.future = f;​

​​这里的话在调用client启动的时候,会轮询去判断channel是否已经创建,这里也就是client端的重试次数,这里也仅仅用来判断连接是否成功,然后每次会讲此次的​

​channelFuture​

​缓存起来;

服务端的话,直接看代码:

@Override
public void run() {
    Long stopConnectTime = Objects.isNull(request.getStopConnectTime()) || request.getStopConnectTime() <= 0 ?
                30 : request.getStopConnectTime();
    // 这里如果自己不指定线程数,默认是当前cpu的两倍
    EventLoopGroup dealConnGroup = new NioEventLoopGroup(128);
    EventLoopGroup workerGroup = new NioEventLoopGroup();
    try {
        ServerBootstrap bootstrap = new ServerBootstrap();
        bootstrap.group(dealConnGroup, workerGroup)
                .channel(NioServerSocketChannel.class)
                // 系统用于临时存放已完成三次握手的请求的队列的最大长度。如果连接建立频繁,服务器处理创建新连接较慢,可以适当调大这个参数
                .option(ChannelOption.SO_BACKLOG, 128)
                // 程序进程非正常退出,内核需要一定的时间才能够释放此端口,不设置 SO_REUSEADDR 就无法正常使用该端口。
                .option(ChannelOption.SO_REUSEADDR, Boolean.TRUE)
                // TCP/IP协议中针对TCP默认开启了Nagle 算法。
                // Nagle 算法通过减少需要传输的数据包,来优化网络。在内核实现中,数据包的发送和接受会先做缓存,分别对应于写缓存和读缓存。
                // 启动 TCP_NODELAY,就意味着禁用了 Nagle 算法,允许小包的发送。
                // 对于延时敏感型,同时数据传输量比较小的应用,开启TCP_NODELAY选项无疑是一个正确的选择
                .childOption(ChannelOption.TCP_NODELAY, Boolean.TRUE)
                .childHandler(new ChannelInitializer<SocketChannel>() {
                    @Override
                    public void initChannel(SocketChannel ch) {...});
        // 默认启动初始端口
        int port = Objects.isNull(request.getPort()) || request.getPort() <= 0 ? 41200 : request.getPort();
        while (NetUtil.isPortUsing(port)) {
            port++;
        }
        LocalAddressInfo.LOCAL_HOST = NetUtil.getHost();
        LocalAddressInfo.PORT = port;
        //注册服务
        this.f = bootstrap.bind(port).sync();
        // 返回请求等待结果,异步监听事件
        this.f.channel().closeFuture().sync();
       ...
}      

服务端也很简单,上面注释上面对各个参数都做了很详细的解释,这里会用的​

​LocalAddressInfo​

​对象缓存服务端的启动信息,端口由自己定义,通过配置传递进来;

所以这里也比较简单,然后只需要开个线程调用即可;

通讯和RPC协议

编解码

这里先说下自定义的编解码协议,在使用handler的时候会走自己定义的编解码协议,首先看看编码:首先是​

​RpcMessageEncoder extends MessageToByteEncoder<RpcMessage>​

​:

public class RpcMessageEncoder extends MessageToByteEncoder<RpcMessage> {

    @Override
    protected void encode(ChannelHandlerContext ctx, RpcMessage rpcMessage, ByteBuf out) {
        // 2B magic code(魔数)
        out.writeBytes(MessageFormatConstant.MAGIC);
        // 1B version(版本)
        out.writeByte(MessageFormatConstant.VERSION);
        // 4B full length(消息长度). 总长度先空着,后面填。
        out.writerIndex(out.writerIndex() + MessageFormatConstant.FULL_LENGTH_LENGTH);
        // 1B messageType(消息类型)
        out.writeByte(rpcMessage.getMessageType());
        // 1B codec(序列化类型)
        out.writeByte(rpcMessage.getSerializeType());
        // 1B compress(压缩类型)
        out.writeByte(rpcMessage.getCompressTye());
        // 8B requestId(请求的Id)
        out.writeLong(rpcMessage.getRequestId());
        // 写 body,返回 body 长度
        int bodyLength = writeBody(rpcMessage, out);

        // 当前写指针
        int writerIndex = out.writerIndex();
        out.writerIndex(MessageFormatConstant.MAGIC_LENGTH + MessageFormatConstant.VERSION_LENGTH);
        // 4B full length(消息长度)
        out.writeInt(MessageFormatConstant.HEADER_LENGTH + bodyLength);
        // 写指针复原
        out.writerIndex(writerIndex);
    }
}      

这里就是按照指定的格式进行编码:

  • 2B​

    ​magic​

    ​(魔数)
  • 1B​

    ​version​

    ​(版本)
  • 4B​

    ​full length​

    ​(消息长度)
  • 1B​

    ​messageType​

    ​(消息类型)
  • 1B​

    ​serialize​

    ​(序列化类型)
  • 1B​

    ​compress​

    ​(压缩类型)
  • 8B​

    ​requestId​

    ​(请求的Id)
  • ​body​

    ​(object类型数据)
0     1     2       3    4    5    6    7           8        9        10   11   12   13   14   15   16   17   18
+-----+-----+-------+----+----+----+----+-----------+---------+--------+----+----+----+----+----+----+----+---+
|   magic   |version|    full length    |messageType|serialize|compress|              RequestId               |
+-----+-----+-------+----+----+----+----+-----------+----- ---+--------+----+----+----+----+----+----+----+---+
|                                                                                                             |
|                                         body                                                                |
|                                                                                                             |
|                                        ... ...                                                              |
+-------------------------------------------------------------------------------------------------------------+      

这里的消息长度开始只是改变指针的位置,后面写body之后会计算字节数,在设置此次的消息总长度;

然后就是​

​writeBody​

​方法:

private int writeBody(RpcMessage rpcMessage, ByteBuf out) {
    byte messageType = rpcMessage.getMessageType();
    // 如果是 ping、pong 心跳类型的,没有 body,直接返回头部长度
    if (messageType == MessageType.HEARTBEAT.getValue()) {
        return 0;
    }
    // 序列化器
    SerializeType serializeType = SerializeType.fromValue(rpcMessage.getSerializeType());
    if (serializeType == null) {
        throw new IllegalArgumentException("codec type not found");
    }
    Serializer serializer = ExtensionLoader.getLoader(Serializer.class).getExtension(serializeType.getName());
    // 压缩器
    CompressType compressType = CompressType.fromValue(rpcMessage.getCompressTye());
    if (compressType == null) {
        throw new IllegalArgumentException("compress type not found");
    }
    Compressor compressor = ExtensionLoader.getLoader(Compressor.class).getExtension(compressType.getName());
    // 序列化
    byte[] notCompressBytes = serializer.serialize(rpcMessage.getData());
    // 压缩
    byte[] compressedBytes = compressor.compress(notCompressBytes);

    // 写 body
    out.writeBytes(compressedBytes);
    return compressedBytes.length;
}      

简单解析下:

  • 首先是判断是否是心跳消息,如果是的话,就直接跳过,不写body体;
  • 然后根据传进来的序列化、压缩类型在通过SPI机制进行加载,然后先进行序列化,在进行压缩操作,这里就返回压缩后的字节数;

编码就是这些,继续看看解码操作;

public class RpcMessageDecoder extends LengthFieldBasedFrameDecoder {

    public RpcMessageDecoder() {
        /**
         * 继承定长解码器,自己完成构造数据后,每次会将一次请求的消息一起解析,不会出现粘包、拆包
         * 这里的话假设数据是(二进制不补零):52 1 0001 1 1 00000001 1
         * maxFrameLength: 8M
         * lengthFieldOffset:3
         * lengthFieldLength:4
         * lengthAdjustment:-7
         * initialBytesToStrip:0
         *
         * 这里解析比较简单 首先是长度偏移量在 3字节后面,也就是跳过了 52 1;
         * 然后长度为 0001,对应4个字节;
         * 接下来就是调整消息读取位置,又到最前面去了,然后读取的初始位置就是0,从头还是读,所以52 1 0001 1 1 00000001 1都读取了
         */
        super(
                // 最大的长度,如果超过,会直接丢弃
                MAX_FRAME_LENGTH,
                // 描述长度的字段[4B full length(消息长度)]在哪个位置:在 [2B magic(魔数)]、[1B version(版本)] 后面
                MAGIC_LENGTH + VERSION_LENGTH,
                // 描述长度的字段[4B full length(消息长度)]本身的长度,也就是 4B 啦
                FULL_LENGTH_LENGTH,
                // LengthFieldBasedFrameDecoder 拿到消息长度之后,还会加上 [4B full length(消息长度)] 字段前面的长度
                // 因为我们的消息长度包含了这部分了,所以需要减回去
                -(MAGIC_LENGTH + VERSION_LENGTH + FULL_LENGTH_LENGTH),
                // initialBytesToStrip: 去除哪个位置前面的数据。因为我们还需要检测 魔数 和 版本号,所以不能去除
                0);
    }

    @Override
    protected Object decode(ChannelHandlerContext ctx, ByteBuf in) throws Exception {
        Object decoded = super.decode(ctx, in);
        if (decoded instanceof ByteBuf) {
            ByteBuf frame = (ByteBuf) decoded;
            if (frame.readableBytes() >= HEADER_LENGTH) {
                try {
                    return decodeFrame(frame);
                } catch (Exception ex) {
                    SimpleRpcLog.error("Decode frame error.", ex);
                } finally {
                    frame.release();
                }
            }
        }
        return decoded;
    }
}      

这里就涉及到了netty的定长解码器,这里的话不多说,注释里面写的挺详细的;可以看看文章:

​​https://zhuanlan.zhihu.com/p/95621344​​

继承了定长解码器,我们只需要关心我们每次接收的信息了,从版本开始解析,然后会做个消息的长度校验操作:​

​frame.readableBytes() >= HEADER_LENGTH​

​​;接下来就是解码:​

​decodeFrame(frame);​

readAndCheckMagic(in);
readAndCheckVersion(in);
int fullLength = in.readInt();
byte messageType = in.readByte();
byte codec = in.readByte();
byte compress = in.readByte();
long requestId = in.readLong();
...
// 是心跳则直接返回
if (messageType == MessageType.HEARTBEAT.getValue()) {
    return rpcMessage;
}
// 计算读取长度
int bodyLength = fullLength - HEADER_LENGTH;
if (bodyLength == 0) {
    return rpcMessage;
}
// 读取数据
byte[] bodyBytes = new byte[bodyLength];
in.readBytes(bodyBytes);      

这里的话粘贴了核心代码,拿到了消息体,后面就是判断序列化、解压类型了,然后就是逆流程了;

整个编解码协议就是这些了;

通讯工具类

就4个类,完成消息的发送:

├── send
│   ├── SyncWrite.java
│   ├── SyncWriteFuture.java
│   ├── SyncWriteMap.java
│   └── WriteFuture.java      
  • ​WriteFuture​

    ​​:这个类是一个异步接口,继承于java并发包里面的​

    ​Future​

    ​;除了其自带的一些方法,扩展了是否写成功、设置写结果、设置响应值、此次请求id等方法;
  • ​SyncWriteMap​

    ​​:这个其实就是个缓存,讲此次请求产生的异步类​

    ​WriteFuture​

    ​缓存起来,用请求id作为key,每次请求结束后删除此次请求操作;
  • ​SyncWriteFuture​

    ​​:这个类其实可以看看,也就是​

    ​WriteFuture​

    ​的实现类;
  • 是否超时,这里根据​

    ​isTimeout​

    ​方法进行判断;
  • ​get​

    ​获取响应的时候,会去尝试拿到锁;
private CountDownLatch latch = new CountDownLatch(1);

if (latch.await(timeout, unit)) {
    return response;
}      
上面会发生锁等待,然后在处理完结果后,写回响应的时候,会去释放锁,然后就能获取到响应值      
public void setResponse(Response response) {
    this.response = response;
    // 设置好了响应值之后,这里释放锁
    latch.countDown();
}      
然后其他的就没什么了,都是一些基本的方法;      
  • ​SyncWrite​

    ​​:这个方法里面就是封装了一层,其实底层就是​

    ​channel.writeAndFlush​

    ​写消息操作;这里大致讲下其流程;

    可以看看代码来说:

public Response writeAndSync(Channel channel, Request request, Long timeout) throws Exception {
    if (channel == null) {throw new NettyInitException("channel is null, please init channel");}
    if (request == null) {throw new NullPointerException("request");}
    if (timeout <= 0) {timeout = 30L;}
    long requestId = MessageFormatConstant.REQUEST_ID.incrementAndGet();
    request.setRequestId(requestId);
    // 记录此次请求id,并放入到缓存中
    WriteFuture<Response> future = new SyncWriteFuture(request.getRequestId());
    SyncWriteMap.syncKey.put(request.getRequestId(), future);
    // 构建请求数据
    RpcMessage rpcMessage = new RpcMessage();
    rpcMessage.setMessageType(MessageType.REQUEST.getValue());
    rpcMessage.setRequestId(requestId);
    byte serializer = SerializeType.fromName(request.getSerializer()).getValue();
    rpcMessage.setSerializeType(serializer);
    byte compressor = CompressType.fromName(request.getCompressor()).getValue();
    rpcMessage.setCompressTye(compressor);
    rpcMessage.setData(request);
    // 同步写数据
    Response response = doWriteAndSync(channel, rpcMessage, timeout, future);
    // 拿到响应值后,此前请求结束,那么可以移除此次请求
    SyncWriteMap.syncKey.remove(request.getRequestId());
    return response;
}      
这里大致分为几个步骤:

- 参数校验
- 通过原子自增获取请求id,然后构建此次请求的`WriteFuture`,并存入到缓存中;
- 然后就是构建`RpcMessage`,这个是Rpc通信协议的实体,这里构建了`序列化类型`、`压缩类型`、`请求id`等,然后将请求数据写入,然后标记此次消息类型为`请求类型`;
- 封装的写操作:`doWriteAndSync`,真正发送在这里,假设发送成功则会得到结果,然后将此次请求缓存的数据清楚;

然后看看`doWriteAndSync`方法:      
private Response doWriteAndSync(final Channel channel, final RpcMessage rpcMessage, final long timeout, final WriteFuture<Response> writeFuture) throws Exception {
    // 这里就不用lambda了,这里就是在channel写出一条数据之后,可以为其添加一个监听时间,也即操作完之后的一个回调方法
    // 每个 Netty 的出站 I/O 操作都将返回一个 ChannelFuture
    channel.writeAndFlush(rpcMessage).addListener(new ChannelFutureListener() {
        @Override
        public void operationComplete(ChannelFuture future) {
            // 设置此次请求的状态
            writeFuture.setWriteResult(future.isSuccess());
            // 如果失败,此次的原因
            writeFuture.setCause(future.cause());
            // 失败移除此次请求
            if (!writeFuture.isWriteSuccess()) {
                SyncWriteMap.syncKey.remove(writeFuture.requestId());
            }
        }
    });

    // 请求完成之后,这里会去模拟等待,get的时候是无法去拿到资源的,这里设置一个等待事件
    Response response = writeFuture.get(timeout, TimeUnit.SECONDS);
    if (response == null) {
        // 已经超时则抛出异常
        if (writeFuture.isTimeout()) {throw new TimeoutException();
        ...
    }
    // 否则返回响应,此次类似 feign的调用,等到请求,过了一段时间还没有拿到结果,则抛出超时异常,否则成功
    return response;
}      
这里同样简单分析一下:

- 首先是真正的写操作`channel.writeAndFlush(rpcMessage)`;然后会为写操作做个监听,设置一些状态信息;
- 然后就是等待获取请求的结果,这里get等待;这里需要结合两个handler看,client-handler里面获取服务端的响应后会去设置,下面跟着心跳检测看看两个handler的内容;      

双handler和心跳

看看心跳处理,在启动双端的时候:

client的处理:

b.handler(new ChannelInitializer<SocketChannel>() {
    @Override
    public void initChannel(SocketChannel ch) throws Exception {
        ch.pipeline().addLast(
                // 设定 IdleStateHandler 心跳检测每 5 秒进行一次写检测
                // write()方法超过 5 秒没调用,就调用 userEventTrigger
                new IdleStateHandler(0, beatTime, 0, TimeUnit.SECONDS),
                new RpcMessageDecoder(),
                new RpcMessageEncoder(),
                new ClientSocketHandler());
    }
});      

client-handler里面:

public class ClientSocketHandler extends SimpleChannelInboundHandler<RpcMessage> {

    @Override
    protected void channelRead0(ChannelHandlerContext channelHandlerContext, RpcMessage rpcMessage) throws Exception {
        // 拿到响应值
        Response msg = (Response) rpcMessage.getData();
        long requestId = rpcMessage.getRequestId();
        // 拿到此次请求的id,对应的缓存信息
        SyncWriteFuture future = (SyncWriteFuture) SyncWriteMap.syncKey.get(requestId);
        // 这里拿到了结果,就设置响应值
        if (future != null) {
            future.setResponse(msg);
        }
    }

    @Override
    public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
        if (evt instanceof IdleStateEvent) {
            // 心跳
            IdleState state = ((IdleStateEvent) evt).state();
            if (state == IdleState.WRITER_IDLE) {
                SimpleRpcLog.info("write idle happen [{}]", ctx.channel().remoteAddress());
                Channel channel = ctx.channel();
                RpcMessage rpcMessage = new RpcMessage();
                rpcMessage.setSerializeType(SerializeType.PROTOSTUFF.getValue());
                rpcMessage.setCompressTye(CompressType.GZIP.getValue());
                rpcMessage.setMessageType(MessageType.HEARTBEAT.getValue());
                channel.writeAndFlush(rpcMessage).addListener(ChannelFutureListener.CLOSE_ON_FAILURE);
            }
        } else {
            super.userEventTriggered(ctx, evt);
        }
    }
}      
  • ​new IdleStateHandler​

    ​​:这里启动的时候进行心跳连接,根据传进来的心跳时间进行判断,如果指定时间内没有写操作,则调用​

    ​userEventTriggered​

    ​;这里会去发送心跳消息维持连接;
  • 使用自定义的编解码,这里前面单独讲了;
  • 客户端的​

    ​channelRead0​

    ​​里面会设置响应,然后释放锁,同时上面讲的工具类里面的​

    ​get​

    ​方法就能拿到对应的结果,最终返回;

server的处理:

.childHandler(new ChannelInitializer<SocketChannel>() {
    @Override
    public void initChannel(SocketChannel ch) {
        ch.pipeline().addLast(
                // 30 秒之内没有收到客户端请求的话就关闭连接
                new IdleStateHandler(stopConnectTime, 0, 0, TimeUnit.SECONDS),
                new RpcMessageDecoder(),
                new RpcMessageEncoder(),
                new ServerSocketHandler());
    }
});      

server-handler的处理:

public class ServerSocketHandler extends SimpleChannelInboundHandler<RpcMessage> {

    @Override
    protected void channelRead0(ChannelHandlerContext ctx, RpcMessage rpcMessage) throws Exception {
        try {
            // 不理心跳消息
            if (rpcMessage.getMessageType() != MessageType.REQUEST.getValue()) {
                return;
            }
            // 拿到请求参数
            Request msg = (Request) rpcMessage.getData();
            //调用
            Class<?> classType = ClassLoaderUtils.forName(msg.getInterfaceName());
            Method addMethod = classType.getMethod(msg.getMethodName(), msg.getParamTypes());
            // 从缓存中里面获取bean信息
            Object objectBean = SimpleRpcServiceCache.getService(msg.getAlias());
            // 进行反射调用
            Object result = addMethod.invoke(objectBean, msg.getArgs());
            //反馈
            Response response = new Response();
            response.setRequestId(rpcMessage.getRequestId());
            response.setResult(result);
            // 构建返回值
            RpcMessage responseRpcMsg = RpcMessage.copy(rpcMessage);
            responseRpcMsg.setMessageType(MessageType.RESPONSE.getValue());
            responseRpcMsg.setData(response);
            ctx.writeAndFlush(responseRpcMsg);
            //释放
            ReferenceCountUtil.release(msg);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    @Override
    public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
        // 处理空闲状态的
        if (evt instanceof IdleStateEvent) {
            IdleState state = ((IdleStateEvent) evt).state();
            if (state == IdleState.READER_IDLE) {
                SimpleRpcLog.info("idle check happen, so close the connection");
                ctx.close();
            }
        } else {
            super.userEventTriggered(ctx, evt);
        }
    }
}      
  • 服务端维持心跳的原理是​

    ​stopConnectTime​

    ​​对应时间内没有读到数据,那么会跟客户端断开连接,这里的话是可以进行配置的,如果没有读区到,在server-handler里面的​

    ​userEventTriggered​

    ​​会去断开连接​

    ​ctx.close()​

    ​;
  • ​channelRead0​

    ​里面做的事情就是拿到请求参数,然后从服务缓存中拿到bean,然后进行反射调用,拿到调用结果之后,在写出;这里就不需要做等待处理,直接入站经过client的handler;

代理整合发送功能

对外提供一个代理类​

​RpcProxy​

​,一般服务整合的过程中,需要用到指定接口的,可以在bean注入的时候选择用代理创建:

public class RpcProxy {
    public static <T> T invoke(Class<T> interfaceClass, CommonConfig config) {
        InvocationHandler handler = new RpcInvocationHandler(config);
        ClassLoader classLoader = ClassLoaderUtils.getCurrentClassLoader();
        T result = (T) Proxy.newProxyInstance(classLoader, new Class[]{interfaceClass}, handler);
        return result;
    }
}      

这里面的​

​CommonConfig​

​​是一些公共配置,主要的处理逻辑是在​

​RpcInvocationHandler​

​​里面,然后里面主要就两个​

​invoke​

​​和​

​connect​

private void connect(Request request) {
    channelFuture = ConnectCache.getChannelFuture(request);
    if (this.channelFuture != null && this.channelFuture.channel().isOpen()) {
        return;
    }
    synchronized (this) {
        if (this.channelFuture != null && this.channelFuture.channel().isOpen()) {
            return;
        }
        //获取通信channel
        if (null == this.channelFuture) {
            RpcClientSocket clientSocket = new RpcClientSocket(request);
            executorService.submit(clientSocket);
            int tryNum = Objects.isNull(request.getRetryNum()) || request.getRetryNum() <= 0 ? 100 : request.getRetryNum();
            for (int i = 0; i < tryNum; i++) {
                if (null != this.channelFuture) {break;}
                try {Thread.sleep(500);} catch (InterruptedException e) {e.printStackTrace();}
                this.channelFuture = clientSocket.getFuture();
            }
        }
        if (null == this.channelFuture) {
            throw new NettyInitException("客户端未连接上服务端,考虑增加重试次数");
        }
        request.setChannelFuture(channelFuture);
        ConnectCache.saveChannelFuture(request);
    }
}      

这里有几点:

  • 首先尝试从连接缓存中拿到对应的连接,如果存在且通道是开启的,那么不需要重新连接;
  • 反之需要从注册中心拿到对应服务器的信息,进行连接,并保存此次连接;
  • 这里存在一个配置,因为客户端连接服务端是异步连接操作,这里会根据​

    ​channelFuture​

    ​​是否返回来判断是否连接成功,所以这里​

    ​tryNum​

    ​来重试判断是否连接成功;
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
    // 连接客户端
    connect(buildRequest());
    String methodName = method.getName();
    Class[] paramTypes = method.getParameterTypes();
    // 排除Object的方法调用
    if (JavaKeywordConstant.TO_STRING.equals(methodName) && paramTypes.length == 0) {
        return this.toString();
    } else if (JavaKeywordConstant.HASHCODE.equals(methodName) && paramTypes.length == 0) {
        return this.hashCode();
    } else if (JavaKeywordConstant.EQUALS.equals(methodName) && paramTypes.length == 1) {
        return this.equals(args[0]);
    }
    Request request = new Request();
    ConsumerConfig consumerConfig = commonConfig.getConsumerConfig();
    BaseConfig baseConfig = commonConfig.getBaseConfig();
    //设置参数
    request.setMethodName(methodName);
    request.setParamTypes(paramTypes);
    request.setArgs(args);
    request.setBeanName(consumerConfig.getBeanName());
    request.setInterfaceName(consumerConfig.getInterfaceName());
    request.setChannel(channelFuture.channel());
    request.setAlias(consumerConfig.getAlias());
    request.setSerializer(baseConfig.getSerializer());
    request.setRegister(baseConfig.getRegister());
    request.setCompressor(baseConfig.getCompressor());
    request.setTimeout(baseConfig.getTimeout());
    // 发送请求
    Response response = null;
    try {
        response = new SyncWrite().writeAndSync(request.getChannel(), request,
                Objects.isNull(request.getTimeout()) ? 30L : request.getTimeout());
    } catch (Exception e) {
        e.printStackTrace();
        SimpleRpcLog.error(e.getMessage());
    }

    //异步调用
    return response.getResult();
}      
  • 首先是建立连接操作;
  • 然后是排除Object的方法调用;
  • 然后就是构建参数进行写操作:​

    ​new SyncWrite().writeAndSync​

    ​​;拿到​

    ​response​

    ​则返回结果,完成此次的调用;