手写PRC框架<二>客户端服务端通信

1. Netty网络通信

蕴含了Dubbo的Protocal,Transport,exchange,Serializable三层。其实都是基于Netty的实现网络通信。

代码github地址

2. 封装请求和响应

本文封装了三个最简单的实体类:
● RpcMessage:请求信息
● RpcRequest:RPC请求
● RpcResponse:RPC响应

package github.javaguide.remoting.dto;


import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.Setter;
import lombok.ToString;

/**
 * @author wangtao
 * @createTime 2020年10月2日 12:33
 */
@AllArgsConstructor
@NoArgsConstructor
@Getter
@Setter
@Builder
@ToString
public class RpcMessage {

    /**
     * rpc message type
     */
    private byte messageType;
    /**
     * serialization type
     */
    private byte codec;
    /**
     * compress type
     */
    private byte compress;
    /**
     * request id
     */
    private int requestId;
    /**
     * request data
     */
    private Object data;

}

● 上面的请求类型有四种,分别是:
○ REQUEST_TYPE:请求类型,编码为1
○ RESPONSE_TYPE:响应类型,编码为2
○ HEARTBEAT_REQUEST_TYPE:心跳请求,编码为3
○ HEARTBEAT_RESPONSE_TYPE:心跳响应,编码为4

package github.javaguide.remoting.dto;

import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.ToString;

import java.io.Serializable;

/**
 * @author shuang.kou
 * @createTime 2020年05月10日 08:24:00
 */
@AllArgsConstructor
@NoArgsConstructor
@Getter
@Builder
@ToString
public class RpcRequest implements Serializable {
    private static final long serialVersionUID = 1905122041950251207L;
    private String requestId;
    private String interfaceName;
    private String methodName;
    private Object[] parameters;
    private Class<?>[] paramTypes;
    private String version;
    private String group;

    public String getRpcServiceName() {
        return this.getInterfaceName() + this.getGroup() + this.getVersion();
    }
}


● RPC请求加上了version和group,进行分组以及版本的匹配过滤
● PRC请求当然要包括:请求的id(唯一标识),请求的方法(接口名称,方法名称,方法参数,方法参数类型)

package github.javaguide.remoting.dto;

import github.javaguide.enums.RpcResponseCodeEnum;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.Setter;
import lombok.ToString;

import java.io.Serializable;

/**
 * @author shuang.kou
 * @createTime 2020年05月12日 16:15:00
 */
@AllArgsConstructor
@NoArgsConstructor
@Getter
@Setter
@Builder
@ToString
public class RpcResponse<T> implements Serializable {

    private static final long serialVersionUID = 715745410605631233L;
    private String requestId;
    /**
     * response code
     */
    private Integer code;
    /**
     * response message
     */
    private String message;
    /**
     * response body
     */
    private T data;

    public static <T> RpcResponse<T> success(T data, String requestId) {
        RpcResponse<T> response = new RpcResponse<>();
        response.setCode(RpcResponseCodeEnum.SUCCESS.getCode());
        response.setMessage(RpcResponseCodeEnum.SUCCESS.getMessage());
        response.setRequestId(requestId);
        if (null != data) {
            response.setData(data);
        }
        return response;
    }

    public static <T> RpcResponse<T> fail(RpcResponseCodeEnum rpcResponseCodeEnum) {
        RpcResponse<T> response = new RpcResponse<>();
        response.setCode(rpcResponseCodeEnum.getCode());
        response.setMessage(rpcResponseCodeEnum.getMessage());
        return response;
    }

}

● RPC响应可以参考Http的影响进行编写,包括状态码,返回信息,以及响应数据。
● 最后需要加上构建成功和失败的方法,方便使用

3. 序列化算法

本文实现了三种高性能的序列化算法:Hessian,Kroy和protostuff,相比于Java自带的序列化算法,可以大大的提升效率

package github.javaguide.serialize.hessian;


import com.caucho.hessian.io.HessianInput;
import com.caucho.hessian.io.HessianOutput;
import github.javaguide.exception.SerializeException;
import github.javaguide.serialize.Serializer;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;

/**
 * Hessian is a dynamically-typed, binary serialization and Web Services protocol designed for object-oriented transmission.
 *
 * @author Vinlee Xiao
 * @createTime 2022/2/23 21:11
 */
public class HessianSerializer implements Serializer {
    @Override
    public byte[] serialize(Object obj) {
        try (ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream()) {
            HessianOutput hessianOutput = new HessianOutput(byteArrayOutputStream);
            hessianOutput.writeObject(obj);

            return byteArrayOutputStream.toByteArray();
        } catch (Exception e) {
            throw new SerializeException("Serialization failed");
        }

    }

    @Override
    public <T> T deserialize(byte[] bytes, Class<T> clazz) {

        try (ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(bytes)) {
            HessianInput hessianInput = new HessianInput(byteArrayInputStream);
            Object o = hessianInput.readObject();

            return clazz.cast(o);

        } catch (Exception e) {
            throw new SerializeException("Deserialization failed");
        }

    }
}

package github.javaguide.serialize.kyro;

