基于netty进行对象传输(操作数据库)

基于netty进行对象传输

编写了一个使用netty搭建的C/S项目,该项目用于远程访问数据库,实现客户端发送请求,服务端接收请求进行相应的数据库操作,返回数据给客户端。

项目结构

基于netty进行对象传输(操作数据库)_第1张图片

客户端代码

package com.syx.client;

import com.syx.client.handler.ClientHandler;
import com.syx.entity.SendMsg;
import com.syx.entity.User;
import io.netty.bootstrap.Bootstrap;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.codec.serialization.ClassResolvers;
import io.netty.handler.codec.serialization.ObjectDecoder;
import io.netty.handler.codec.serialization.ObjectEncoder;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.LinkedHashMap;

public class HrSystemClient {

    private String host;
    private int port;

    public HrSystemClient(String host, int port) {
        this.host = host;
        this.port = port;
    }


    public void run() throws InterruptedException {
        Bootstrap bootstrap = new Bootstrap();
        NioEventLoopGroup group = new NioEventLoopGroup();

        try {
            bootstrap.group(group)
                    .channel(NioSocketChannel.class)
                    .handler(new ChannelInitializer<SocketChannel>() {
                        @Override
                        protected void initChannel(SocketChannel ch) throws Exception {
                            ChannelPipeline pipeline = ch.pipeline();
                            pipeline.addLast(new ObjectEncoder())
                                    .addLast(new ObjectDecoder(Integer.MAX_VALUE, ClassResolvers.cacheDisabled(null)))
                                    .addLast(new ClientHandler());
                        }
                    });

            ChannelFuture future = bootstrap.connect(host, port).sync();
            try {
                Input(future);
            } catch (IOException e) {
                e.printStackTrace();
            }
            future.channel().closeFuture().sync();
        } finally {
            group.shutdownGracefully();
        }
    }


    private void Input(ChannelFuture future) throws IOException {
        InputStreamReader is = new InputStreamReader(System.in, "UTF-8");
        BufferedReader br = new BufferedReader(is);
        //控制台显示帮助信息
        help();
        // 将输入的内容写入到Channel
        System.out.println("请输入你的操作>>>");
        while (true) {
            String operator = br.readLine();
            if (operator.equals("5")) {
                System.out.println("退出");
                break;
            }
            if (operator.equals("4")) {
                help();
                continue;
            }
            SendMsg msg = excute(operator, br);

            if(msg != null){
                System.out.println("发送给服务器的内容");
                System.out.println(msg);
                future.channel().writeAndFlush(msg);
            }

        }
        br.close();
        is.close();
    }

    private void help() {
        System.out.println("****************help*********************");
        System.out.println("*             输入0:查询                 *");
        System.out.println("*             输入1:新增                 *");
        System.out.println("*             输入2:修改                 *");
        System.out.println("*             输入3:删除                 *");
        System.out.println("*             输入4:help                 *");
        System.out.println("*             输入5:退出                  *");
        System.out.println("*****************************************");

    }


    private SendMsg excute(String operator, BufferedReader br) throws IOException {
        SendMsg sendMsg = new SendMsg();
        try {
            if (operator.equals("0")) {
                System.out.println("请输入要查询的类型");
                msg();
                String oper = br.readLine();
                if (oper.equals("0")) {
                    System.out.println("是否根据Id查找? 输入y或者n");
                    if ("y".equals(br.readLine())) {
                        LinkedHashMap<String, String> params = new LinkedHashMap<>();

                        System.out.println("请输入Id");
                        params.put("Id", br.readLine());
                        sendMsg.setSql("select * from user where id = ?");
                        sendMsg.setOper(0);
                        sendMsg.setClazz(User.class);
                        sendMsg.setParams(params);

                    } else {
                        sendMsg.setSql("select * from user");
                        sendMsg.setOper(0);
                        sendMsg.setClazz(User.class);
                        sendMsg.setParams(null);
                    }
                }
                return sendMsg;
            }

            if(operator.equals("1")){

            }

            if(operator.equals("2")){

            }
            if(operator.equals("3")){

            }

            if(operator.equals("4")){
            }

        } catch (Exception e) {

        }
        return null;
    }

    private void msg() {
        System.out.println("*********************************************");
        System.out.println("*             输入0:查询User                 *");
        System.out.println("*          输入1:查询employee                *");
        System.out.println("*********************************************");
    }

