Fate-Serving推理服务源码解读

https://fate-serving.readthedocs.io/en/develop/?query=guest
什么是Fate-Serving
fate-serving是FATE的在线部分,在使用FATE进行联邦建模完成之后,可以使用fate-serving进行包括单笔预测、多笔预测以及多host预测在内的在线联合预测。
模型的初始化流程
在FATE中建好模型之后,通过fate-flow的推送模型脚本可以将模型推送至serving-server。 推送成功之后,serving-server会将该模型相关的预测接口注册进zookeeper, 外部系统可以通过服务发现获取接口地址并调用。
参与方划分
fate在调用在线预测接口时,需要数据使用方(Guest)、数据提供方(Host)双方联合预测,Guest方对模型和特征数据进行业务处理后,Guest方接口参数中的sendToRemoteFeatureData会发往Host端,Host方则是通过自定义的Adaptor跟己方业务系统交互(eg:通过访问远程rpc接口、或者通过访问存储)来获取特征数据,并将获取的特征交给算法模块进行计算,最终得出合并后的预测结果并返回给Guest。

Fate和Secretflow推理服务的比较

Fate-serving适用jdk1.8+SSM,服务提供HTTP接口和RPC接口(grpc),代码量11w行;Secretflow-serving使用C++17+brpc,服务提供RPC接口,代码量1w行。
Secret-serving将模型的执行拆分成了Exector,实现了动态的调度执行,Fate不具备这样的能力。
Fate比隐语多的能力:

  1. 服务的故障恢复/重启能力,能够保留Server的版本记录并从备份中恢复;
  2. 服务注册、发现和鉴权
  3. 模型动态加载和卸载(热更新)[内存占用率更低]
  4. 弹性扩容,负载均衡,高可用

服务发现有两个维度,一个是类似serving里面的queryModel接口,另一个是借助curator实现的zk监听回调的能力。第二个能力没有暴露给用户,而只是在内部使用。

server-mode 推理执行模块

model是推理模块的核心,我们先看这部分。fate官网提供了推理算法的讲解,https://fate-serving.readthedocs.io/en/develop/algo/base/,所以这里只关注调度链路。
model模块的架构如下:

注意BaseComponent只实现了LocalInferenceAware。

PipelineModelProcessor初始化

model在server模块中被ModeLoader加载,ModeLoader会调用initModel进行模型的初始化。
每个模型对应一个PipelineModelProcessor。
serving-server在收到推送模型的请求后,会在内存中初始化一个PipelineModelProcessor实例。和隐语拆分成executor类似,PipelineModelProcessor也会将model拆分成components。但是,fate拆分的components并不是调度的最小单位,因此没有隐语那种动态执行能力。
fate的model同样适用proto定义,由dslParser进行parse后动态加载每一个components。**这里我们可以看到,通过反射,fate提供了动态模型加载能力。**对于线上服务,动态注册模型能力还是很重要的,可以实现服务的热更新。


public int initModel(Context context, Map<String, byte[]> modelProtoMap) {
    if (modelProtoMap != null) {
        logger.info("start init pipeline,model components {}", modelProtoMap.keySet());
        try {
            Map<String, byte[]> newModelProtoMap = changeModelProto(modelProtoMap);
            logger.info("after parse pipeline {}", newModelProtoMap.keySet());
            Preconditions.checkArgument(newModelProtoMap.get(PIPLELINE_IN_MODEL) != null);
            PipelineProto.Pipeline pipeLineProto = PipelineProto.Pipeline.parseFrom(newModelProtoMap.get(PIPLELINE_IN_MODEL));
            String dsl = pipeLineProto.getInferenceDsl().toStringUtf8();
            dslParser.parseDagFromDSL(dsl);
            ArrayList<String> components = dslParser.getAllComponent();
            HashMap<String, String> componentModuleMap = dslParser.getComponentModuleMap();
            // 调用每一个components的initModel
            for (int i = 0; i < components.size(); ++i) {
                String componentName = components.get(i);
                String className = componentModuleMap.get(componentName);
                logger.info("try to get class:{}", className);
                try {
                    // 动态加载components
                    Class modelClass = Class.forName(this.modelPackage + "." + className);
                    BaseComponent mlNode = (BaseComponent) modelClass.getConstructor().newInstance();
                    mlNode.setComponentName(componentName);
                    byte[] protoMeta = newModelProtoMap.get(componentName + ".Meta");
                    byte[] protoParam = newModelProtoMap.get(componentName + ".Param");
                    int returnCode = mlNode.initModel(protoMeta, protoParam);
                    if (returnCode == Integer.valueOf(StatusCode.SUCCESS)) {
                        componentMap.put(componentName, mlNode);
                        pipeLineNode.add(mlNode);
                        logger.info(" add class {} to pipeline task list", className);
                    } else {
                        throw new RuntimeException("init model error");
                    }
                } catch (Exception ex) {
                    pipeLineNode.add(null);
                    logger.warn("Can not instance {} class", className);
                }
            }
        } catch (Exception ex) {
            logger.info("initModel error:{}", ex);
            throw new RuntimeException("initModel error");
        }
        logger.info("Finish init Pipeline");
        return Integer.valueOf(StatusCode.SUCCESS);
    } else {
        logger.error("model content is null ");
        throw new RuntimeException("model content is null");
    }
}

guest推理

PipelineModelProcessor的guestInference同样在server中被调用,它的接口是:
public ReturnResult guestInference(Context context, InferenceRequest inferenceRequest, Map futureMap, long timeout)
这里的futureMap并不是特征,而是remote inference的结果。
guestInference首先进行singleLocalPredict,顺序调用components的LocalInferenceAware方法。
之后和remote inference的结果进行合并,顺序调用components的mergeRemoteInference方法。