import com.esotericsoftware.kryo.Kryo;
import com.esotericsoftware.kryo.io.Input;
import com.esotericsoftware.kryo.io.Output;
import github.javaguide.exception.SerializeException;
import github.javaguide.remoting.dto.RpcRequest;
import github.javaguide.remoting.dto.RpcResponse;
import github.javaguide.serialize.Serializer;
import lombok.extern.slf4j.Slf4j;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;

/**
 * Kryo serialization class, Kryo serialization efficiency is very high, but only compatible with Java language
 *
 * @author shuang.kou
 * @createTime 2020年05月13日 19:29:00
 */
@Slf4j
public class KryoSerializer implements Serializer {

    /**
     * kryo线程不安全,所以每一个线程自带一个序列化器比较好
     * Because Kryo is not thread safe. So, use ThreadLocal to store Kryo objects
     */
    private final ThreadLocal<Kryo> kryoThreadLocal = ThreadLocal.withInitial(() -> {
        Kryo kryo = new Kryo();
        kryo.register(RpcResponse.class);
        kryo.register(RpcRequest.class);
        return kryo;
    });

    @Override
    public byte[] serialize(Object obj) {
        try (ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
             Output output = new Output(byteArrayOutputStream)) {
            Kryo kryo = kryoThreadLocal.get();
            // Object->byte:将对象序列化为byte数组
            kryo.writeObject(output, obj);
            output.flush();
            return byteArrayOutputStream.toByteArray();
        } catch (Exception e) {
            log.error("Serialization failed", e);
            throw new SerializeException("Serialization failed", e);
        }
    }

    @Override
    public <T> T deserialize(byte[] bytes, Class<T> clazz) {
        try (ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(bytes);
             Input input = new Input(byteArrayInputStream)) {
            Kryo kryo = kryoThreadLocal.get();
            // byte->Object:从byte数组中反序列化出对对象
            return kryo.readObject(input, clazz);
        } catch (Exception e) {
            log.error("Deserialization failed", e);
            throw new SerializeException("Deserialization failed", e);
        }
    }

    public class SerializeException extends RuntimeException {
        public SerializeException(String message) {
            super(message);
        }

        public SerializeException(String message, Throwable cause) {
            super(message, cause);
        }
    }

}

package github.javaguide.serialize.protostuff;

import github.javaguide.serialize.Serializer;
import io.protostuff.LinkedBuffer;
import io.protostuff.ProtostuffIOUtil;
import io.protostuff.Schema;
import io.protostuff.runtime.RuntimeSchema;

/**
 * @author TangMinXuan
 * @createTime 2020年11月09日 20:13
 */
public class ProtostuffSerializer implements Serializer {

    /**
     * Avoid re applying buffer space every time serialization
     */
    private static final LinkedBuffer BUFFER = LinkedBuffer.allocate(LinkedBuffer.DEFAULT_BUFFER_SIZE);

    @Override
    public byte[] serialize(Object obj) {
        Class<?> clazz = obj.getClass();
        Schema schema = RuntimeSchema.getSchema(clazz);
        byte[] bytes;
        try {
            bytes = ProtostuffIOUtil.toByteArray(obj, schema, BUFFER);
        } finally {
            BUFFER.clear();
        }
        return bytes;
    }

    @Override
    public <T> T deserialize(byte[] bytes, Class<T> clazz) {
        Schema<T> schema = RuntimeSchema.getSchema(clazz);
        T obj = schema.newMessage();
        ProtostuffIOUtil.mergeFrom(bytes, obj, schema);
        return obj;
    }
}


package github.javaguide.serialize;

import github.javaguide.extension.SPI;

/**
 * 序列化接口,所有序列化类都要实现这个接口
 *
 * @author shuang.kou
 * @createTime 2020年05月13日 19:29:00
 */
@SPI
public interface Serializer {
    /**
     * 序列化
     *
     * @param obj 要序列化的对象
     * @return 字节数组
     */
    byte[] serialize(Object obj);

    /**
     * 反序列化
     *
     * @param bytes 序列化后的字节数组
     * @param clazz 目标类
     * @param    类的类型。举个例子,  {@code String.class} 的类型是 {@code Class}.
     *              如果不知道类的类型的话,使用 {@code Class}
     * @return 反序列化的对象
     */
    <T> T deserialize(byte[] bytes, Class<T> clazz);
}

此处需要注意的是Kroy算法,线程不安全,所以针对每一个线程,最好都单独开一个序列化器对象

4. 压缩算法

本文使用了Gzip进行了传输压缩

package github.javaguide.compress.gzip;

import github.javaguide.compress.Compress;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream;

/**
 * @author wangtao .
 * @createTime on 2020/10/3
 */

public class GzipCompress implements Compress {


    private static final int BUFFER_SIZE = 1024 * 4;

    @Override
    public byte[] compress(byte[] bytes) {
        if (bytes == null) {
            throw new NullPointerException("bytes is null");
        }
        try (ByteArrayOutputStream out = new ByteArrayOutputStream();
             GZIPOutputStream gzip = new GZIPOutputStream(out)) {
            gzip.write(bytes);
            gzip.flush();
            gzip.finish();
            return out.toByteArray();
        } catch (IOException e) {
            throw new RuntimeException("gzip compress error", e);
        }
    }