    public static void main(String[] args) {
        HrSystemClient client = new HrSystemClient("127.0.0.1", 8888);

        try {
            client.run();
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
    }
}

客户端的handler

package com.syx.client.handler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;

public class ClientHandler extends ChannelInboundHandlerAdapter {
    @Override
    public void channelActive(ChannelHandlerContext ctx) throws Exception {
    }
    @Override
    public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
        System.out.println("服务器响应的信息");
        System.out.println(msg);
    }

    @Override
    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
        ctx.close();
    }

}

服务端代码

package com.syx.server;

import com.syx.server.handler.ServerHandler;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.*;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.handler.codec.serialization.ClassResolvers;
import io.netty.handler.codec.serialization.ObjectDecoder;
import io.netty.handler.codec.serialization.ObjectEncoder;

public class HrSystemServer {

    public void run(final int port) throws Exception {
        EventLoopGroup bossGroup = new NioEventLoopGroup(1);
        EventLoopGroup workerGroup = new NioEventLoopGroup();
        try {
            ServerBootstrap b = new ServerBootstrap();
            b.group(bossGroup, workerGroup)
                    .channel(NioServerSocketChannel.class)
                    .childHandler(new ChannelInitializer<SocketChannel>() {
                        @Override
                        public void initChannel(SocketChannel ch) throws Exception {
                            ChannelPipeline pipeline = ch.pipeline();
                            pipeline.addLast(new ObjectEncoder())
                                    .addLast(new ObjectDecoder(Integer.MAX_VALUE ,ClassResolvers.cacheDisabled(null)))
                                    .addLast(new ServerHandler());

                        }
                    });

            b.bind(port).sync().channel().closeFuture().sync();
            System.out.println("服务器已启动");
        } finally {
            bossGroup.shutdownGracefully();
            workerGroup.shutdownGracefully();
        }
    }
    public static void main(String[] args) throws Exception {
        HrSystemServer server = new HrSystemServer();

        server.run(8888);
    }
}

服务端的handler

package com.syx.server.handler;
import com.syx.constant.Constant;
import com.syx.entity.SendMsg;
import com.syx.entity.User;
import com.syx.utils.CommonUtil;
import com.syx.utils.DBUtil;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.group.ChannelGroup;
import io.netty.channel.group.DefaultChannelGroup;
import io.netty.util.concurrent.GlobalEventExecutor;
import org.apache.commons.dbutils.QueryRunner;
import org.apache.commons.dbutils.handlers.BeanListHandler;

import java.util.LinkedHashMap;
import java.util.List;

public class ServerHandler extends ChannelInboundHandlerAdapter {
    //所有的连接
    public static ChannelGroup channelGroup = new DefaultChannelGroup(GlobalEventExecutor.INSTANCE);


    //返回给客户端的数据
    private static String data = null;