模型服务

fate-serving-server的controller层定义了一些HTTP请求,grpc.service定义了rpc请求,因为controller也是构造rpc调用,所以不多介绍。

Service公共抽象类和Context

我们先看ModelService部分,ModelServiceProvider继承了AbstractServingServiceProvider,AbstractServingServiceProvider是一个抽象类,它继承了AbstractServiceAdaptor。
AbstractServiceAdaptor是所有Service和ServiceProvider的公共父抽象类,我们来看它提供了哪些接口和公用方法。
Fate-Serving推理服务源码解读_第1张图片
公共方法/变量:

  • getFlowCounterManage/setFlowCounterManagerr: 获取/设置flowCounterManager,flowCounterManager是模型的计数器,用来统计访问信息
  • getMethodMap/setMethodMap:获取设置一个string->method的映射
  • preChain/postChain:服务的前处理和后处理逻辑
  • AbstractStub:grpc的stub,每个Service和ServiceProvider都对应到一个grpc stub

需要实现的接口:

  • doService service的实际实现
  • transformExceptionInfo

接下来我们看service方法,service会传入一个服务上下文context,context实际上就是一个k-v,记载了执行中的一些信息。

    @Override
    public OutboundPackage<resp> service(Context context, InboundPackage<req> data) throws RuntimeException {

        OutboundPackage<resp> outboundPackage = new OutboundPackage<resp>();
        // 将requestInProcess + 1
        context.preProcess();
        List<Throwable> exceptions = Lists.newArrayList();
        context.setReturnCode(StatusCode.SUCCESS);
        // main方法退出时,会将此值设为0
        if (!isOpen) {
            return this.serviceFailInner(context, data, new ShowDownRejectException());
        }
        if(data.getBody()!=null) {
            context.putData(Dict.INPUT_DATA, data.getBody());
        }

        try {
            // 记录服务调用次数
            requestInHandle.addAndGet(1);
            resp result = null;
            context.setServiceName(this.serviceName);
            try {
                preChain.doPreProcess(context, data, outboundPackage);
                // 调用子类方法
                result = doService(context, data, outboundPackage);
                if (logger.isDebugEnabled()) {
                    logger.debug("do service, router info: {}, service name: {}, result: {}", JsonUtil.object2Json(data.getRouterInfo()), serviceName, result);
                }
            } catch (Throwable e) {
                exceptions.add(e);
                logger.error("do service fail, cause by: {}", e.getMessage());
            }
            outboundPackage.setData(result);
            postChain.doPostProcess(context, data, outboundPackage);

        } 

模型服务代理类

ModelService主要用到了ModelServiceProvider这个Bean,它是模型服务的代理,我们来看这边的代码。
ModelServiceProvider使用了ModelManager,负责实际的模型管理,下一节会介绍ModelManager。
ModelServiceProvider提供了下面几个模型服务:

  1. 模型加载
  2. 模型在线发布
  3. 模型查询
  4. 模型卸载
  5. 模型解绑定
  6. 模型拉取
  7. 模型数据拉取

@FateService注解设置AbstractServiceAdaptor的preChain和postChain:

@FateService(name = "modelService", preChain = {
        "requestOverloadBreaker"
}, postChain = {
})

@FateService注解设置的chain在admin、service、proxy的Register中被调用:


/**
 * 当spring应用启动完成后,onApplicationEvent 方法会被调用
 **/
@Override
public void onApplicationEvent(ApplicationReadyEvent applicationEvent) {
    String[] beans = applicationContext.getBeanNamesForType(AbstractServiceAdaptor.class);
    FlowCounterManager flowCounterManager = applicationContext.getBean(FlowCounterManager.class);
    for (String beanName : beans) {
        AbstractServiceAdaptor serviceAdaptor = applicationContext.getBean(beanName, AbstractServiceAdaptor.class);
        serviceAdaptor.setFlowCounterManager(flowCounterManager);
        // 获取被FateService注解的bean
        FateService proxyService = serviceAdaptor.getClass().getAnnotation(FateService.class);
        Method[] methods = serviceAdaptor.getClass().getMethods();
        for (Method method : methods) {
            
            FateServiceMethod fateServiceMethod = method.getAnnotation(FateServiceMethod.class);
            if (fateServiceMethod != null) {
                String[] names = fateServiceMethod.name();
                for (String name : names) {
                    serviceAdaptor.getMethodMap().put(name, method);
                }
            }
        }
        if (proxyService != null) {
            serviceAdaptor.setServiceName(proxyService.name());
            String[] postChain = proxyService.postChain();
            String[] preChain = proxyService.preChain();
            for (String post : postChain) {
                Interceptor postInterceptor = applicationContext.getBean(post, Interceptor.class);
                serviceAdaptor.addPostProcessor(postInterceptor);
            }
            for (String pre : preChain) {
                Interceptor preInterceptor = applicationContext.getBean(pre, Interceptor.class);
                serviceAdaptor.addPreProcessor(preInterceptor);
            }

            this.serviceAdaptorMap.put(proxyService.name(), serviceAdaptor);
        }
    }

    logger.info("service register info {}", this.serviceAdaptorMap.keySet());
}

模型管理

ModelManager是非常重要的模块,负责模型服务的实际执行,我们分别看下上面提到的几个执行方法。

绑定

绑定的作用是给service id绑定一个已有的模型。
这里会维护一个serviceid -> key的映射(下图来自官网):
Fate-Serving推理服务源码解读_第2张图片
模型池就是namespaceMap,存储一个模型名称到ModelProcessor的映射关系。
注意,每次操作都会进行本地缓存的更新,本地缓存用于服务恢复。

    public synchronized ReturnResult bind(Context context, ModelServiceProto.PublishRequest req) {
        if (logger.isDebugEnabled()) {
            logger.debug("try to bind model, receive request : {}", req);
        }
        ReturnResult returnResult = new ReturnResult();
        String serviceId = req.getServiceId();
        Preconditions.checkArgument(StringUtils.isNotBlank(serviceId), "param service id is blank");
        Preconditions.checkArgument(!StringUtils.containsAny(serviceId, URL_FILTER_CHARACTER), "Service id contains special characters, " + JsonUtil.object2Json(URL_FILTER_CHARACTER));

        returnResult.setRetcode(StatusCode.SUCCESS);
        Model model = this.buildModelForBind(context, req);
        String modelKey = this.getNameSpaceKey(model.getTableName(), model.getNamespace());
        Model loadedModel = this.namespaceMap.get(modelKey);
        if (loadedModel == null) {
            throw new ModelNullException("model " + modelKey + " is not exist ");
        }
        this.serviceIdNamespaceMap.put(serviceId, modelKey);
        if (zookeeperRegistry != null) {
            if (StringUtils.isNotEmpty(serviceId)) {
                zookeeperRegistry.addDynamicEnvironment(serviceId);
            }
            zookeeperRegistry.register(FateServer.guestServiceSets, Lists.newArrayList(serviceId));
        }
        //update cache
        this.store(serviceIdNamespaceMap, serviceIdFile);
        return returnResult;
    }
    private Model buildModelForBind(Context context, ModelServiceProto.PublishRequest req) {
        // 从请求的modelMap中读取mode info,
        // 可以发现,这里用的全都是从req读出来的数据
        Model model = new Model();
        String role = req.getLocal().getRole();
        model.setPartId(req.getLocal().getPartyId());
        model.setRole(Dict.GUEST.equals(role) ? Dict.GUEST : Dict.HOST);
        String serviceId = req.getServiceId();
        model.getServiceIds().add(serviceId);
        Map<String, ModelServiceProto.RoleModelInfo> modelMap = req.getModelMap();
        ModelServiceProto.RoleModelInfo roleModelInfo = modelMap.get(model.getRole());
        Map<String, ModelServiceProto.ModelInfo> modelInfoMap = roleModelInfo.getRoleModelInfoMap();
        Map<String, ModelServiceProto.Party> roleMap = req.getRoleMap();
        ModelServiceProto.Party selfParty = roleMap.get(model.getRole());
        String selfPartyId = selfParty.getPartyIdList().get(0);
        ModelServiceProto.ModelInfo selfModelInfo = modelInfoMap.get(selfPartyId);
        String selfNamespace = selfModelInfo.getNamespace();
        String selfTableName = selfModelInfo.getTableName();
        model.setNamespace(selfNamespace);
        model.setTableName(selfTableName);
        return model;
    }

加载

这里数据提供方(host)加载模型时,记录数据使用方(guest) name + namespace -> (host) model 映射关系,实现使用方和提供方模型的一一对应。
partnerModelMap在guest方始终为空。namespaceMap在host和guest方都存在,记录本地模型池映射关系。


    public synchronized ReturnResult load(Context context, ModelServiceProto.PublishRequest req) {
        if (logger.isDebugEnabled()) {
            logger.debug("try to load model, receive request : {}", req);
        }
        ReturnResult returnResult = new ReturnResult();
        returnResult.setRetcode(StatusCode.SUCCESS);
        Model model = this.buildModelForLoad(context, req);
        String namespaceKey = this.getNameSpaceKey(model.getTableName(), model.getNamespace());
        ModelLoader.ModelLoaderParam modelLoaderParam = new ModelLoader.ModelLoaderParam();
        String loadType = req.getLoadType();
        if (StringUtils.isNotEmpty(loadType)) {
            modelLoaderParam.setLoadModelType(ModelLoader.LoadModelType.valueOf(loadType));
        } else {
            modelLoaderParam.setLoadModelType(ModelLoader.LoadModelType.FATEFLOW);
        }
        modelLoaderParam.setTableName(model.getTableName());
        modelLoaderParam.setNameSpace(model.getNamespace());
        modelLoaderParam.setFilePath(req.getFilePath());
        ModelLoader modelLoader = this.modelLoaderFactory.getModelLoader(context, modelLoaderParam.getLoadModelType());
        Preconditions.checkArgument(modelLoader != null, "model loader not found");
        ModelProcessor modelProcessor = modelLoader.loadModel(context, modelLoaderParam);
        if (modelProcessor == null) {
            throw new ModelProcessorInitException("model initialization error, please check if the model exists and the configuration of the FATEFLOW load model process is correct.");
        }
        model.setModelProcessor(modelProcessor);
        modelProcessor.setModel(model);
        // 本地模型池映射关系
        this.namespaceMap.put(namespaceKey, model);
        // 数据提供方(host)加载模型时,记录数据使用方(guest) name + namespace -> (host) model 映射关系
        // 实现使用方和提供方模型的一一对应
        if (Dict.HOST.equals(model.getRole())) {
            model.getFederationModelMap().values().forEach(remoteModel -> {
                String remoteNamespaceKey = this.getNameSpaceKey(remoteModel.getTableName(), remoteModel.getNamespace());
                this.partnerModelMap.put(remoteNamespaceKey, model);
            });
        }
        /**
         *  host model
         */
        if (Dict.HOST.equals(model.getRole()) && zookeeperRegistry != null) {
            String modelKey = ModelUtil.genModelKey(model.getTableName(), model.getNamespace());
            zookeeperRegistry.addDynamicEnvironment(EncryptUtils.encrypt(modelKey, EncryptMethod.MD5));
            zookeeperRegistry.register(FateServer.hostServiceSets);
        }
        // update cache
        this.store(namespaceMap, namespaceFile);
        return returnResult;

    }

buildModelForLoad执行实际的模型动态加载:

private Model buildModelForLoad(Context context, ModelServiceProto.PublishRequest req) {
    Model model = new Model();
    String role = req.getLocal().getRole();
    model.setPartId(req.getLocal().getPartyId());
    model.setRole(Dict.GUEST.equals(role) ? Dict.GUEST : Dict.HOST);
    Map<String, ModelServiceProto.RoleModelInfo> modelMap = req.getModelMap();
    ModelServiceProto.RoleModelInfo roleModelInfo = modelMap.get(model.getRole());
    Map<String, ModelServiceProto.ModelInfo> modelInfoMap = roleModelInfo.getRoleModelInfoMap();
    Map<String, ModelServiceProto.Party> roleMap = req.getRoleMap();
    String remotePartyRole = model.getRole().equals(Dict.GUEST) ? Dict.HOST : Dict.GUEST;
    ModelServiceProto.Party remoteParty = roleMap.get(remotePartyRole);
    List<String> remotePartyIdList = remoteParty.getPartyIdList();
    for (String remotePartyId : remotePartyIdList) {
        ModelServiceProto.RoleModelInfo remoteRoleModelInfo = modelMap.get(remotePartyRole);
        ModelServiceProto.ModelInfo remoteModelInfo = remoteRoleModelInfo.getRoleModelInfoMap().get(remotePartyId);
        Model remoteModel = new Model();
        remoteModel.setPartId(remotePartyId);
        remoteModel.setNamespace(remoteModelInfo.getNamespace());
        remoteModel.setTableName(remoteModelInfo.getTableName());
        remoteModel.setRole(remotePartyRole);
        model.getFederationModelMap().put(remotePartyId, remoteModel);
    }
    ModelServiceProto.Party selfParty = roleMap.get(model.getRole());
    String selfPartyId = selfParty.getPartyIdList().get(0);
    ModelServiceProto.ModelInfo selfModelInfo = modelInfoMap.get(model.getPartId());
    Preconditions.checkArgument(selfModelInfo != null, "model info is invalid");
    String selfNamespace = selfModelInfo.getNamespace();
    String selfTableName = selfModelInfo.getTableName();
    model.setNamespace(selfNamespace);
    model.setTableName(selfTableName);
    // 从FATEFLOW中加载模型
    if (ModelLoader.LoadModelType.FATEFLOW.name().equals(req.getLoadType())) {
        try {
            ModelLoader.ModelLoaderParam modelLoaderParam = new ModelLoader.ModelLoaderParam();
            modelLoaderParam.setLoadModelType(ModelLoader.LoadModelType.FATEFLOW);
            modelLoaderParam.setTableName(model.getTableName());
            modelLoaderParam.setNameSpace(model.getNamespace());
            modelLoaderParam.setFilePath(req.getFilePath());
            ModelLoader modelLoader = this.modelLoaderFactory.getModelLoader(context, ModelLoader.LoadModelType.FATEFLOW);
            model.setResourceAdress(getAdressForUrl(modelLoader.getResource(context, modelLoaderParam)));
        } catch (Exception e) {
            logger.error("getloadModelUrl error = {}", e);
        }
    }
    return model;
}

故障恢复

服务注册、发现和鉴权

服务注册的时间

加载时

加载时只会在数据提供方进行服务注册:

       /**
         *  host model
         */
        if (Dict.HOST.equals(model.getRole()) && zookeeperRegistry != null) {
            String modelKey = ModelUtil.genModelKey(model.getTableName(), model.getNamespace());
            zookeeperRegistry.addDynamicEnvironment(EncryptUtils.encrypt(modelKey, EncryptMethod.MD5));
            zookeeperRegistry.register(FateServer.hostServiceSets);
        }

那么这里的DynamicEnvironment作用是什么呢?FateServer.hostServiceSets又是在什么时候被注册的呢?
首先我们看下FateServer.hostServiceSets的初始化,通过阅读源码我们可以发现,在ServingServer这个bean实现了InitializingBean,在初始化完成之后,会调用下面这一段代码,这段代码注册了Fate-Serving需要初始化的几个服务,后面我们可以看到,新的服务都是由这几个初始服务衍生的

@Override
public void afterPropertiesSet() throws Exception {
    logger.info("try to star server ,meta info {}", MetaInfo.toMap());
    Executor executor = new ThreadPoolExecutor(MetaInfo.PROPERTY_SERVING_CORE_POOL_SIZE, MetaInfo.PROPERTY_SERVING_MAX_POOL_SIZE, MetaInfo.PROPERTY_SERVING_POOL_ALIVE_TIME, TimeUnit.MILLISECONDS,
            MetaInfo.PROPERTY_SERVING_POOL_QUEUE_SIZE == 0 ? new SynchronousQueue<Runnable>() :
                    (MetaInfo.PROPERTY_SERVING_POOL_QUEUE_SIZE < 0 ? new LinkedBlockingQueue<Runnable>()
                            : new LinkedBlockingQueue<Runnable>(MetaInfo.PROPERTY_SERVING_POOL_QUEUE_SIZE)), new NamedThreadFactory("ServingServer", true));
    FateServerBuilder serverBuilder = (FateServerBuilder) ServerBuilder.forPort(MetaInfo.PROPERTY_SERVER_PORT);
    serverBuilder.keepAliveTime(100, TimeUnit.MILLISECONDS);
    serverBuilder.executor(executor);
    serverBuilder.addService(ServerInterceptors.intercept(guestInferenceService, new ServiceExceptionHandler(), new ServiceOverloadProtectionHandle()), GuestInferenceService.class);
    serverBuilder.addService(ServerInterceptors.intercept(modelService, new ServiceExceptionHandler(), new ServiceOverloadProtectionHandle()), ModelService.class);
    serverBuilder.addService(ServerInterceptors.intercept(hostInferenceService, new ServiceExceptionHandler(), new ServiceOverloadProtectionHandle()), HostInferenceService.class);
    serverBuilder.addService(ServerInterceptors.intercept(commonService, new ServiceExceptionHandler(), new ServiceOverloadProtectionHandle()), CommonService.class);
    server = serverBuilder.build();
    server.start();
    boolean useRegister = MetaInfo.PROPERTY_USE_REGISTER;
    if (useRegister) {
        logger.info("serving-server is using register center");
        zookeeperRegistry.subProject(Dict.PROPERTY_PROXY_ADDRESS);
        zookeeperRegistry.subProject(Dict.PROPERTY_FLOW_ADDRESS);
        zookeeperRegistry.register(FateServer.serviceSets);
    } else {
        logger.warn("serving-server not use register center");
    }
    modelManager.restore(new BaseContext());
    logger.warn("serving-server start over");
}

接下来看第二个问题,DynamicEnvironment的作用,我们来看register这里的代码:

public synchronized void register(Set<RegisterService> sets) {
    if (logger.isDebugEnabled()) {
        logger.debug("prepare to register {}", sets);
    }
    String hostAddress = NetUtils.getLocalIp();
    Preconditions.checkArgument(port != 0);
    Preconditions.checkArgument(StringUtils.isNotEmpty(environment));

    Set<URL> registered = this.getRegistered();
    for (RegisterService service : sets) {
        try {
            URL url = generateUrl(hostAddress, service);
            URL serviceUrl = url.setProject(project);
            // 对于推理服务,useDynamicEnvironment为True
            if (service.useDynamicEnvironment()) {
                if (CollectionUtils.isNotEmpty(dynamicEnvironments)) {
                    dynamicEnvironments.forEach(environment -> {
                        URL newServiceUrl = service.protocol().equals(Dict.HTTP) ? url : serviceUrl.setEnvironment(environment);
                        // use cache service params
                        loadCacheParams(newServiceUrl);
                        // 对于每一个environment,生成一个新的service
                        // 生成的数量是environment的个数*sets的size
                        String serviceName = service.serviceName() + environment;
                        if (!registedString.contains(serviceName)) {
                            this.register(newServiceUrl);
                            this.registedString.add(serviceName);
                        } else {
                            logger.info("url {} is already registed, will not do anything ", newServiceUrl);
                        }
                    });
                }
            } else {
                if (!registedString.contains(service.serviceName() + environment)) {
                    URL newServiceUrl = service.protocol().equals(Dict.HTTP) ? url : serviceUrl.setEnvironment(environment);
                    if (logger.isDebugEnabled()) {
                        logger.debug("try to register url {}", newServiceUrl);
                    }
                    // use cache service params
                    loadCacheParams(newServiceUrl);

                    this.register(newServiceUrl);
                    this.registedString.add(service.serviceName() + environment);
                } else {
                    logger.info("url {} is already registed, will not do anything ", service.serviceName());
                }
            }
        } catch (Exception e) {
            e.printStackTrace();
            logger.error("try to register service {} failed", service);
        }
    }

    syncServiceCacheFile();

    if (logger.isDebugEnabled()) {
        logger.debug("registed urls {}", registered);
    }
}

可以看出来,这里通过environment的个数*sets的size的方式,减少了代码复杂度;只在数据提供方注册一次,防止重复注册。

绑定时

和上面加载的思路一样,只不过绑定只会被guest调用,所以不需要区分guest和host:

if (zookeeperRegistry != null) {
    if (StringUtils.isNotEmpty(serviceId)) {
        zookeeperRegistry.addDynamicEnvironment(serviceId);
    }
    // 给guestServiceSets中的每一个服务都注册一个新的serviceId服务
    zookeeperRegistry.register(FateServer.guestServiceSets, Lists.newArrayList(serviceId));
}

unload和unregister的代码逻辑差不多,因此就不展开了。

服务注册和服务发现

接下来我们看register模块,你会发现这里的代码特别多,因为路由、负载均衡等模块也在这里实现了。
本节主要关注注册逻辑,这样只需要看common和zookeeper两个文件夹就行了,接下来结合官网这张部署实例的图来讲。
Fate-Serving推理服务源码解读_第3张图片
首先,我们可以看到,fate-serving不实现zookeeper,zk集群需要客户自己部署。
这里用到的主要是ZookeeperRegistry这个类,我们就从这里展开。

zookeeper client的创建

public static ConcurrentMap registeryMap = new ConcurrentHashMap()是一个URL - > ZookeeperRegistry单例的map。
它的初始化流程如下:

public static synchronized ZookeeperRegistry getRegistry(String url, String project, String environment, int port) {
    if (url == null) {
        return null;
    }
    URL registryUrl = URL.valueOf(url);
    registryUrl = registryUrl.addParameter(Constants.ENVIRONMENT_KEY, environment);
    registryUrl = registryUrl.addParameter(Constants.SERVER_PORT, port);
    registryUrl = registryUrl.addParameter(Constants.PROJECT_KEY, project);
    List<URL> backups = registryUrl.getBackupUrls();
    if (registeryMap.get(registryUrl) == null) {
        URL finalRegistryUrl = registryUrl;
        registeryMap.computeIfAbsent(registryUrl, n -> {
            CuratorZookeeperTransporter curatorZookeeperTransporter = new CuratorZookeeperTransporter();
            ZookeeperRegistryFactory zookeeperRegistryFactory = new ZookeeperRegistryFactory();
            zookeeperRegistryFactory.setZookeeperTransporter(curatorZookeeperTransporter);
            ZookeeperRegistry zookeeperRegistry = (ZookeeperRegistry) zookeeperRegistryFactory.createRegistry(finalRegistryUrl);
            return zookeeperRegistry;
        });
    }
    return registeryMap.get(registryUrl);
}

我们先来看CuratorZookeeperTransporter,它负责维护一个Map zookeeperClientMap ,保留URL -> ZookeeperClient的关系。
我们看ZookeeperClient初始化的过程:

@Override
public ZookeeperClient connect(URL url) {
    ZookeeperClient zookeeperClient;
    // 解析所有url
    List<String> addressList = getURLBackupAddress(url);
    // The field define the zookeeper server , including protocol, host, port, username, password
    // 更新url->zookeeperClient映射
    if ((zookeeperClient = fetchAndUpdateZookeeperClientCache(addressList)) != null && zookeeperClient.isConnected()) {
        logger.info("find valid zookeeper client from the cache for address: " + url);
        return zookeeperClient;
    }
    // avoid creating too many connections, so add lock
    synchronized (zookeeperClientMap) {
        if ((zookeeperClient = fetchAndUpdateZookeeperClientCache(addressList)) != null && zookeeperClient.isConnected()) {
            logger.info("find valid zookeeper client from the cache for address: " + url);
            return zookeeperClient;
        }

        zookeeperClient = createZookeeperClient(toClientURL(url));
        logger.info("No valid zookeeper client found from cache, therefore create a new client for url. " + url);
        writeToClientMap(addressList, zookeeperClient);
        // 调度到下面的构造方法
    }
    return zookeeperClient;
}

public CuratorZookeeperClient(URL url) {
    super(url);
    try {
        // 从 URL 中获取连接超时设置,默认为 5000 毫秒
        int timeout = url.getParameter(TIMEOUT_KEY, 5000);

        // 使用 CuratorFrameworkFactory.Builder 构建 Curator 客户端
        CuratorFrameworkFactory.Builder builder = CuratorFrameworkFactory.builder()
                .connectString(url.getBackupAddress()) // 获取连接地址,这里使用了 getBackupAddress 方法
                .retryPolicy(new RetryNTimes(1, 1000)) // 设置重试策略,这里是重试一次,每次间隔 1000 毫秒
                .connectionTimeoutMs(timeout); // 设置连接超时时间

        aclEnable = MetaInfo.PROPERTY_ACL_ENABLE;
        if (aclEnable) {
            aclUsername = MetaInfo.PROPERTY_ACL_USERNAME;
            aclPassword = MetaInfo.PROPERTY_ACL_PASSWORD;

            // 如果启用 ACL,检查用户名和密码是否为空
            if (StringUtils.isBlank(aclUsername) || StringUtils.isBlank(aclPassword)) {
                aclEnable = false;
                MetaInfo.PROPERTY_ACL_ENABLE = false;
            } else {
                // 如果用户名和密码不为空,添加授权信息和 ACL 规则
                builder.authorization(SCHEME, (aclUsername + ":" + aclPassword).getBytes());

                Id allow = new Id(SCHEME, DigestAuthenticationProvider.generateDigest(aclUsername + ":" + aclPassword));
                // add more
                acls.add(new ACL(ZooDefs.Perms.ALL, allow));
            }
        }

        // 使用 builder 构建 Curator 客户端
        client = builder.build();

        // 添加连接状态监听器,处理连接状态变化事件
        client.getConnectionStateListenable().addListener(new ConnectionStateListener() {
            @Override
            public void stateChanged(CuratorFramework client, ConnectionState state) {
                // 处理连接状态变化事件,根据不同状态调用 stateChanged 方法
                // 只实现了RECONNECTED
                if (state == ConnectionState.LOST) {
                    CuratorZookeeperClient.this.stateChanged(StateListener.DISCONNECTED);
                } else if (state == ConnectionState.CONNECTED) {
                    CuratorZookeeperClient.this.stateChanged(StateListener.CONNECTED);
                } else if (state == ConnectionState.RECONNECTED) {
                    CuratorZookeeperClient.this.stateChanged(StateListener.RECONNECTED);
                }
            }
        });

        // 启动 Curator 客户端
        client.start();

        // 如果启用 ACL,为根节点设置 ACL
        if (aclEnable) {
            client.setACL().withACL(acls).forPath("/");
        }
    } catch (Exception e) {
        // 处理异常,抛出 IllegalStateException
        throw new IllegalStateException(e.getMessage(), e);
    }
}

ZookeeperRegistry 注册

继续看ZookeeperRegistry,在client初始化完后,ZookeeperRegistry会add一个状态监听器,用于断线重连之后服务的恢复。


public ZookeeperRegistry(URL url, ZookeeperTransporter zookeeperTransporter) {
    super(url);
    String group = url.getParameter(ROOT_KEY, Dict.DEFAULT_FATE_ROOT);
    if (!group.startsWith(PATH_SEPARATOR)) {
        group = PATH_SEPARATOR + group;
    }
    this.environment = url.getParameter(ENVIRONMENT_KEY, "online");
    project = url.getParameter(PROJECT_KEY);
    port = url.getParameter(SERVER_PORT) != null ? new Integer(url.getParameter(SERVER_PORT)) : 0;

    this.root = group;
    zkClient = zookeeperTransporter.connect(url);
    zkClient.addStateListener(state -> {
        if (state == StateListener.RECONNECTED) {
            logger.error("state listener reconnected");
            try {
                recover();
            } catch (Exception e) {
                logger.error(e.getMessage(), e);
            }
        }
    });
}
// recover最后会调用到:
public void addFailedRegisterComponentTask(URL url) {
    if(url!=null) {
        String instanceId = AbstractRegistry.INSTANCE_ID;

        FailedRegisterComponentTask oldOne = this.failedRegisterComponent.get(instanceId);
        if (oldOne != null) {
            return;
        }
        // 新的重试任务
        FailedRegisterComponentTask newTask = new FailedRegisterComponentTask(url, this);
        oldOne = failedRegisterComponent.putIfAbsent(instanceId, newTask);
        if (oldOne == null) {
            // never has a retry task. then start a new task for retry.
            // 设置超时时间,超时后调用doRegisterComponent()
            retryTimer.newTimeout(newTask, retryPeriod, TimeUnit.MILLISECONDS);
        }
    }
}

服务注册最后会调用到下面的client代码:

// 创建临时节点
@Override
public void createEphemeral(String path) {
    try {
        if (logger.isDebugEnabled()) {
            logger.debug("createEphemeral {}", path);
        }

        if (aclEnable) {
            // 如果启用 ACL,则使用指定的 ACL(acls)创建临时节点
            client.create().withMode(CreateMode.EPHEMERAL).withACL(acls).forPath(path);
        } else {
            // 如果未启用 ACL,则以默认权限创建临时节点
            client.create().withMode(CreateMode.EPHEMERAL).forPath(path);
        }
    } catch (NodeExistsException e) {
    } catch (Exception e) {
        throw new IllegalStateException(e.getMessage(), e);
    }
}
// 创建永久节点

@Override
protected void createPersistent(String path, String data) {
    byte[] dataBytes = data.getBytes(CHARSET);
    try {
        if (logger.isDebugEnabled()) {
            logger.debug("createPersistent {} data {}", path, data);
        }
        if (aclEnable) {
            client.create().withACL(acls).forPath(path, dataBytes);
        } else {
            client.create().forPath(path, dataBytes);
        }
    } catch (NodeExistsException e) {
        try {
            if (aclEnable) {
                Stat stat = client.checkExists().forPath(path);
                client.setData().withVersion(stat.getAversion()).forPath(path, dataBytes);
            } else {
                client.setData().forPath(path, dataBytes);
            }
        } catch (Exception e1) {
            throw new IllegalStateException(e.getMessage(), e1);
        }
    } catch (Exception e) {
        throw new IllegalStateException(e.getMessage(), e);
    }
}

ZookeeperRegistry 发现

subProject实现了服务发现,最终会调用到client.getChildren().usingWatcher(listener).forPath(path):

@Override
public void subProject(String project) {
    if (logger.isDebugEnabled()) {
        logger.debug("try to subProject: {}", project);
    }
    super.subProject(project);
    failedSubProject.remove(project);
    try {
        doSubProject(project);
    } catch (Exception e) {
        addFailedSubscribedProjectTask(project);
    }
}
@Override
public void doSubProject(String project) {
    String path = root + Constants.PATH_SEPARATOR + project;
    // 监听 root + Constants.PATH_SEPARATOR + project
    List<String> environments = zkClient.addChildListener(path, (parent, childrens) -> {
        if (StringUtils.isNotEmpty(parent)) {
            logger.info("fire environments changes {}", childrens);
            // 监听新出现的children
            subEnvironments(path, project, childrens);
        }
    });

    if (logger.isDebugEnabled()) {
        logger.debug("environments {}", environments);
    }
    if (environments == null) {
        if (logger.isDebugEnabled()) {
            logger.debug("path {} is not exist in zk", path);
        }
        throw new RuntimeException("environment is null");
    }

    subEnvironments(path, project, environments);
}

private void subEnvironments(String path, String project, List<String> environments) {
    if (environments != null) {
        for (String environment : environments) {
            String tempPath = path + Constants.PATH_SEPARATOR + environment;
        	// 监听 root + Constants.PATH_SEPARATOR + project + onstants.PATH_SEPARATOR + environment
            List<String> services = zkClient.addChildListener(tempPath, (parent, childrens) -> {
                if (StringUtils.isNotEmpty(parent)) {
                    if (logger.isDebugEnabled()) {
                        logger.debug("fire services changes {}", childrens);
                    }
                    subServices(project, environment, childrens);
                }
            });

            subServices(project, environment, services);
        }
    }
}

如果父节点发生了变化,那么就会调用下面的方法,进行订阅:

private void subServices(String project, String environment, List<String> services) {
    if (services != null) {
        for (String service : services) {
            String subString = project + Constants.PATH_SEPARATOR + environment + Constants.PATH_SEPARATOR + service;
            if (logger.isDebugEnabled()) {
                logger.debug("subServices sub {}", subString);
            }
            subscribe(URL.valueOf(subString), urls -> {
                if (logger.isDebugEnabled()) {
                    logger.debug("change services urls =" + urls);
                }
            });
        }
    }
}

因为在fate-serving中使用的zk结构如下:
yml /FATE-SERVICES/{模块名}/{ID}/{接口名}/provider/{服务提供者信息}
从前面我们可以知道用户新的服务都是由固定的模块生成的,所以用户注册了新的服务之后,也能够被client发现。原始服务的注册在afterPropertiesSet()中进行,上面已经介绍过了。

HashedWheelTimer 定时任务

我们注意到ZookeeperRegistry的基类FailbackRegistry中出现了retryTimer,我们来看下它的实现。
在ZookeeperRegistry和FailbackRegistry中,任务失败后会设置:retryTimer.newTimeout(newTask, retryPeriod, TimeUnit.MILLISECONDS);
来启动一个定时重试任务,它会执行:

HashedWheelTimeout timeout = new HashedWheelTimeout(this, task, deadline);
timeouts.add(timeout);

�把任务加入队列中,worker会poll这个队列,到时间后执行任务。
HashedWheelTimer构造函数会执行worker的初始化逻辑,

workerThread = threadFactory.newThread(worker);

threadFactory是一个名称标记的线程池实现,给每个线程进行了命名。
我们继续看worker这边的run方法:

@Override
public void run() {
    // Initialize the startTime.
    startTime = System.nanoTime();
    if (startTime == 0) {
        // We use 0 as an indicator for the uninitialized value here, so make sure it's not 0 when initialized.
        startTime = 1;
    }

    // Notify the other threads waiting for the initialization at start().
    // HashedWheelTimer执行线程和worker线程之间同步
    // 等待worker初始化完成后才能添加任务
    startTimeInitialized.countDown();

    do {
        final long deadline = waitForNextTick();
        if (deadline > 0) {
            // 这里将相同tick的timeouts放到同一个bucket,就是所谓的HashedWheelBucket
            int idx = (int) (tick & mask);
            processCancelledTasks();
            HashedWheelBucket bucket =
                    wheel[idx];
            transferTimeoutsToBuckets();
            // 过期掉bucket中的所有timeouts
            bucket.expireTimeouts(deadline);
            tick++;
        }
    } while (WORKER_STATE_UPDATER.get(HashedWheelTimer.this) == WORKER_STATE_STARTED);

    // Fill the unprocessedTimeouts so we can return them from stop() method.
    for (HashedWheelBucket bucket : wheel) {
        bucket.clearTimeouts(unprocessedTimeouts);
    }
    for (; ; ) {
        // 处理所有的timeouts
        HashedWheelTimeout timeout = timeouts.poll();
        if (timeout == null) {
            break;
        }
        if (!timeout.isCancelled()) {
            unprocessedTimeouts.add(timeout);
        }
    }
    processCancelledTasks();
}

这里的逻辑就很简单,也没用小顶堆,因为这里的过期任务数量其实并不多。

路由和负载均衡

proxy模块用于路由服务的基类是BaseServingRouter,它有两个实现,一个是ConfigFileBasedServingRouter,另一个是ZkServingRouter。被用在如下地方:

  1. HealthCheckEndPointService 用到了ConfigFileBasedServingRouter
  2. DefaultServingRouter 用到了zkServingRouter,目前还没有实现

register模块用于路由服务的基类的RouterService,它的使用如下所示:

  1. FederationRouterInterceptor 未实现(我没找到哪里用这个)
  2. DefaultServingRouter 未实现
  3. RegistedClient 路由到server
  4. HealthCheckEndPointService 路由到fateflow
  5. RouterService 路由到资源服务器
  6. FateFlowModelLoader 路由到flow
  7. DefaultFederatedRpcInvoker 未实现

我们看register这边的逻辑,负载均衡主要被路由模块使用,所以就一起看了。

register模块

我们可以看到,被AbstractRouterService使用的是LoadBalanceModel.random,RandomLoadBalance只有一个选择算法,按照这个算法,落在权重大的节点中的概率更高。

public class RandomLoadBalance extends AbstractLoadBalancer {

    public static final String NAME = "random";

    @Override
    protected List<URL> doSelect(List<URL> urls) {
        // 获取URL列表的长度
        int length = urls.size();

        // 初始化标志,表示所有URL的权重是否相同
        boolean sameWeight = true;

        // 初始化数组,用于存储每个URL的权重
        int[] weights = new int[length];

        // 获取第一个URL的权重,用于后续比较
        int firstWeight = getWeight(urls.get(0));
        weights[0] = firstWeight;

        // 初始化总权重,并加上第一个URL的权重
        int totalWeight = firstWeight;
        
        // 遍历剩余的URL,计算总权重,同时检查各个URL的权重是否相同
        for (int i = 1; i < length; i++) {
            int weight = getWeight(urls.get(i));

            weights[i] = weight;

            totalWeight += weight;
            
            // 如果有一个URL的权重不同于第一个URL,则标志位置为false
            if (sameWeight && weight != firstWeight) {
                sameWeight = false;
            }
        }

        // 如果总权重为正且不是所有URL的权重都相同,进行随机选择
        if (totalWeight > 0 && !sameWeight) {

            // 生成一个随机偏移量,范围在总权重内
            int offset = ThreadLocalRandom.current().nextInt(totalWeight);

            // 遍历URL列表,根据随机偏移量选择一个URL,使得该URL的权重占比与总权重相匹配
            for (int i = 0; i < length; i++) {
                offset -= weights[i];
                if (offset < 0) {
                    // 将选定的URL放入列表并返回
                    return Lists.newArrayList(urls.get(i));
                }
            }
        }

        // 特殊情况处理:如果总权重为非正数或所有URL的权重都相同,返回随机选择的URL
        return Lists.newArrayList(urls.get(ThreadLocalRandom.current().nextInt(length)));
    }
}

权重参数被存放在URL的private volatile transient Map numbers;中,我们可以看到这里没有修改的逻辑,所以最终都会使用默认值。

你可能感兴趣的:(推理引擎)