    @Override
    public byte[] decompress(byte[] bytes) {
        if (bytes == null) {
            throw new NullPointerException("bytes is null");
        }
        try (ByteArrayOutputStream out = new ByteArrayOutputStream();
             GZIPInputStream gunzip = new GZIPInputStream(new ByteArrayInputStream(bytes))) {
            byte[] buffer = new byte[BUFFER_SIZE];
            int n;
            while ((n = gunzip.read(buffer)) > -1) {
                out.write(buffer, 0, n);
            }
            return out.toByteArray();
        } catch (IOException e) {
            throw new RuntimeException("gzip decompress error", e);
        }
    }
}

package github.javaguide.compress;

import github.javaguide.extension.SPI;

/**
 * @author wangtao .
 * @createTime on 2020/10/3
 */

@SPI
public interface Compress {

    byte[] compress(byte[] bytes);


    byte[] decompress(byte[] bytes);
}

5. 服务端网路通信

● 本文采用了netty进行通信
● 同时采用了TCP协议进行客户端和服务端的通信协议
● 为了解决粘包问题,采用了固定长度的解码器。

  • netty服务端代码
package github.javaguide.remoting.transport.netty.server;

import github.javaguide.remoting.transport.CustomShutdownHook;
import github.javaguide.config.RpcServiceConfig;
import github.javaguide.factory.SingletonFactory;
import github.javaguide.provider.ServiceProvider;
import github.javaguide.provider.impl.ZkServiceProviderImpl;
import github.javaguide.remoting.transport.netty.codec.RpcMessageDecoder;
import github.javaguide.remoting.transport.netty.codec.RpcMessageEncoder;
import github.javaguide.utils.RuntimeUtil;
import github.javaguide.utils.concurrent.threadpool.ThreadPoolFactoryUtil;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.handler.logging.LogLevel;
import io.netty.handler.logging.LoggingHandler;
import io.netty.handler.timeout.IdleStateHandler;
import io.netty.util.concurrent.DefaultEventExecutorGroup;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;

import java.net.InetAddress;
import java.util.concurrent.TimeUnit;

/**
 * Server. Receive the client message, call the corresponding method according to the client message,
 * and then return the result to the client.
 *
 * @author shuang.kou
 * @createTime 2020年05月25日 16:42:00
 */
@Slf4j
@Component
public class NettyRpcServer {

    public static final int PORT = 9998;

    private final ServiceProvider serviceProvider = SingletonFactory.getInstance(ZkServiceProviderImpl.class);

    public void registerService(RpcServiceConfig rpcServiceConfig) {
        serviceProvider.publishService(rpcServiceConfig);
    }

    @SneakyThrows
    public void start() {
        CustomShutdownHook.getCustomShutdownHook().clearAll();
        String host = InetAddress.getLocalHost().getHostAddress();
        // 1. 三个线程池
        // 1.1 主react线程池,连接事件监听线程池
        EventLoopGroup bossGroup = new NioEventLoopGroup(1);
        // 1.2 读写事件监听线程池
        EventLoopGroup workerGroup = new NioEventLoopGroup();
        // 1.3 handler处理线程池
        DefaultEventExecutorGroup serviceHandlerGroup = new DefaultEventExecutorGroup(
                RuntimeUtil.cpus() * 2,
                ThreadPoolFactoryUtil.createThreadFactory("service-handler-group", false)
        );
        try {
            ServerBootstrap b = new ServerBootstrap();
            b.group(bossGroup, workerGroup)
                    .channel(NioServerSocketChannel.class)
                    // TCP默认开启了 Nagle 算法,该算法的作用是尽可能的发送大数据快,减少网络传输。TCP_NODELAY 参数的作用就是控制是否启用 Nagle 算法。
                    .childOption(ChannelOption.TCP_NODELAY, true)
                    // 是否开启 TCP 底层心跳机制
                    .childOption(ChannelOption.SO_KEEPALIVE, true)
                    //表示系统用于临时存放已完成三次握手的请求的队列的最大长度,如果连接建立频繁,服务器处理创建新连接较慢,可以适当调大这个参数
                    .option(ChannelOption.SO_BACKLOG, 128)
                    .handler(new LoggingHandler(LogLevel.INFO))
                    // 当客户端第一次进行请求的时候才会进行初始化
                    .childHandler(new ChannelInitializer<SocketChannel>() {
                        @Override
                        protected void initChannel(SocketChannel ch) {
                            // 30 秒之内没有收到客户端请求的话就关闭连接
                            ChannelPipeline p = ch.pipeline();
                            p.addLast(new IdleStateHandler(30, 0, 0, TimeUnit.SECONDS));
                            p.addLast(new RpcMessageEncoder());
                            p.addLast(new RpcMessageDecoder());
                            p.addLast(serviceHandlerGroup, new NettyRpcServerHandler());
                        }
                    });

            // 绑定端口,同步等待绑定成功
            ChannelFuture f = b.bind(host, PORT).sync();
            // 等待服务端监听端口关闭
            f.channel().closeFuture().sync();
        } catch (InterruptedException e) {
            log.error("occur exception when start server:", e);
        } finally {
            log.error("shutdown bossGroup and workerGroup");
            bossGroup.shutdownGracefully();
            workerGroup.shutdownGracefully();
            serviceHandlerGroup.shutdownGracefully();
        }
    }


}

  • 对应的handler,进行rpc的解析