    @Override
    public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
        // 获取当前连接的客户端的 channel
        Channel incoming = ctx.channel();
        // 将客户端的 Channel 存入 ChannelGroup 列表中
        channelGroup.add(incoming);
    }

    @Override
    public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
        SendMsg message = (SendMsg) msg;
        System.out.println(message);
        Channel channel = ctx.channel();
        channelGroup.forEach(e -> {
            if (e == channel) {//匹配当前连接对象
                start(message);
                if (data == null) {
                    data = "fail";
                }
                System.out.println(data);
                e.writeAndFlush(data);
            }
        });

    }

    private void start(SendMsg message) {
        int oper = message.getOper();
        //sql语句
        String sql = message.getSql();

        //参数
        LinkedHashMap<String, String> params = message.getParams();

        //操作的对象类型
        Class clazz = message.getClazz();

        //执行操作
        execute(oper, sql, params, clazz);
    }


    //执行操作
    private static void execute(int oper, String sql, LinkedHashMap<String, String> params, Class clazz) {
        QueryRunner queryRunner = DBUtil.queryRunner();
        switch (oper) {
            case Constant.select: //查询
                query(sql, params, clazz, queryRunner);
                break;
            case Constant.add: //新增
                insert(sql, params, queryRunner);
                break;
            case Constant.update://修改
                update(sql, params, queryRunner);
                break;
            case Constant.delete://删除
                delete(sql, params, queryRunner);
                break;
        }

    }


    @Override
    public void channelInactive(ChannelHandlerContext ctx) throws Exception {
        super.channelInactive(ctx);
    }

    @Override
    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
        System.out.println(cause.getMessage());
        ctx.close();
    }


    //插入
    private static void insert(String sql, LinkedHashMap<String, String> params, QueryRunner queryRunner) {
        if (params == null || params.isEmpty() || params.get("Id") == null) {
            data = "参数为空,新增失败";
            return;
        } else {
            try {
                queryRunner.update(sql, CommonUtil.Map2Array(params));
                data = "新增成功";
            } catch (Exception e) {
                e.printStackTrace();
                data = "新增失败";
            }
        }
    }

    //查询
    private static void query(String sql, LinkedHashMap<String, String> params, Class clazz, QueryRunner queryRunner) {
        if (params == null || params.isEmpty()) {
            try {
                List query = (List) queryRunner.query(sql, new BeanListHandler<>(clazz));
                data = query.toString();
            } catch (Exception e) {
                e.printStackTrace();
                data = "查询失败";
            }
        } else {
            String id = params.get("Id");
            if (id == null) {
                data = "请输入查询的Id";
                return;
            }
            Object o = DBUtil.find(clazz, id);
            if (o == null) {
                data = "查询不存在";
            } else {
                data = o.toString();
            }

        }
    }

    //更新
    private static void update(String sql, LinkedHashMap<String, String> params, QueryRunner queryRunner) {
        if (params == null || params.isEmpty()) {
            data = "参数为空,修改失败";
            return;
        }
        try {
            queryRunner.update(sql, CommonUtil.Map2Array(params));
            data = "更新成功";
        } catch (Exception e) {
            e.printStackTrace();
            data = "更新失败";
        }
    }

    //删除
    private static void delete(String sql, LinkedHashMap<String, String> params, QueryRunner queryRunner) {
        if (params == null || params.isEmpty()) {
            data = "参数为空删除失败";
            return;
        }
        if (params.get("Id") == null) {
            data = "缺少主键Id,无法删除";
            return;
        }
        try {
            queryRunner.update(sql);
            data = "删除成功";
        } catch (Exception e) {
            e.printStackTrace();
            data = "删除失败";
        }
    }

    public static void main(String[] args) {
//        LinkedHashMap params = new LinkedHashMap<>();
//        params.put("Id", CommonUtil.generateUUid());
//        params.put("username","孙永香");
//        params.put("sex","女");
//
//
//        execute(1, "insert into user(id,username,sex) values(?,?,?)", params, User.class);
//
//        System.out.println(data);

        execute(0, "select * from user", null, User.class);
        System.out.println(data);

    }
}

用于传输的实体类

package com.syx.entity;
import java.io.Serializable;
import java.util.LinkedHashMap;
public class SendMsg implements Serializable {
    private int oper;
    private String sql;
    private LinkedHashMap<String,String> params;
    private Class clazz;
    public SendMsg() {

    }

    public int getOper() {
        return oper;
    }

    public void setOper(int oper) {
        this.oper = oper;
    }

    public String getSql() {
        return sql;
    }

    public void setSql(String sql) {
        this.sql = sql;
    }

    public LinkedHashMap<String, String> getParams() {
        return params;
    }

    public void setParams(LinkedHashMap<String, String> params) {
        this.params = params;
    }

    public Class getClazz() {
        return clazz;
    }

    public void setClazz(Class clazz) {
        this.clazz = clazz;
    }

    @Override
    public String toString() {
        return "SendMsg{" +
                "oper=" + oper +
                ", sql='" + sql + '\'' +
                ", params=" + params +
                ", clazz=" + clazz +
                '}';
    }
}

dbutil工具类

package com.syx.utils;

import com.mchange.v2.c3p0.DataSources;
import com.syx.annotation.ID;
import com.syx.entity.User;
import org.apache.commons.dbutils.QueryRunner;
import org.apache.commons.dbutils.handlers.BeanHandler;
import org.apache.commons.dbutils.handlers.BeanListHandler;
import javax.sql.DataSource;
import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.List;

public class DBUtil {
    //数据源
    private static DataSource dataSource = null;
    //queryRunner对象
    private static QueryRunner queryRunner = null;
    //数据库的配置文件
    private static String dbconfig = "dbconfig.properties";

    /**
     * 初始化数据源
     */
    static {
        try {
            dataSource = DataSources.unpooledDataSource(PropertiesUtil.propertyValue(dbconfig, "jdbc.url"),
                    PropertiesUtil.propertyValue(dbconfig, "jdbc.user"),
                    PropertiesUtil.propertyValue(dbconfig, "jdbc.password"));
        } catch (SQLException e) {
            e.printStackTrace();
        }

    }

