HTTP2: netty server端同一个端口支持 http1.1/http2

同时支持http1和http2比较推荐的方法是通过HttpServerUpgradeHandler将http1.1升级到http2,网上有很多资料。这里采用的是另一种方式。

在对接收到的请求字节进行解码时,判断client采用的是http1还是http2,然后再将相应的ChannelHandler添加到ChannelPipeline中。
netty 建立http2 server的代码见上一篇文章:HTTP2: netty http2 server demo

新增2个接口和4个实现类。

Interface Http1 http2 Desc
HttpProtocolDetector Http1ProtocolDetector Http2ProtocolDetector http 协议探测器
HttpServerConfigurator Http1ServerConfigurator Http2ServerConfigurator http服务器配置器

相应代码如下:
HttpProtocolDetector:

import io.netty.buffer.ByteBuf;

public interface HttpProtocolDetector {

    Result detect(ByteBuf in);

    enum Result {
        RECOGNIZED, UNRECOGNIZED, NEED_MORE_DATA
    }

    /**
     * 前缀匹配方法
     * @param bufferA
     * @param bufferB
     * @param count
     * @return
     */
    static boolean prefixEquals(ByteBuf bufferA, ByteBuf bufferB, int count) {
        final int aLen = bufferA.readableBytes();
        final int bLen = bufferB.readableBytes();
        if (aLen < count || bLen < count) {
            return false;
        }

        int aIndex = bufferA.readerIndex();
        int bIndex = bufferB.readerIndex();

        for (int i = count; i > 0; i--) {
            if (bufferA.getByte(aIndex) != bufferB.getByte(bIndex)) {
                return false;
            }
            aIndex++;
            bIndex++;
        }

        return true;
    }


    /**
     * 关键字匹配方法
     * @param bufferA
     * @param key
     * @param count
     * @return
     */
    static boolean contains(ByteBuf bufferA, ByteBuf key, int count) {
        final int keyLen = key.readableBytes() - 1;
        int aIndex = bufferA.readerIndex();
        int keyIndex = 0;

        for (int i = count; i > 0; i--) {
            if (bufferA.getByte(aIndex++) != key.getByte(keyIndex)) {
                keyIndex = 0;
            } else if (keyIndex == keyLen) {
                return true;
            } else {
                keyIndex++;
            }
        }
        return false;
    }

}

http1 协议探测器:

import io.netty.buffer.ByteBuf;

import static io.netty.buffer.Unpooled.directBuffer;
import static io.netty.buffer.Unpooled.unreleasableBuffer;
import static io.netty.util.CharsetUtil.UTF_8;
import static java.lang.Math.min;

public class Http1ProtocolDetector2 implements HttpProtocolDetector {

    private static final ByteBuf CONNECTION_PREFACE =
            unreleasableBuffer(directBuffer(24).writeBytes("HTTP/1.1".getBytes(UTF_8)))
                    .asReadOnly();

    private int tryMaxLength;

    public Http1ProtocolDetector2(int tryMaxLength) {
        this.tryMaxLength = tryMaxLength;
    }



    @Override
    public Result detect(ByteBuf in) {
        int bytesRead = min(in.readableBytes(), tryMaxLength);

        // If the input so far doesn't match the preface, break the connection.
        if (bytesRead == 0 || !HttpProtocolDetector.contains(in, CONNECTION_PREFACE, bytesRead)) {
            return bytesRead < tryMaxLength ? Result.NEED_MORE_DATA : Result.UNRECOGNIZED;
        }
        return Result.RECOGNIZED;
    }

}

http2协议探测器:

import io.netty.buffer.ByteBuf;

import static io.netty.buffer.Unpooled.directBuffer;
import static io.netty.buffer.Unpooled.unreleasableBuffer;
import static io.netty.util.CharsetUtil.UTF_8;
import static java.lang.Math.min;

public class Http2ProtocolDetector implements HttpProtocolDetector{

    private static final ByteBuf CONNECTION_PREFACE =
            unreleasableBuffer(directBuffer(24).writeBytes("PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n".getBytes(UTF_8)))
                    .asReadOnly();
    @Override
    public Result detect(ByteBuf in) {
        int prefaceLen = CONNECTION_PREFACE.readableBytes();
        int bytesRead = min(in.readableBytes(), prefaceLen);

        // If the input so far doesn't match the preface, break the connection.
        if (bytesRead == 0 || !HttpProtocolDetector.prefixEquals(in, CONNECTION_PREFACE, bytesRead)) {
            return Result.UNRECOGNIZED;
        }
        if (bytesRead == prefaceLen) {
            return Result.RECOGNIZED;
        }
        return Result.NEED_MORE_DATA;
    }
}