package github.javaguide.remoting.transport.netty.server;

import github.javaguide.enums.CompressTypeEnum;
import github.javaguide.enums.RpcResponseCodeEnum;
import github.javaguide.enums.SerializationTypeEnum;
import github.javaguide.factory.SingletonFactory;
import github.javaguide.remoting.constants.RpcConstants;
import github.javaguide.remoting.dto.RpcMessage;
import github.javaguide.remoting.dto.RpcRequest;
import github.javaguide.remoting.dto.RpcResponse;
import github.javaguide.remoting.handler.RpcRequestHandler;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.handler.timeout.IdleState;
import io.netty.handler.timeout.IdleStateEvent;
import io.netty.util.ReferenceCountUtil;
import lombok.extern.slf4j.Slf4j;

/**
 * Customize the ChannelHandler of the server to process the data sent by the client.
 * 

* 如果继承自 SimpleChannelInboundHandler 的话就不要考虑 ByteBuf 的释放 ,{@link SimpleChannelInboundHandler} 内部的 * channelRead 方法会替你释放 ByteBuf ,避免可能导致的内存泄露问题。详见《Netty进阶之路 跟着案例学 Netty》 * * @author shuang.kou * @createTime 2020年05月25日 20:44:00 */ @Slf4j public class NettyRpcServerHandler extends ChannelInboundHandlerAdapter { private final RpcRequestHandler rpcRequestHandler; public NettyRpcServerHandler() { this.rpcRequestHandler = SingletonFactory.getInstance(RpcRequestHandler.class); } @Override public void channelRead(ChannelHandlerContext ctx, Object msg) { try { if (msg instanceof RpcMessage) { log.info("server receive msg: [{}] ", msg); byte messageType = ((RpcMessage) msg).getMessageType(); RpcMessage rpcMessage = new RpcMessage(); rpcMessage.setCodec(SerializationTypeEnum.HESSIAN.getCode()); rpcMessage.setCompress(CompressTypeEnum.GZIP.getCode()); if (messageType == RpcConstants.HEARTBEAT_REQUEST_TYPE) { // 心跳请求,直接返回响应pong rpcMessage.setMessageType(RpcConstants.HEARTBEAT_RESPONSE_TYPE); rpcMessage.setData(RpcConstants.PONG); } else { // rpc调用请求 RpcRequest rpcRequest = (RpcRequest) ((RpcMessage) msg).getData(); // Execute the target method (the method the client needs to execute) and return the method result // 通过反射获取到对应的方法,然后执行,最后返回结果 Object result = rpcRequestHandler.handle(rpcRequest); log.info(String.format("server get result: %s", result.toString())); rpcMessage.setMessageType(RpcConstants.RESPONSE_TYPE); if (ctx.channel().isActive() && ctx.channel().isWritable()) { RpcResponse<Object> rpcResponse = RpcResponse.success(result, rpcRequest.getRequestId()); rpcMessage.setData(rpcResponse); } else { RpcResponse<Object> rpcResponse = RpcResponse.fail(RpcResponseCodeEnum.FAIL); rpcMessage.setData(rpcResponse); log.error("not writable now, message dropped"); } } ctx.writeAndFlush(rpcMessage).addListener(ChannelFutureListener.CLOSE_ON_FAILURE); } } finally { //Ensure that ByteBuf is released, otherwise there may be memory leaks ReferenceCountUtil.release(msg); } } @Override public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { if (evt instanceof IdleStateEvent) { IdleState state = ((IdleStateEvent) evt).state(); if (state == IdleState.READER_IDLE) { log.info("idle check happen, so close the connection"); ctx.close(); } } else { super.userEventTriggered(ctx, evt); } } @Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { log.error("server catch exception"); cause.printStackTrace(); ctx.close(); } }

上面代码的核心逻辑是

  1. 服务端,也就是Provider端,接收到心跳请求,直接返回心跳响应
  2. 使用RpcRequestHandler处理请求,这个将会在后续文章中说明。简单来说就是在服务注册的时候,服务端本地会生成一个{名称,服务对象}的map存根。之后通过名称可以获取服务对象,进而处理RpcRequest
  3. 返回RpcResponse请求

另外为了解决TCP的粘包问题,使用了固定长度的编/解码器

  • 编码器
package github.javaguide.remoting.transport.netty.codec;


import github.javaguide.compress.Compress;
import github.javaguide.enums.CompressTypeEnum;
import github.javaguide.enums.SerializationTypeEnum;
import github.javaguide.extension.ExtensionLoader;
import github.javaguide.remoting.constants.RpcConstants;
import github.javaguide.remoting.dto.RpcMessage;
import github.javaguide.serialize.Serializer;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.MessageToByteEncoder;
import lombok.extern.slf4j.Slf4j;

import java.util.concurrent.atomic.AtomicInteger;


/**
 * 

* 用户自定义编码器 *

*

 *   0     1     2     3     4        5     6     7     8         9          10      11     12  13  14   15 16
 *   +-----+-----+-----+-----+--------+----+----+----+------+-----------+-------+----- --+-----+-----+-------+
 *   |   magic   code        |version | full length         | messageType| codec|compress|    RequestId       |
 *   +-----------------------+--------+---------------------+-----------+-----------+-----------+------------+
 *   |                                                                                                       |
 *   |                                         body                                                          |
 *   |                                                                                                       |
 *   |                                        ... ...                                                        |
 *   +-------------------------------------------------------------------------------------------------------+
 * 4B  magic code(魔法数)   1B version(版本)   4B full length(消息长度)    1B messageType(消息类型)
 * 1B compress(压缩类型) 1B codec(序列化类型)    4B  requestId(请求的Id)
 * body(object类型数据)
 * 
* * @author WangTao * @createTime on 2020/10/2 * @see LengthFieldBasedFrameDecoder解码器 */
@Slf4j public class RpcMessageEncoder extends MessageToByteEncoder<RpcMessage> { private static final AtomicInteger ATOMIC_INTEGER = new AtomicInteger(0); @Override protected void encode(ChannelHandlerContext ctx, RpcMessage rpcMessage, ByteBuf out) { try { out.writeBytes(RpcConstants.MAGIC_NUMBER); out.writeByte(RpcConstants.VERSION); // leave a place to write the value of full length out.writerIndex(out.writerIndex() + 4); byte messageType = rpcMessage.getMessageType(); out.writeByte(messageType); out.writeByte(rpcMessage.getCodec()); out.writeByte(CompressTypeEnum.GZIP.getCode()); out.writeInt(ATOMIC_INTEGER.getAndIncrement()); // build full length byte[] bodyBytes = null; int fullLength = RpcConstants.HEAD_LENGTH; // if messageType is not heartbeat message,fullLength = head length + body length if (messageType != RpcConstants.HEARTBEAT_REQUEST_TYPE && messageType != RpcConstants.HEARTBEAT_RESPONSE_TYPE) { // serialize the object String codecName = SerializationTypeEnum.getName(rpcMessage.getCodec()); log.info("codec name: [{}] ", codecName); Serializer serializer = ExtensionLoader.getExtensionLoader(Serializer.class) .getExtension(codecName); bodyBytes = serializer.serialize(rpcMessage.getData()); // compress the bytes String compressName = CompressTypeEnum.getName(rpcMessage.getCompress()); Compress compress = ExtensionLoader.getExtensionLoader(Compress.class) .getExtension(compressName); bodyBytes = compress.compress(bodyBytes); fullLength += bodyBytes.length; } if (bodyBytes != null) { out.writeBytes(bodyBytes); } int writeIndex = out.writerIndex(); out.writerIndex(writeIndex - fullLength + RpcConstants.MAGIC_NUMBER.length + 1); out.writeInt(fullLength); out.writerIndex(writeIndex); } catch (Exception e) { log.error("Encode request error!", e); } } }
  • 解码器
package github.javaguide.remoting.transport.netty.codec;

import github.javaguide.compress.Compress;
import github.javaguide.enums.CompressTypeEnum;
import github.javaguide.enums.SerializationTypeEnum;
import github.javaguide.extension.ExtensionLoader;
import github.javaguide.remoting.constants.RpcConstants;
import github.javaguide.remoting.dto.RpcMessage;
import github.javaguide.remoting.dto.RpcRequest;
import github.javaguide.remoting.dto.RpcResponse;
import github.javaguide.serialize.Serializer;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
import lombok.extern.slf4j.Slf4j;

import java.util.Arrays;

/**
 * 用户自定义解码器,为了防止粘包问题
 * 
 *   0     1     2     3     4        5     6     7     8         9          10      11     12  13  14   15 16
 *   +-----+-----+-----+-----+--------+----+----+----+------+-----------+-------+----- --+-----+-----+-------+
 *   |   magic   code        |version | full length         | messageType| codec|compress|    RequestId       |
 *   +-----------------------+--------+---------------------+-----------+-----------+-----------+------------+
 *   |                                                                                                       |
 *   |                                         body                                                          |
 *   |                                                                                                       |
 *   |                                        ... ...                                                        |
 *   +-------------------------------------------------------------------------------------------------------+
 * 4B  magic code(魔法数)   1B version(版本)   4B full length(消息长度)    1B messageType(消息类型)
 * 1B compress(压缩类型) 1B codec(序列化类型)    4B  requestId(请求的Id)
 * body(object类型数据)
 * 
*

* {@link LengthFieldBasedFrameDecoder} is a length-based decoder , used to solve TCP unpacking and sticking problems. *

* * @author wangtao * @createTime on 2020/10/2 * @see LengthFieldBasedFrameDecoder解码器 */
@Slf4j public class RpcMessageDecoder extends LengthFieldBasedFrameDecoder { public RpcMessageDecoder() { // lengthFieldOffset: magic code is 4B, and version is 1B, and then full length. so value is 5 // lengthFieldLength: full length is 4B. so value is 4 // lengthAdjustment: full length include all data and read 9 bytes before, so the left length is (fullLength-9). so values is -9 // initialBytesToStrip: we will check magic code and version manually, so do not strip any bytes. so values is 0 this(RpcConstants.MAX_FRAME_LENGTH, 5, 4, -9, 0); } /** * @param maxFrameLength Maximum frame length. It decide the maximum length of data that can be received. * If it exceeds, the data will be discarded. * @param lengthFieldOffset Length field offset. The length field is the one that skips the specified length of byte. * @param lengthFieldLength The number of bytes in the length field. * @param lengthAdjustment The compensation value to add to the value of the length field * @param initialBytesToStrip Number of bytes skipped. * If you need to receive all of the header+body data, this value is 0 * if you only want to receive the body data, then you need to skip the number of bytes consumed by the header. */ public RpcMessageDecoder(int maxFrameLength, int lengthFieldOffset, int lengthFieldLength, int lengthAdjustment, int initialBytesToStrip) { super(maxFrameLength, lengthFieldOffset, lengthFieldLength, lengthAdjustment, initialBytesToStrip); } @Override protected Object decode(ChannelHandlerContext ctx, ByteBuf in) throws Exception { Object decoded = super.decode(ctx, in); if (decoded instanceof ByteBuf) { ByteBuf frame = (ByteBuf) decoded; if (frame.readableBytes() >= RpcConstants.TOTAL_LENGTH) { try { return decodeFrame(frame); } catch (Exception e) { log.error("Decode frame error!", e); throw e; } finally { frame.release(); } } } return decoded; } private Object decodeFrame(ByteBuf in) { // note: must read ByteBuf in order checkMagicNumber(in); checkVersion(in); int fullLength = in.readInt(); // build RpcMessage object byte messageType = in.readByte(); byte codecType = in.readByte(); byte compressType = in.readByte(); int requestId = in.readInt(); RpcMessage rpcMessage = RpcMessage.builder() .codec(codecType) .requestId(requestId) .messageType(messageType).build(); if (messageType == RpcConstants.HEARTBEAT_REQUEST_TYPE) { rpcMessage.setData(RpcConstants.PING); return rpcMessage; } if (messageType == RpcConstants.HEARTBEAT_RESPONSE_TYPE) { rpcMessage.setData(RpcConstants.PONG); return rpcMessage; } int bodyLength = fullLength - RpcConstants.HEAD_LENGTH; if (bodyLength > 0) { byte[] bs = new byte[bodyLength]; in.readBytes(bs); // decompress the bytes String compressName = CompressTypeEnum.getName(compressType); Compress compress = ExtensionLoader.getExtensionLoader(Compress.class) .getExtension(compressName); bs = compress.decompress(bs); // deserialize the object String codecName = SerializationTypeEnum.getName(rpcMessage.getCodec()); log.info("codec name: [{}] ", codecName); Serializer serializer = ExtensionLoader.getExtensionLoader(Serializer.class) .getExtension(codecName); if (messageType == RpcConstants.REQUEST_TYPE) { RpcRequest tmpValue = serializer.deserialize(bs, RpcRequest.class); rpcMessage.setData(tmpValue); } else { RpcResponse tmpValue = serializer.deserialize(bs, RpcResponse.class); rpcMessage.setData(tmpValue); } } return rpcMessage; } private void checkVersion(ByteBuf in) { // read the version and compare byte version = in.readByte(); if (version != RpcConstants.VERSION) { throw new RuntimeException("version isn't compatible" + version); } } private void checkMagicNumber(ByteBuf in) { // read the first 4 bit, which is the magic number, and compare int len = RpcConstants.MAGIC_NUMBER.length; byte[] tmp = new byte[len]; in.readBytes(tmp); for (int i = 0; i < len; i++) { if (tmp[i] != RpcConstants.MAGIC_NUMBER[i]) { throw new IllegalArgumentException("Unknown magic code: " + Arrays.toString(tmp)); } } } }

6. 客户端通信

package github.javaguide.remoting.transport.netty.client;


import github.javaguide.enums.CompressTypeEnum;
import github.javaguide.enums.SerializationTypeEnum;
import github.javaguide.enums.ServiceDiscoveryEnum;
import github.javaguide.extension.ExtensionLoader;
import github.javaguide.factory.SingletonFactory;
import github.javaguide.registry.ServiceDiscovery;
import github.javaguide.remoting.constants.RpcConstants;
import github.javaguide.remoting.dto.RpcMessage;
import github.javaguide.remoting.dto.RpcRequest;
import github.javaguide.remoting.dto.RpcResponse;
import github.javaguide.remoting.transport.RpcRequestTransport;
import github.javaguide.remoting.transport.netty.codec.RpcMessageDecoder;
import github.javaguide.remoting.transport.netty.codec.RpcMessageEncoder;
import io.netty.bootstrap.Bootstrap;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.logging.LogLevel;
import io.netty.handler.logging.LoggingHandler;
import io.netty.handler.timeout.IdleStateHandler;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;

import java.net.InetSocketAddress;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;

/**
 * initialize and close Bootstrap object
 *
 * @author shuang.kou
 * @createTime 2020年05月29日 17:51:00
 */
@Slf4j
public final class NettyRpcClient implements RpcRequestTransport {
    private final ServiceDiscovery serviceDiscovery;
    /**
     * 使用completableFuture进行响应的结果等待
     * */
    private final UnprocessedRequests unprocessedRequests;
    /**
     * 缓存channel,不用每次都进行重新的tcp连接了
     * */
    private final ChannelProvider channelProvider;
    private final Bootstrap bootstrap;
    private final EventLoopGroup eventLoopGroup;

    public NettyRpcClient() {
        // initialize resources such as EventLoopGroup, Bootstrap
        eventLoopGroup = new NioEventLoopGroup();
        bootstrap = new Bootstrap();
        bootstrap.group(eventLoopGroup)
                .channel(NioSocketChannel.class)
                .handler(new LoggingHandler(LogLevel.INFO))
                //  The timeout period of the connection.
                //  If this time is exceeded or the connection cannot be established, the connection fails.
                .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, 5000)
                .handler(new ChannelInitializer<SocketChannel>() {
                    @Override
                    protected void initChannel(SocketChannel ch) {
                        ChannelPipeline p = ch.pipeline();
                        // If no data is sent to the server within 15 seconds, a heartbeat request is sent
                        p.addLast(new IdleStateHandler(0, 5, 0, TimeUnit.SECONDS));
                        p.addLast(new RpcMessageEncoder());
                        p.addLast(new RpcMessageDecoder());
                        p.addLast(new NettyRpcClientHandler());
                    }
                });
        this.serviceDiscovery = ExtensionLoader.getExtensionLoader(ServiceDiscovery.class).getExtension(ServiceDiscoveryEnum.ZK.getName());
        this.unprocessedRequests = SingletonFactory.getInstance(UnprocessedRequests.class);
        this.channelProvider = SingletonFactory.getInstance(ChannelProvider.class);
    }

    /**
     * connect server and get the channel ,so that you can send rpc message to server
     *
     * @param inetSocketAddress server address
     * @return the channel
     */
    @SneakyThrows
    public Channel doConnect(InetSocketAddress inetSocketAddress) {
        CompletableFuture<Channel> completableFuture = new CompletableFuture<>();
        bootstrap.connect(inetSocketAddress).addListener((ChannelFutureListener) future -> {
            if (future.isSuccess()) {
                log.info("The client has connected [{}] successful!", inetSocketAddress.toString());
                completableFuture.complete(future.channel());
            } else {
                throw new IllegalStateException();
            }
        });
        return completableFuture.get();
    }

    @Override
    public Object sendRpcRequest(RpcRequest rpcRequest) {
        CompletableFuture<RpcResponse<Object>> resultFuture = new CompletableFuture<>();
        // 1. 获取服务的地址
        InetSocketAddress inetSocketAddress = serviceDiscovery.lookupService(rpcRequest);
        // 2. 获取channel
        Channel channel = getChannel(inetSocketAddress);
        if (channel.isActive()) {
            // 3.发送请求
            unprocessedRequests.put(rpcRequest.getRequestId(), resultFuture);
            RpcMessage rpcMessage = RpcMessage.builder().data(rpcRequest)
                    .codec(SerializationTypeEnum.HESSIAN.getCode())
                    .compress(CompressTypeEnum.GZIP.getCode())
                    .messageType(RpcConstants.REQUEST_TYPE).build();
            channel.writeAndFlush(rpcMessage).addListener((ChannelFutureListener) future -> {
                if (future.isSuccess()) {
                    log.info("client send message: [{}]", rpcMessage);
                } else {
                    future.channel().close();
                    resultFuture.completeExceptionally(future.cause());
                    log.error("Send failed:", future.cause());
                }
            });
        } else {
            throw new IllegalStateException();
        }
        // 4. 得到响应的结果
        try {
            return resultFuture.get();
        } catch (InterruptedException | ExecutionException e) {
            throw new RuntimeException("rpc请求失败," + e.getMessage());
        }
    }

    public Channel getChannel(InetSocketAddress inetSocketAddress) {
        Channel channel = channelProvider.get(inetSocketAddress);
        if (channel == null) {
            channel = doConnect(inetSocketAddress);
            channelProvider.set(inetSocketAddress, channel);
        }
        return channel;
    }

    public void close() {
        eventLoopGroup.shutdownGracefully();
    }
}


核心逻辑,就是sendRpcRequest,这个继承了RpcRequestTransport接口

  1. 首先通过注册中心,使用负载均衡算法,获取到对应的服务端的ip:port等信息
  2. 其次从缓存中获取channel。 – ChannelProvider
  3. 然后发
  4. 送Rpc请求,等待netty的handler收到Response – NettyRpcClientHandler
  5. 最后得到响应结果 – UnprocessedRequests,这个是需要进行判断的类

下面是三个类

  • ChannelProvider:缓存通道
package github.javaguide.remoting.transport.netty.client;

import io.netty.channel.Channel;
import lombok.extern.slf4j.Slf4j;

import java.net.InetSocketAddress;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

/**
 * store and get Channel object
 *
 * @author shuang.kou
 * @createTime 2020年05月29日 16:36:00
 */
@Slf4j
public class ChannelProvider {

    private final Map<String, Channel> channelMap;

    public ChannelProvider() {
        channelMap = new ConcurrentHashMap<>();
    }

    public Channel get(InetSocketAddress inetSocketAddress) {
        String key = inetSocketAddress.toString();
        // determine if there is a connection for the corresponding address
        if (channelMap.containsKey(key)) {
            Channel channel = channelMap.get(key);
            // if so, determine if the connection is available, and if so, get it directly
            if (channel != null && channel.isActive()) {
                return channel;
            } else {
                channelMap.remove(key);
            }
        }
        return null;
    }

    public void set(InetSocketAddress inetSocketAddress, Channel channel) {
        String key = inetSocketAddress.toString();
        channelMap.put(key, channel);
    }

    public void remove(InetSocketAddress inetSocketAddress) {
        String key = inetSocketAddress.toString();
        channelMap.remove(key);
        log.info("Channel map size :[{}]", channelMap.size());
    }
}

  • NettyRpcClientHandler: 客户端解析,与服务端通信,得到对应的RpcResponse
package github.javaguide.remoting.transport.netty.client;

import github.javaguide.enums.CompressTypeEnum;
import github.javaguide.enums.SerializationTypeEnum;
import github.javaguide.factory.SingletonFactory;
import github.javaguide.remoting.constants.RpcConstants;
import github.javaguide.remoting.dto.RpcMessage;
import github.javaguide.remoting.dto.RpcResponse;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.handler.timeout.IdleState;
import io.netty.handler.timeout.IdleStateEvent;
import io.netty.util.ReferenceCountUtil;
import lombok.extern.slf4j.Slf4j;

import java.net.InetSocketAddress;

/**
 * Customize the client ChannelHandler to process the data sent by the server
 *
 * 

* 如果继承自 SimpleChannelInboundHandler 的话就不要考虑 ByteBuf 的释放 ,{@link SimpleChannelInboundHandler} 内部的 * channelRead 方法会替你释放 ByteBuf ,避免可能导致的内存泄露问题。详见《Netty进阶之路 跟着案例学 Netty》 * * @author shuang.kou * @createTime 2020年05月25日 20:50:00 */ @Slf4j public class NettyRpcClientHandler extends ChannelInboundHandlerAdapter { private final UnprocessedRequests unprocessedRequests; private final NettyRpcClient nettyRpcClient; public NettyRpcClientHandler() { this.unprocessedRequests = SingletonFactory.getInstance(UnprocessedRequests.class); this.nettyRpcClient = SingletonFactory.getInstance(NettyRpcClient.class); } /** * Read the message transmitted by the server */ @Override public void channelRead(ChannelHandlerContext ctx, Object msg) { try { log.info("client receive msg: [{}]", msg); if (msg instanceof RpcMessage) { RpcMessage tmp = (RpcMessage) msg; byte messageType = tmp.getMessageType(); if (messageType == RpcConstants.HEARTBEAT_RESPONSE_TYPE) { log.info("heart [{}]", tmp.getData()); } else if (messageType == RpcConstants.RESPONSE_TYPE) { RpcResponse<Object> rpcResponse = (RpcResponse<Object>) tmp.getData(); unprocessedRequests.complete(rpcResponse); } } } finally { ReferenceCountUtil.release(msg); } } @Override public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { if (evt instanceof IdleStateEvent) { IdleState state = ((IdleStateEvent) evt).state(); if (state == IdleState.WRITER_IDLE) { log.info("write idle happen [{}]", ctx.channel().remoteAddress()); Channel channel = nettyRpcClient.getChannel((InetSocketAddress) ctx.channel().remoteAddress()); RpcMessage rpcMessage = new RpcMessage(); rpcMessage.setCodec(SerializationTypeEnum.PROTOSTUFF.getCode()); rpcMessage.setCompress(CompressTypeEnum.GZIP.getCode()); rpcMessage.setMessageType(RpcConstants.HEARTBEAT_REQUEST_TYPE); rpcMessage.setData(RpcConstants.PING); channel.writeAndFlush(rpcMessage).addListener(ChannelFutureListener.CLOSE_ON_FAILURE); } } else { super.userEventTriggered(ctx, evt); } } /** * Called when an exception occurs in processing a client message */ @Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { log.error("client catch exception:", cause); cause.printStackTrace(); ctx.close(); } }

  • UnprocessedRequests:使用CompletableFuture进行结果等待,增加可读性

package github.javaguide.remoting.transport.netty.client;

import github.javaguide.remoting.dto.RpcResponse;

import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;

/**
 * unprocessed requests by the server.
 *
 * @author shuang.kou
 * @createTime 2020年06月04日 17:30:00
 */
public class UnprocessedRequests {
    private static final Map<String, CompletableFuture<RpcResponse<Object>>> UNPROCESSED_RESPONSE_FUTURES = new ConcurrentHashMap<>();

    public void put(String requestId, CompletableFuture<RpcResponse<Object>> future) {
        UNPROCESSED_RESPONSE_FUTURES.put(requestId, future);
    }

    public void complete(RpcResponse<Object> rpcResponse) {
        CompletableFuture<RpcResponse<Object>> future = UNPROCESSED_RESPONSE_FUTURES.remove(rpcResponse.getRequestId());
        if (null != future) {
            future.complete(rpcResponse);
        } else {
            throw new IllegalStateException();
        }
    }
}

  • RpcRequestTransport: rpc远程请求的接口

RpcRequestTransport这个接口,可以根据不同的客户端、服务端类型进行扩展,比如使用Socket而不是使用Netty进行发送消息
@SPI
pub
lic interface RpcRequestTransport {
    /**
     * send rpc request to server and get result
     *
     * @param rpcRequest message body
     * @return data from server
     */
    Object sendRpcRequest(RpcRequest rpcRequest) throws Exception;
}

你可能感兴趣的:(后台,rpc)