    /**
     * queryRunner对象(用户可以自己直接调用queryRunner种的方法)
     *
     * @return
     */
    public static QueryRunner queryRunner() {
        if (queryRunner == null) {
            queryRunner = new QueryRunner(dataSource);
            return queryRunner;
        }
        return queryRunner;
    }


    /**
     * 新增一条记录
     * @param t
     * @param 
     * @return
     */
    public static <T> int insert(T t) {
        queryRunner();
        Class<?> clazz = t.getClass();
        //sql的参数
        Object[] params  = getParams(clazz,t);
        //sql
        String sql = insertSql(clazz);

        try {
            return queryRunner.update(sql, params);
        } catch (SQLException e) {
            e.printStackTrace();
        }
        return 0;
    }

    /**
     * 获取get方法的值(用于为insert语句的?赋值)
     * @param clazz
     * @param t
     * @param 
     * @return
     */
    private static <T> Object[] getParams(Class<?> clazz,T t) {
        Method[] methods = clazz.getDeclaredMethods();
        ArrayList<Object> ls = new ArrayList<>();
        for (Method method : methods) {
            try {
                if(method.getName().contains("get")){
                    Object param = method.invoke(t);
                    ls.add(param);
                }
            } catch (IllegalAccessException e) {
                e.printStackTrace();
            } catch (InvocationTargetException e) {
                e.printStackTrace();
            }
        }
        return ls.toArray();
    }


    /**
     * 拼接插入sql
     *
     * @param t
     * @param 
     * @return
     */
    private static <T> String insertSql(Class<T> t) {
        //表名
        String tableName = t.getSimpleName();
        StringBuffer sql = new StringBuffer("insert into " + tableName + "(");
        Field[] fields = t.getDeclaredFields();
        for (int i = 0; i < fields.length; i++) {
            if (i < fields.length - 1) {
                sql.append(fields[i].getName() + ",");
            } else {
                sql.append(fields[i].getName() + ")");
            }
        }
        sql.append(" values (");
        //添加问号
        for (int i = 0; i < fields.length; i++) {
            if (i < fields.length - 1) {
                sql.append("?,");
            } else {
                sql.append("?)");
            }
        }
        return sql.toString();
    }


    /**
     * 根据id查询一条记录
     *
     * @param t
     * @param id
     * @param 
     * @return
     */
    public static <T> T find(Class<T> t, String id) {
        queryRunner();
        //拼接sql
        String sql = querySql(t, false);
        try {
            return queryRunner.query(sql, new BeanHandler<>(t), id);
        } catch (SQLException e) {
            e.printStackTrace();
        }
        return null;
    }


    /**
     * 查询所有
     * @param t
     * @param 
     * @return
     */
    public static <T> List<T> findAll(Class<T> t) {
        queryRunner();
        //拼接sql
        String sql = querySql(t, true);
        try {
            return queryRunner.query(sql, new BeanListHandler<>(t));
        } catch (SQLException e) {
            e.printStackTrace();
        }
        return null;
    }

    /**
     * 拼接查询sql
     *
     * @param t
     * @param flag
     * @param 
     * @return
     */
    private static <T> String querySql(Class<T> t, boolean flag) {
        String tableName = t.getSimpleName();
        //查询所有
        if (flag) {
            String sql = "select * from " + tableName;
            return sql;
        }
        //获取所有的字段
        Field[] fields = t.getDeclaredFields();
        String id = "";
        for (Field field : fields) {
            if (field.isAnnotationPresent(ID.class)) {
                id = field.getName();
            }

        }
        String sql = "select * from " + tableName + " where " + id + " = ?";
        return sql;
    }


    public static void main(String[] args) throws SQLException {
        System.out.println(queryRunner());

       queryRunner.query("select * from user",new BeanListHandler<>(User.class));

    }
}

commonutil

package com.syx.utils;

import java.util.HashMap;
import java.util.List;
import java.util.UUID;
import java.util.stream.Collectors;

public class CommonUtil {

    //生成UUID
    public static String generateUUid() {
        return UUID.randomUUID().toString().replace("-", "").toLowerCase();
    }


    public static Object[] Map2Array(HashMap<String,String> param){


      return param.values().toArray();
    }


    public static List<Object> Map2List(HashMap<Object,Object>map){
        return map.values().stream().collect(Collectors.toList());
    }
}

你可能感兴趣的:(Java工具,netty)