HttpServerConfigurator:

import io.netty.channel.Channel;

public interface HttpServerConfigurator {

    void configServer(Channel ch) throws Exception;

    HttpProtocolDetector protocolDetector();

}

http1配置器:

import io.netty.channel.Channel;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.handler.codec.http.*;
import io.netty.handler.logging.LogLevel;
import io.netty.handler.logging.LoggingHandler;
import io.netty.util.ReferenceCountUtil;
import lombok.extern.slf4j.Slf4j;

import java.io.IOException;
import java.nio.charset.StandardCharsets;

@Slf4j
public class Http1ServerConfigurator implements HttpServerConfigurator {

    private static final int MAX_CONTENT_LENGTH = 512 * 1024 * 1024;
    private HttpProtocolDetector httpProtocolDetector = new Http1ProtocolDetector(512);

    @Override
    public void configServer(Channel ch) {
        ch.pipeline()
//                .addLast(new HttpRequestDecoder())
//                .addLast(new HttpResponseEncoder())
                .addLast(new LoggingHandler(HttpServer.class, LogLevel.DEBUG))
                .addLast(new HttpServerCodec())
                .addLast(new HttpObjectAggregator(MAX_CONTENT_LENGTH))
                .addLast(new ChannelInboundHandlerAdapter() {
                    @Override
                    public void channelReadComplete(ChannelHandlerContext ctx) {
                        ctx.flush();
                    }

                    @Override
                    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
                        log.error("netty channel caught exception", cause);
                        ctx.close();
                    }

                    @Override
                    public void channelRead(ChannelHandlerContext ctx, Object msg) throws IOException, ClassNotFoundException {
                        if (!(msg instanceof FullHttpRequest)) {
                            log.warn("msg is not http request, msg:{}", msg);
                            return;
                        }
                        FullHttpRequest httpRequest = (FullHttpRequest) msg;
                        log.info("access request:{}", httpRequest.uri());

                        try {
                            FullHttpResponse httpResponse = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK);
                            httpResponse.content()
                                    .writeBytes(String.format("response from %s", httpRequest.uri()).getBytes(StandardCharsets.UTF_8));
                            ctx.writeAndFlush(httpResponse)
                                    .addListener(ChannelFutureListener.CLOSE);

                        } finally {
                            ReferenceCountUtil.release(msg);
                        }
                    }
                });

    }

    @Override
    public HttpProtocolDetector protocolDetector() {
        return httpProtocolDetector;
    }
}

http2配置器:

import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.*;
import io.netty.handler.codec.http.*;
import io.netty.handler.codec.http2.*;
import io.netty.handler.logging.LogLevel;
import io.netty.util.CharsetUtil;

import javax.net.ssl.SSLException;
import java.security.cert.CertificateException;

public class Http2ServerConfigurator implements HttpServerConfigurator {

    private static final int DEFAULT_SETTING_HEADER_LIST_SIZE = 4096;
    private static final int MIB_8 = 1 << 23;
    private static final int DEFAULT_MAX_FRAME_SIZE = MIB_8;
    private static final int DEFAULT_WINDOW_INIT_SIZE = MIB_8;
    private static final int KIB_32 = 1 << 15;
    private static final int DEFAULT_MAX_HEADER_LIST_SIZE = KIB_32;

    public static final Http2FrameLogger SERVER_LOGGER = new Http2FrameLogger(LogLevel.DEBUG, "H2_SERVER");

    static final ByteBuf RESPONSE_BYTES = Unpooled.unreleasableBuffer(Unpooled.copiedBuffer("Hello World", CharsetUtil.UTF_8));

    private HttpProtocolDetector http2ProtocolDetector = new Http2ProtocolDetector();


