蕴含了Dubbo的Protocal,Transport,exchange,Serializable三层。其实都是基于Netty的实现网络通信。
代码github地址
本文封装了三个最简单的实体类:
● RpcMessage:请求信息
● RpcRequest:RPC请求
● RpcResponse:RPC响应
package github.javaguide.remoting.dto;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.Setter;
import lombok.ToString;
/**
* @author wangtao
* @createTime 2020年10月2日 12:33
*/
@AllArgsConstructor
@NoArgsConstructor
@Getter
@Setter
@Builder
@ToString
public class RpcMessage {
/**
* rpc message type
*/
private byte messageType;
/**
* serialization type
*/
private byte codec;
/**
* compress type
*/
private byte compress;
/**
* request id
*/
private int requestId;
/**
* request data
*/
private Object data;
}
● 上面的请求类型有四种,分别是:
○ REQUEST_TYPE:请求类型,编码为1
○ RESPONSE_TYPE:响应类型,编码为2
○ HEARTBEAT_REQUEST_TYPE:心跳请求,编码为3
○ HEARTBEAT_RESPONSE_TYPE:心跳响应,编码为4
package github.javaguide.remoting.dto;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.ToString;
import java.io.Serializable;
/**
* @author shuang.kou
* @createTime 2020年05月10日 08:24:00
*/
@AllArgsConstructor
@NoArgsConstructor
@Getter
@Builder
@ToString
public class RpcRequest implements Serializable {
private static final long serialVersionUID = 1905122041950251207L;
private String requestId;
private String interfaceName;
private String methodName;
private Object[] parameters;
private Class<?>[] paramTypes;
private String version;
private String group;
public String getRpcServiceName() {
return this.getInterfaceName() + this.getGroup() + this.getVersion();
}
}
● RPC请求加上了version和group,进行分组以及版本的匹配过滤
● PRC请求当然要包括:请求的id(唯一标识),请求的方法(接口名称,方法名称,方法参数,方法参数类型)
package github.javaguide.remoting.dto;
import github.javaguide.enums.RpcResponseCodeEnum;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.Setter;
import lombok.ToString;
import java.io.Serializable;
/**
* @author shuang.kou
* @createTime 2020年05月12日 16:15:00
*/
@AllArgsConstructor
@NoArgsConstructor
@Getter
@Setter
@Builder
@ToString
public class RpcResponse<T> implements Serializable {
private static final long serialVersionUID = 715745410605631233L;
private String requestId;
/**
* response code
*/
private Integer code;
/**
* response message
*/
private String message;
/**
* response body
*/
private T data;
public static <T> RpcResponse<T> success(T data, String requestId) {
RpcResponse<T> response = new RpcResponse<>();
response.setCode(RpcResponseCodeEnum.SUCCESS.getCode());
response.setMessage(RpcResponseCodeEnum.SUCCESS.getMessage());
response.setRequestId(requestId);
if (null != data) {
response.setData(data);
}
return response;
}
public static <T> RpcResponse<T> fail(RpcResponseCodeEnum rpcResponseCodeEnum) {
RpcResponse<T> response = new RpcResponse<>();
response.setCode(rpcResponseCodeEnum.getCode());
response.setMessage(rpcResponseCodeEnum.getMessage());
return response;
}
}
● RPC响应可以参考Http的影响进行编写,包括状态码,返回信息,以及响应数据。
● 最后需要加上构建成功和失败的方法,方便使用
本文实现了三种高性能的序列化算法:Hessian,Kroy和protostuff,相比于Java自带的序列化算法,可以大大的提升效率
package github.javaguide.serialize.hessian;
import com.caucho.hessian.io.HessianInput;
import com.caucho.hessian.io.HessianOutput;
import github.javaguide.exception.SerializeException;
import github.javaguide.serialize.Serializer;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
/**
* Hessian is a dynamically-typed, binary serialization and Web Services protocol designed for object-oriented transmission.
*
* @author Vinlee Xiao
* @createTime 2022/2/23 21:11
*/
public class HessianSerializer implements Serializer {
@Override
public byte[] serialize(Object obj) {
try (ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream()) {
HessianOutput hessianOutput = new HessianOutput(byteArrayOutputStream);
hessianOutput.writeObject(obj);
return byteArrayOutputStream.toByteArray();
} catch (Exception e) {
throw new SerializeException("Serialization failed");
}
}
@Override
public <T> T deserialize(byte[] bytes, Class<T> clazz) {
try (ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(bytes)) {
HessianInput hessianInput = new HessianInput(byteArrayInputStream);
Object o = hessianInput.readObject();
return clazz.cast(o);
} catch (Exception e) {
throw new SerializeException("Deserialization failed");
}
}
}
package github.javaguide.serialize.kyro;
import com.esotericsoftware.kryo.Kryo;
import com.esotericsoftware.kryo.io.Input;
import com.esotericsoftware.kryo.io.Output;
import github.javaguide.exception.SerializeException;
import github.javaguide.remoting.dto.RpcRequest;
import github.javaguide.remoting.dto.RpcResponse;
import github.javaguide.serialize.Serializer;
import lombok.extern.slf4j.Slf4j;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
/**
* Kryo serialization class, Kryo serialization efficiency is very high, but only compatible with Java language
*
* @author shuang.kou
* @createTime 2020年05月13日 19:29:00
*/
@Slf4j
public class KryoSerializer implements Serializer {
/**
* kryo线程不安全,所以每一个线程自带一个序列化器比较好
* Because Kryo is not thread safe. So, use ThreadLocal to store Kryo objects
*/
private final ThreadLocal<Kryo> kryoThreadLocal = ThreadLocal.withInitial(() -> {
Kryo kryo = new Kryo();
kryo.register(RpcResponse.class);
kryo.register(RpcRequest.class);
return kryo;
});
@Override
public byte[] serialize(Object obj) {
try (ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
Output output = new Output(byteArrayOutputStream)) {
Kryo kryo = kryoThreadLocal.get();
// Object->byte:将对象序列化为byte数组
kryo.writeObject(output, obj);
output.flush();
return byteArrayOutputStream.toByteArray();
} catch (Exception e) {
log.error("Serialization failed", e);
throw new SerializeException("Serialization failed", e);
}
}
@Override
public <T> T deserialize(byte[] bytes, Class<T> clazz) {
try (ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(bytes);
Input input = new Input(byteArrayInputStream)) {
Kryo kryo = kryoThreadLocal.get();
// byte->Object:从byte数组中反序列化出对对象
return kryo.readObject(input, clazz);
} catch (Exception e) {
log.error("Deserialization failed", e);
throw new SerializeException("Deserialization failed", e);
}
}
public class SerializeException extends RuntimeException {
public SerializeException(String message) {
super(message);
}
public SerializeException(String message, Throwable cause) {
super(message, cause);
}
}
}
package github.javaguide.serialize.protostuff;
import github.javaguide.serialize.Serializer;
import io.protostuff.LinkedBuffer;
import io.protostuff.ProtostuffIOUtil;
import io.protostuff.Schema;
import io.protostuff.runtime.RuntimeSchema;
/**
* @author TangMinXuan
* @createTime 2020年11月09日 20:13
*/
public class ProtostuffSerializer implements Serializer {
/**
* Avoid re applying buffer space every time serialization
*/
private static final LinkedBuffer BUFFER = LinkedBuffer.allocate(LinkedBuffer.DEFAULT_BUFFER_SIZE);
@Override
public byte[] serialize(Object obj) {
Class<?> clazz = obj.getClass();
Schema schema = RuntimeSchema.getSchema(clazz);
byte[] bytes;
try {
bytes = ProtostuffIOUtil.toByteArray(obj, schema, BUFFER);
} finally {
BUFFER.clear();
}
return bytes;
}
@Override
public <T> T deserialize(byte[] bytes, Class<T> clazz) {
Schema<T> schema = RuntimeSchema.getSchema(clazz);
T obj = schema.newMessage();
ProtostuffIOUtil.mergeFrom(bytes, obj, schema);
return obj;
}
}
package github.javaguide.serialize;
import github.javaguide.extension.SPI;
/**
* 序列化接口,所有序列化类都要实现这个接口
*
* @author shuang.kou
* @createTime 2020年05月13日 19:29:00
*/
@SPI
public interface Serializer {
/**
* 序列化
*
* @param obj 要序列化的对象
* @return 字节数组
*/
byte[] serialize(Object obj);
/**
* 反序列化
*
* @param bytes 序列化后的字节数组
* @param clazz 目标类
* @param 类的类型。举个例子, {@code String.class} 的类型是 {@code Class}.
* 如果不知道类的类型的话,使用 {@code Class>}
* @return 反序列化的对象
*/
<T> T deserialize(byte[] bytes, Class<T> clazz);
}
此处需要注意的是Kroy算法,线程不安全,所以针对每一个线程,最好都单独开一个序列化器对象
本文使用了Gzip进行了传输压缩
package github.javaguide.compress.gzip;
import github.javaguide.compress.Compress;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream;
/**
* @author wangtao .
* @createTime on 2020/10/3
*/
public class GzipCompress implements Compress {
private static final int BUFFER_SIZE = 1024 * 4;
@Override
public byte[] compress(byte[] bytes) {
if (bytes == null) {
throw new NullPointerException("bytes is null");
}
try (ByteArrayOutputStream out = new ByteArrayOutputStream();
GZIPOutputStream gzip = new GZIPOutputStream(out)) {
gzip.write(bytes);
gzip.flush();
gzip.finish();
return out.toByteArray();
} catch (IOException e) {
throw new RuntimeException("gzip compress error", e);
}
}
@Override
public byte[] decompress(byte[] bytes) {
if (bytes == null) {
throw new NullPointerException("bytes is null");
}
try (ByteArrayOutputStream out = new ByteArrayOutputStream();
GZIPInputStream gunzip = new GZIPInputStream(new ByteArrayInputStream(bytes))) {
byte[] buffer = new byte[BUFFER_SIZE];
int n;
while ((n = gunzip.read(buffer)) > -1) {
out.write(buffer, 0, n);
}
return out.toByteArray();
} catch (IOException e) {
throw new RuntimeException("gzip decompress error", e);
}
}
}
package github.javaguide.compress;
import github.javaguide.extension.SPI;
/**
* @author wangtao .
* @createTime on 2020/10/3
*/
@SPI
public interface Compress {
byte[] compress(byte[] bytes);
byte[] decompress(byte[] bytes);
}
● 本文采用了netty进行通信
● 同时采用了TCP协议进行客户端和服务端的通信协议
● 为了解决粘包问题,采用了固定长度的解码器。
package github.javaguide.remoting.transport.netty.server;
import github.javaguide.remoting.transport.CustomShutdownHook;
import github.javaguide.config.RpcServiceConfig;
import github.javaguide.factory.SingletonFactory;
import github.javaguide.provider.ServiceProvider;
import github.javaguide.provider.impl.ZkServiceProviderImpl;
import github.javaguide.remoting.transport.netty.codec.RpcMessageDecoder;
import github.javaguide.remoting.transport.netty.codec.RpcMessageEncoder;
import github.javaguide.utils.RuntimeUtil;
import github.javaguide.utils.concurrent.threadpool.ThreadPoolFactoryUtil;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.handler.logging.LogLevel;
import io.netty.handler.logging.LoggingHandler;
import io.netty.handler.timeout.IdleStateHandler;
import io.netty.util.concurrent.DefaultEventExecutorGroup;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;
import java.net.InetAddress;
import java.util.concurrent.TimeUnit;
/**
* Server. Receive the client message, call the corresponding method according to the client message,
* and then return the result to the client.
*
* @author shuang.kou
* @createTime 2020年05月25日 16:42:00
*/
@Slf4j
@Component
public class NettyRpcServer {
public static final int PORT = 9998;
private final ServiceProvider serviceProvider = SingletonFactory.getInstance(ZkServiceProviderImpl.class);
public void registerService(RpcServiceConfig rpcServiceConfig) {
serviceProvider.publishService(rpcServiceConfig);
}
@SneakyThrows
public void start() {
CustomShutdownHook.getCustomShutdownHook().clearAll();
String host = InetAddress.getLocalHost().getHostAddress();
// 1. 三个线程池
// 1.1 主react线程池,连接事件监听线程池
EventLoopGroup bossGroup = new NioEventLoopGroup(1);
// 1.2 读写事件监听线程池
EventLoopGroup workerGroup = new NioEventLoopGroup();
// 1.3 handler处理线程池
DefaultEventExecutorGroup serviceHandlerGroup = new DefaultEventExecutorGroup(
RuntimeUtil.cpus() * 2,
ThreadPoolFactoryUtil.createThreadFactory("service-handler-group", false)
);
try {
ServerBootstrap b = new ServerBootstrap();
b.group(bossGroup, workerGroup)
.channel(NioServerSocketChannel.class)
// TCP默认开启了 Nagle 算法,该算法的作用是尽可能的发送大数据快,减少网络传输。TCP_NODELAY 参数的作用就是控制是否启用 Nagle 算法。
.childOption(ChannelOption.TCP_NODELAY, true)
// 是否开启 TCP 底层心跳机制
.childOption(ChannelOption.SO_KEEPALIVE, true)
//表示系统用于临时存放已完成三次握手的请求的队列的最大长度,如果连接建立频繁,服务器处理创建新连接较慢,可以适当调大这个参数
.option(ChannelOption.SO_BACKLOG, 128)
.handler(new LoggingHandler(LogLevel.INFO))
// 当客户端第一次进行请求的时候才会进行初始化
.childHandler(new ChannelInitializer<SocketChannel>() {
@Override
protected void initChannel(SocketChannel ch) {
// 30 秒之内没有收到客户端请求的话就关闭连接
ChannelPipeline p = ch.pipeline();
p.addLast(new IdleStateHandler(30, 0, 0, TimeUnit.SECONDS));
p.addLast(new RpcMessageEncoder());
p.addLast(new RpcMessageDecoder());
p.addLast(serviceHandlerGroup, new NettyRpcServerHandler());
}
});
// 绑定端口,同步等待绑定成功
ChannelFuture f = b.bind(host, PORT).sync();
// 等待服务端监听端口关闭
f.channel().closeFuture().sync();
} catch (InterruptedException e) {
log.error("occur exception when start server:", e);
} finally {
log.error("shutdown bossGroup and workerGroup");
bossGroup.shutdownGracefully();
workerGroup.shutdownGracefully();
serviceHandlerGroup.shutdownGracefully();
}
}
}
package github.javaguide.remoting.transport.netty.server;
import github.javaguide.enums.CompressTypeEnum;
import github.javaguide.enums.RpcResponseCodeEnum;
import github.javaguide.enums.SerializationTypeEnum;
import github.javaguide.factory.SingletonFactory;
import github.javaguide.remoting.constants.RpcConstants;
import github.javaguide.remoting.dto.RpcMessage;
import github.javaguide.remoting.dto.RpcRequest;
import github.javaguide.remoting.dto.RpcResponse;
import github.javaguide.remoting.handler.RpcRequestHandler;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.handler.timeout.IdleState;
import io.netty.handler.timeout.IdleStateEvent;
import io.netty.util.ReferenceCountUtil;
import lombok.extern.slf4j.Slf4j;
/**
* Customize the ChannelHandler of the server to process the data sent by the client.
*
* 如果继承自 SimpleChannelInboundHandler 的话就不要考虑 ByteBuf 的释放 ,{@link SimpleChannelInboundHandler} 内部的
* channelRead 方法会替你释放 ByteBuf ,避免可能导致的内存泄露问题。详见《Netty进阶之路 跟着案例学 Netty》
*
* @author shuang.kou
* @createTime 2020年05月25日 20:44:00
*/
@Slf4j
public class NettyRpcServerHandler extends ChannelInboundHandlerAdapter {
private final RpcRequestHandler rpcRequestHandler;
public NettyRpcServerHandler() {
this.rpcRequestHandler = SingletonFactory.getInstance(RpcRequestHandler.class);
}
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) {
try {
if (msg instanceof RpcMessage) {
log.info("server receive msg: [{}] ", msg);
byte messageType = ((RpcMessage) msg).getMessageType();
RpcMessage rpcMessage = new RpcMessage();
rpcMessage.setCodec(SerializationTypeEnum.HESSIAN.getCode());
rpcMessage.setCompress(CompressTypeEnum.GZIP.getCode());
if (messageType == RpcConstants.HEARTBEAT_REQUEST_TYPE) {
// 心跳请求,直接返回响应pong
rpcMessage.setMessageType(RpcConstants.HEARTBEAT_RESPONSE_TYPE);
rpcMessage.setData(RpcConstants.PONG);
} else {
// rpc调用请求
RpcRequest rpcRequest = (RpcRequest) ((RpcMessage) msg).getData();
// Execute the target method (the method the client needs to execute) and return the method result
// 通过反射获取到对应的方法,然后执行,最后返回结果
Object result = rpcRequestHandler.handle(rpcRequest);
log.info(String.format("server get result: %s", result.toString()));
rpcMessage.setMessageType(RpcConstants.RESPONSE_TYPE);
if (ctx.channel().isActive() && ctx.channel().isWritable()) {
RpcResponse<Object> rpcResponse = RpcResponse.success(result, rpcRequest.getRequestId());
rpcMessage.setData(rpcResponse);
} else {
RpcResponse<Object> rpcResponse = RpcResponse.fail(RpcResponseCodeEnum.FAIL);
rpcMessage.setData(rpcResponse);
log.error("not writable now, message dropped");
}
}
ctx.writeAndFlush(rpcMessage).addListener(ChannelFutureListener.CLOSE_ON_FAILURE);
}
} finally {
//Ensure that ByteBuf is released, otherwise there may be memory leaks
ReferenceCountUtil.release(msg);
}
}
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
if (evt instanceof IdleStateEvent) {
IdleState state = ((IdleStateEvent) evt).state();
if (state == IdleState.READER_IDLE) {
log.info("idle check happen, so close the connection");
ctx.close();
}
} else {
super.userEventTriggered(ctx, evt);
}
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
log.error("server catch exception");
cause.printStackTrace();
ctx.close();
}
}
上面代码的核心逻辑是
另外为了解决TCP的粘包问题,使用了固定长度的编/解码器
package github.javaguide.remoting.transport.netty.codec;
import github.javaguide.compress.Compress;
import github.javaguide.enums.CompressTypeEnum;
import github.javaguide.enums.SerializationTypeEnum;
import github.javaguide.extension.ExtensionLoader;
import github.javaguide.remoting.constants.RpcConstants;
import github.javaguide.remoting.dto.RpcMessage;
import github.javaguide.serialize.Serializer;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.MessageToByteEncoder;
import lombok.extern.slf4j.Slf4j;
import java.util.concurrent.atomic.AtomicInteger;
/**
*
* 用户自定义编码器
*
*
* 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
* +-----+-----+-----+-----+--------+----+----+----+------+-----------+-------+----- --+-----+-----+-------+
* | magic code |version | full length | messageType| codec|compress| RequestId |
* +-----------------------+--------+---------------------+-----------+-----------+-----------+------------+
* | |
* | body |
* | |
* | ... ... |
* +-------------------------------------------------------------------------------------------------------+
* 4B magic code(魔法数) 1B version(版本) 4B full length(消息长度) 1B messageType(消息类型)
* 1B compress(压缩类型) 1B codec(序列化类型) 4B requestId(请求的Id)
* body(object类型数据)
*
*
* @author WangTao
* @createTime on 2020/10/2
* @see LengthFieldBasedFrameDecoder解码器
*/
@Slf4j
public class RpcMessageEncoder extends MessageToByteEncoder<RpcMessage> {
private static final AtomicInteger ATOMIC_INTEGER = new AtomicInteger(0);
@Override
protected void encode(ChannelHandlerContext ctx, RpcMessage rpcMessage, ByteBuf out) {
try {
out.writeBytes(RpcConstants.MAGIC_NUMBER);
out.writeByte(RpcConstants.VERSION);
// leave a place to write the value of full length
out.writerIndex(out.writerIndex() + 4);
byte messageType = rpcMessage.getMessageType();
out.writeByte(messageType);
out.writeByte(rpcMessage.getCodec());
out.writeByte(CompressTypeEnum.GZIP.getCode());
out.writeInt(ATOMIC_INTEGER.getAndIncrement());
// build full length
byte[] bodyBytes = null;
int fullLength = RpcConstants.HEAD_LENGTH;
// if messageType is not heartbeat message,fullLength = head length + body length
if (messageType != RpcConstants.HEARTBEAT_REQUEST_TYPE
&& messageType != RpcConstants.HEARTBEAT_RESPONSE_TYPE) {
// serialize the object
String codecName = SerializationTypeEnum.getName(rpcMessage.getCodec());
log.info("codec name: [{}] ", codecName);
Serializer serializer = ExtensionLoader.getExtensionLoader(Serializer.class)
.getExtension(codecName);
bodyBytes = serializer.serialize(rpcMessage.getData());
// compress the bytes
String compressName = CompressTypeEnum.getName(rpcMessage.getCompress());
Compress compress = ExtensionLoader.getExtensionLoader(Compress.class)
.getExtension(compressName);
bodyBytes = compress.compress(bodyBytes);
fullLength += bodyBytes.length;
}
if (bodyBytes != null) {
out.writeBytes(bodyBytes);
}
int writeIndex = out.writerIndex();
out.writerIndex(writeIndex - fullLength + RpcConstants.MAGIC_NUMBER.length + 1);
out.writeInt(fullLength);
out.writerIndex(writeIndex);
} catch (Exception e) {
log.error("Encode request error!", e);
}
}
}
package github.javaguide.remoting.transport.netty.codec;
import github.javaguide.compress.Compress;
import github.javaguide.enums.CompressTypeEnum;
import github.javaguide.enums.SerializationTypeEnum;
import github.javaguide.extension.ExtensionLoader;
import github.javaguide.remoting.constants.RpcConstants;
import github.javaguide.remoting.dto.RpcMessage;
import github.javaguide.remoting.dto.RpcRequest;
import github.javaguide.remoting.dto.RpcResponse;
import github.javaguide.serialize.Serializer;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
import lombok.extern.slf4j.Slf4j;
import java.util.Arrays;
/**
* 用户自定义解码器,为了防止粘包问题
*
* 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
* +-----+-----+-----+-----+--------+----+----+----+------+-----------+-------+----- --+-----+-----+-------+
* | magic code |version | full length | messageType| codec|compress| RequestId |
* +-----------------------+--------+---------------------+-----------+-----------+-----------+------------+
* | |
* | body |
* | |
* | ... ... |
* +-------------------------------------------------------------------------------------------------------+
* 4B magic code(魔法数) 1B version(版本) 4B full length(消息长度) 1B messageType(消息类型)
* 1B compress(压缩类型) 1B codec(序列化类型) 4B requestId(请求的Id)
* body(object类型数据)
*
*
* {@link LengthFieldBasedFrameDecoder} is a length-based decoder , used to solve TCP unpacking and sticking problems.
*
*
* @author wangtao
* @createTime on 2020/10/2
* @see LengthFieldBasedFrameDecoder解码器
*/
@Slf4j
public class RpcMessageDecoder extends LengthFieldBasedFrameDecoder {
public RpcMessageDecoder() {
// lengthFieldOffset: magic code is 4B, and version is 1B, and then full length. so value is 5
// lengthFieldLength: full length is 4B. so value is 4
// lengthAdjustment: full length include all data and read 9 bytes before, so the left length is (fullLength-9). so values is -9
// initialBytesToStrip: we will check magic code and version manually, so do not strip any bytes. so values is 0
this(RpcConstants.MAX_FRAME_LENGTH, 5, 4, -9, 0);
}
/**
* @param maxFrameLength Maximum frame length. It decide the maximum length of data that can be received.
* If it exceeds, the data will be discarded.
* @param lengthFieldOffset Length field offset. The length field is the one that skips the specified length of byte.
* @param lengthFieldLength The number of bytes in the length field.
* @param lengthAdjustment The compensation value to add to the value of the length field
* @param initialBytesToStrip Number of bytes skipped.
* If you need to receive all of the header+body data, this value is 0
* if you only want to receive the body data, then you need to skip the number of bytes consumed by the header.
*/
public RpcMessageDecoder(int maxFrameLength, int lengthFieldOffset, int lengthFieldLength,
int lengthAdjustment, int initialBytesToStrip) {
super(maxFrameLength, lengthFieldOffset, lengthFieldLength, lengthAdjustment, initialBytesToStrip);
}
@Override
protected Object decode(ChannelHandlerContext ctx, ByteBuf in) throws Exception {
Object decoded = super.decode(ctx, in);
if (decoded instanceof ByteBuf) {
ByteBuf frame = (ByteBuf) decoded;
if (frame.readableBytes() >= RpcConstants.TOTAL_LENGTH) {
try {
return decodeFrame(frame);
} catch (Exception e) {
log.error("Decode frame error!", e);
throw e;
} finally {
frame.release();
}
}
}
return decoded;
}
private Object decodeFrame(ByteBuf in) {
// note: must read ByteBuf in order
checkMagicNumber(in);
checkVersion(in);
int fullLength = in.readInt();
// build RpcMessage object
byte messageType = in.readByte();
byte codecType = in.readByte();
byte compressType = in.readByte();
int requestId = in.readInt();
RpcMessage rpcMessage = RpcMessage.builder()
.codec(codecType)
.requestId(requestId)
.messageType(messageType).build();
if (messageType == RpcConstants.HEARTBEAT_REQUEST_TYPE) {
rpcMessage.setData(RpcConstants.PING);
return rpcMessage;
}
if (messageType == RpcConstants.HEARTBEAT_RESPONSE_TYPE) {
rpcMessage.setData(RpcConstants.PONG);
return rpcMessage;
}
int bodyLength = fullLength - RpcConstants.HEAD_LENGTH;
if (bodyLength > 0) {
byte[] bs = new byte[bodyLength];
in.readBytes(bs);
// decompress the bytes
String compressName = CompressTypeEnum.getName(compressType);
Compress compress = ExtensionLoader.getExtensionLoader(Compress.class)
.getExtension(compressName);
bs = compress.decompress(bs);
// deserialize the object
String codecName = SerializationTypeEnum.getName(rpcMessage.getCodec());
log.info("codec name: [{}] ", codecName);
Serializer serializer = ExtensionLoader.getExtensionLoader(Serializer.class)
.getExtension(codecName);
if (messageType == RpcConstants.REQUEST_TYPE) {
RpcRequest tmpValue = serializer.deserialize(bs, RpcRequest.class);
rpcMessage.setData(tmpValue);
} else {
RpcResponse tmpValue = serializer.deserialize(bs, RpcResponse.class);
rpcMessage.setData(tmpValue);
}
}
return rpcMessage;
}
private void checkVersion(ByteBuf in) {
// read the version and compare
byte version = in.readByte();
if (version != RpcConstants.VERSION) {
throw new RuntimeException("version isn't compatible" + version);
}
}
private void checkMagicNumber(ByteBuf in) {
// read the first 4 bit, which is the magic number, and compare
int len = RpcConstants.MAGIC_NUMBER.length;
byte[] tmp = new byte[len];
in.readBytes(tmp);
for (int i = 0; i < len; i++) {
if (tmp[i] != RpcConstants.MAGIC_NUMBER[i]) {
throw new IllegalArgumentException("Unknown magic code: " + Arrays.toString(tmp));
}
}
}
}
package github.javaguide.remoting.transport.netty.client;
import github.javaguide.enums.CompressTypeEnum;
import github.javaguide.enums.SerializationTypeEnum;
import github.javaguide.enums.ServiceDiscoveryEnum;
import github.javaguide.extension.ExtensionLoader;
import github.javaguide.factory.SingletonFactory;
import github.javaguide.registry.ServiceDiscovery;
import github.javaguide.remoting.constants.RpcConstants;
import github.javaguide.remoting.dto.RpcMessage;
import github.javaguide.remoting.dto.RpcRequest;
import github.javaguide.remoting.dto.RpcResponse;
import github.javaguide.remoting.transport.RpcRequestTransport;
import github.javaguide.remoting.transport.netty.codec.RpcMessageDecoder;
import github.javaguide.remoting.transport.netty.codec.RpcMessageEncoder;
import io.netty.bootstrap.Bootstrap;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.logging.LogLevel;
import io.netty.handler.logging.LoggingHandler;
import io.netty.handler.timeout.IdleStateHandler;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import java.net.InetSocketAddress;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
/**
* initialize and close Bootstrap object
*
* @author shuang.kou
* @createTime 2020年05月29日 17:51:00
*/
@Slf4j
public final class NettyRpcClient implements RpcRequestTransport {
private final ServiceDiscovery serviceDiscovery;
/**
* 使用completableFuture进行响应的结果等待
* */
private final UnprocessedRequests unprocessedRequests;
/**
* 缓存channel,不用每次都进行重新的tcp连接了
* */
private final ChannelProvider channelProvider;
private final Bootstrap bootstrap;
private final EventLoopGroup eventLoopGroup;
public NettyRpcClient() {
// initialize resources such as EventLoopGroup, Bootstrap
eventLoopGroup = new NioEventLoopGroup();
bootstrap = new Bootstrap();
bootstrap.group(eventLoopGroup)
.channel(NioSocketChannel.class)
.handler(new LoggingHandler(LogLevel.INFO))
// The timeout period of the connection.
// If this time is exceeded or the connection cannot be established, the connection fails.
.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, 5000)
.handler(new ChannelInitializer<SocketChannel>() {
@Override
protected void initChannel(SocketChannel ch) {
ChannelPipeline p = ch.pipeline();
// If no data is sent to the server within 15 seconds, a heartbeat request is sent
p.addLast(new IdleStateHandler(0, 5, 0, TimeUnit.SECONDS));
p.addLast(new RpcMessageEncoder());
p.addLast(new RpcMessageDecoder());
p.addLast(new NettyRpcClientHandler());
}
});
this.serviceDiscovery = ExtensionLoader.getExtensionLoader(ServiceDiscovery.class).getExtension(ServiceDiscoveryEnum.ZK.getName());
this.unprocessedRequests = SingletonFactory.getInstance(UnprocessedRequests.class);
this.channelProvider = SingletonFactory.getInstance(ChannelProvider.class);
}
/**
* connect server and get the channel ,so that you can send rpc message to server
*
* @param inetSocketAddress server address
* @return the channel
*/
@SneakyThrows
public Channel doConnect(InetSocketAddress inetSocketAddress) {
CompletableFuture<Channel> completableFuture = new CompletableFuture<>();
bootstrap.connect(inetSocketAddress).addListener((ChannelFutureListener) future -> {
if (future.isSuccess()) {
log.info("The client has connected [{}] successful!", inetSocketAddress.toString());
completableFuture.complete(future.channel());
} else {
throw new IllegalStateException();
}
});
return completableFuture.get();
}
@Override
public Object sendRpcRequest(RpcRequest rpcRequest) {
CompletableFuture<RpcResponse<Object>> resultFuture = new CompletableFuture<>();
// 1. 获取服务的地址
InetSocketAddress inetSocketAddress = serviceDiscovery.lookupService(rpcRequest);
// 2. 获取channel
Channel channel = getChannel(inetSocketAddress);
if (channel.isActive()) {
// 3.发送请求
unprocessedRequests.put(rpcRequest.getRequestId(), resultFuture);
RpcMessage rpcMessage = RpcMessage.builder().data(rpcRequest)
.codec(SerializationTypeEnum.HESSIAN.getCode())
.compress(CompressTypeEnum.GZIP.getCode())
.messageType(RpcConstants.REQUEST_TYPE).build();
channel.writeAndFlush(rpcMessage).addListener((ChannelFutureListener) future -> {
if (future.isSuccess()) {
log.info("client send message: [{}]", rpcMessage);
} else {
future.channel().close();
resultFuture.completeExceptionally(future.cause());
log.error("Send failed:", future.cause());
}
});
} else {
throw new IllegalStateException();
}
// 4. 得到响应的结果
try {
return resultFuture.get();
} catch (InterruptedException | ExecutionException e) {
throw new RuntimeException("rpc请求失败," + e.getMessage());
}
}
public Channel getChannel(InetSocketAddress inetSocketAddress) {
Channel channel = channelProvider.get(inetSocketAddress);
if (channel == null) {
channel = doConnect(inetSocketAddress);
channelProvider.set(inetSocketAddress, channel);
}
return channel;
}
public void close() {
eventLoopGroup.shutdownGracefully();
}
}
核心逻辑,就是sendRpcRequest,这个继承了RpcRequestTransport接口
下面是三个类
ChannelProvider
:缓存通道package github.javaguide.remoting.transport.netty.client;
import io.netty.channel.Channel;
import lombok.extern.slf4j.Slf4j;
import java.net.InetSocketAddress;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
/**
* store and get Channel object
*
* @author shuang.kou
* @createTime 2020年05月29日 16:36:00
*/
@Slf4j
public class ChannelProvider {
private final Map<String, Channel> channelMap;
public ChannelProvider() {
channelMap = new ConcurrentHashMap<>();
}
public Channel get(InetSocketAddress inetSocketAddress) {
String key = inetSocketAddress.toString();
// determine if there is a connection for the corresponding address
if (channelMap.containsKey(key)) {
Channel channel = channelMap.get(key);
// if so, determine if the connection is available, and if so, get it directly
if (channel != null && channel.isActive()) {
return channel;
} else {
channelMap.remove(key);
}
}
return null;
}
public void set(InetSocketAddress inetSocketAddress, Channel channel) {
String key = inetSocketAddress.toString();
channelMap.put(key, channel);
}
public void remove(InetSocketAddress inetSocketAddress) {
String key = inetSocketAddress.toString();
channelMap.remove(key);
log.info("Channel map size :[{}]", channelMap.size());
}
}
NettyRpcClientHandler
: 客户端解析,与服务端通信,得到对应的RpcResponsepackage github.javaguide.remoting.transport.netty.client;
import github.javaguide.enums.CompressTypeEnum;
import github.javaguide.enums.SerializationTypeEnum;
import github.javaguide.factory.SingletonFactory;
import github.javaguide.remoting.constants.RpcConstants;
import github.javaguide.remoting.dto.RpcMessage;
import github.javaguide.remoting.dto.RpcResponse;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.handler.timeout.IdleState;
import io.netty.handler.timeout.IdleStateEvent;
import io.netty.util.ReferenceCountUtil;
import lombok.extern.slf4j.Slf4j;
import java.net.InetSocketAddress;
/**
* Customize the client ChannelHandler to process the data sent by the server
*
*
* 如果继承自 SimpleChannelInboundHandler 的话就不要考虑 ByteBuf 的释放 ,{@link SimpleChannelInboundHandler} 内部的
* channelRead 方法会替你释放 ByteBuf ,避免可能导致的内存泄露问题。详见《Netty进阶之路 跟着案例学 Netty》
*
* @author shuang.kou
* @createTime 2020年05月25日 20:50:00
*/
@Slf4j
public class NettyRpcClientHandler extends ChannelInboundHandlerAdapter {
private final UnprocessedRequests unprocessedRequests;
private final NettyRpcClient nettyRpcClient;
public NettyRpcClientHandler() {
this.unprocessedRequests = SingletonFactory.getInstance(UnprocessedRequests.class);
this.nettyRpcClient = SingletonFactory.getInstance(NettyRpcClient.class);
}
/**
* Read the message transmitted by the server
*/
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) {
try {
log.info("client receive msg: [{}]", msg);
if (msg instanceof RpcMessage) {
RpcMessage tmp = (RpcMessage) msg;
byte messageType = tmp.getMessageType();
if (messageType == RpcConstants.HEARTBEAT_RESPONSE_TYPE) {
log.info("heart [{}]", tmp.getData());
} else if (messageType == RpcConstants.RESPONSE_TYPE) {
RpcResponse<Object> rpcResponse = (RpcResponse<Object>) tmp.getData();
unprocessedRequests.complete(rpcResponse);
}
}
} finally {
ReferenceCountUtil.release(msg);
}
}
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
if (evt instanceof IdleStateEvent) {
IdleState state = ((IdleStateEvent) evt).state();
if (state == IdleState.WRITER_IDLE) {
log.info("write idle happen [{}]", ctx.channel().remoteAddress());
Channel channel = nettyRpcClient.getChannel((InetSocketAddress) ctx.channel().remoteAddress());
RpcMessage rpcMessage = new RpcMessage();
rpcMessage.setCodec(SerializationTypeEnum.PROTOSTUFF.getCode());
rpcMessage.setCompress(CompressTypeEnum.GZIP.getCode());
rpcMessage.setMessageType(RpcConstants.HEARTBEAT_REQUEST_TYPE);
rpcMessage.setData(RpcConstants.PING);
channel.writeAndFlush(rpcMessage).addListener(ChannelFutureListener.CLOSE_ON_FAILURE);
}
} else {
super.userEventTriggered(ctx, evt);
}
}
/**
* Called when an exception occurs in processing a client message
*/
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
log.error("client catch exception:", cause);
cause.printStackTrace();
ctx.close();
}
}
UnprocessedRequests
:使用CompletableFuture进行结果等待,增加可读性
package github.javaguide.remoting.transport.netty.client;
import github.javaguide.remoting.dto.RpcResponse;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
/**
* unprocessed requests by the server.
*
* @author shuang.kou
* @createTime 2020年06月04日 17:30:00
*/
public class UnprocessedRequests {
private static final Map<String, CompletableFuture<RpcResponse<Object>>> UNPROCESSED_RESPONSE_FUTURES = new ConcurrentHashMap<>();
public void put(String requestId, CompletableFuture<RpcResponse<Object>> future) {
UNPROCESSED_RESPONSE_FUTURES.put(requestId, future);
}
public void complete(RpcResponse<Object> rpcResponse) {
CompletableFuture<RpcResponse<Object>> future = UNPROCESSED_RESPONSE_FUTURES.remove(rpcResponse.getRequestId());
if (null != future) {
future.complete(rpcResponse);
} else {
throw new IllegalStateException();
}
}
}
RpcRequestTransport
: rpc远程请求的接口
RpcRequestTransport这个接口,可以根据不同的客户端、服务端类型进行扩展,比如使用Socket而不是使用Netty进行发送消息
@SPI
pub
lic interface RpcRequestTransport {
/**
* send rpc request to server and get result
*
* @param rpcRequest message body
* @return data from server
*/
Object sendRpcRequest(RpcRequest rpcRequest) throws Exception;
}