    @Override
    public void configServer(Channel ch) throws SSLException, CertificateException {
        final ChannelPipeline p = ch.pipeline();
        final Http2FrameCodec codec = Http2FrameCodecBuilder.forServer()
                .gracefulShutdownTimeoutMillis(10000)
                .initialSettings(new Http2Settings().headerTableSize(DEFAULT_SETTING_HEADER_LIST_SIZE)
                        .maxConcurrentStreams(Integer.MAX_VALUE)
                        .initialWindowSize(DEFAULT_WINDOW_INIT_SIZE)
//                        .maxFrameSize(DEFAULT_MAX_FRAME_SIZE)
                        .maxFrameSize(16384)
                        .maxHeaderListSize(DEFAULT_MAX_HEADER_LIST_SIZE * 2))
                .frameLogger(SERVER_LOGGER)
                .build();
        final Http2MultiplexHandler handler = new Http2MultiplexHandler(
                new ChannelInitializer<Channel>() {
                    @Override
                    protected void initChannel(Channel ch) throws Exception {
                        final ChannelPipeline p = ch.pipeline();
                        p.addLast(new CustHttp2Handler());
                    }
                });

        p.addLast(codec);
        p.addLast(handler);
    }

    @Override
    public HttpProtocolDetector protocolDetector() {
        return http2ProtocolDetector;
    }


    private static class CustHttp2Handler extends ChannelDuplexHandler {


        @Override
        public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
            if (msg instanceof Http2HeadersFrame) {
                Http2HeadersFrame msgHeader = (Http2HeadersFrame) msg;
                if (msgHeader.isEndStream()) {
                    System.out.println("-hhhhh");
                    writeData(ctx, msgHeader.stream());
                } else {
                    System.out.println("hhhhh");
                }
            } else if (msg instanceof Http2DataFrame) {
                Http2DataFrame msgData = (Http2DataFrame) msg;
                if (msgData.isEndStream()) {
                    System.out.println("-ddddd");
                    writeData(ctx, msgData.stream());
                } else {
                    System.out.println("ddddd");
                }
            } else {
                super.channelRead(ctx, msg);
            }
        }

        private static void writeData(ChannelHandlerContext ctx, Http2FrameStream stream) {
            ByteBuf content = ctx.alloc().buffer();
            content.writeBytes(RESPONSE_BYTES.duplicate());
            Http2Headers headers = new DefaultHttp2Headers().status(HttpResponseStatus.OK.codeAsText())
                    .add("t1", "tttt")
                    .add("t2", "tttt");
            ctx.write(
                    new DefaultHttp2HeadersFrame(headers)
                            .stream(stream)
            );
            ctx.write(
                    new DefaultHttp2DataFrame(content, true)
                            .stream(stream)
            );
        }
    }
}

HttpServer也做关键调整:

public class HttpServer {
	...
	public void init() throws InterruptedException, CertificateException, SSLException {
        // 初始化ssl
        initSsl();

        //初始化ServerBootstrap
        initServerBootstrap();

        bootstrap.childHandler(new ChannelInitializer<SocketChannel>() {
            @Override
            protected void initChannel(SocketChannel ch) throws Exception {
                if (sslCtx != null) {
                    ch.pipeline().addLast(sslCtx.newHandler(ch.alloc()));
                }

                ch.pipeline().addLast(new ByteToMessageDecoder() {
                    @Override
                    protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
                        for (HttpServerConfigurator httpServerConfigurator : httpServerConfigurators) {
                            in.markReaderIndex();
                            HttpProtocolDetector.Result detectResult = httpServerConfigurator.protocolDetector().detect(in);
                            in.resetReaderIndex();
                            if (detectResult == HttpProtocolDetector.Result.RECOGNIZED) {
                                httpServerConfigurator.configServer(ctx.channel());
                                ctx.channel().pipeline().remove(this);
                                break;
                            }
                        }
                    }
                });
            }
        });
        // bind
        doBind();
    }

    private void doBind() {
        try {
            String bindIp = "localhost";
            int bindPort = 8080;
            InetSocketAddress bindAddress = new InetSocketAddress(bindIp, bindPort);
            ChannelFuture channelFuture = bootstrap.bind(bindAddress).sync();
            if (channelFuture.isDone()) {
                log.info("http server start at port " + bindPort);
            }
            channel = channelFuture.channel();
            channel.closeFuture().sync();
            log.info("http server shutdown");
        } catch (Exception e) {
            log.error("http server start exception,", e);
        } finally {
            log.info("http server shutdown bossEventLoopGroup&workerEventLoopGroup gracefully");
            bossGroup.shutdownGracefully();
            workerGroup.shutdownGracefully();
        }
    }
	...
}

参考

HTTP/2 in Netty
netty系列之:搭建客户端使用http1.1的方式连接http2服务器
apn服务器源码,使用Netty实现HTTP2服务器/客户端的源码和教程 - Baeldung
netty系列之:使用netty实现支持http2的服务器

你可能感兴趣的:(http2,netty)