Milvus实战:构建Q&A系统及推荐系统

Milvus简介

全民AI的时代已经在趋势之中,各类应用层出不穷,而想要构建一个完善的AI应用/系统,底层存储是不可缺少的一个组件。
与传统数据库或大数据存储不同的是,这种场景下则需要选择向量数据库,是专门用来存储和查询向量的数据库,其存储的向量来自于对文本、语音、图像、视频等的向量化数据,向量数据库不仅能够完成基本的CRUD(添加、读取查询、更新、删除)等操作,还能够对向量数据进行更快速的相似性搜索。

Milvus是众多向量库中的之一,适用于多个场景,如Questions & Answering系统、推荐系统等,单节点 Milvus 可以在秒内完成十亿级的向量搜索,分布式架构亦能满足用户的水平扩展需求。

参考文档:
Milvus官网
为AI而生的数据库:Milvus详解及实战

实践之问答系统&推荐系统Mix-in

元数据定义

不论是问答还是推荐,它们对上层暴露的接口仅仅是predict(...)/search(...)/query(...),模式是相同的,因此可以共用一个基本的Schema,固定基本的字段即可。

public class MilvusMeta {
    @Getter
    private final RecommenderSchema defaultMetricsSchema;
    @Getter
    private final QASchema defaultQASchema;

    public interface Schema {
        String getCollectionName();

        CreateCollectionParam getCreateCollectionParam();

        CreateIndexParam getCreateIndexParam();
    }


    @AllArgsConstructor
    public abstract static class BasicSchema implements Schema {
        public static final String SEARCH_PARAM = "{\"nprobe\":10}";    // Params
        public static final String INDEX_PARAM = "{\"nlist\":1024}";     // ExtraParam
        public static final IndexType INDEX_TYPE_DEFAULT = IndexType.IVF_FLAT;
        public static final MetricType METRIC_TYPE_DEFAULT = MetricType.L2;      // metric type
        public static final String INDEX_NAME_DEFAULT = "ivf_flat";
        public static final String PRIMARY_KEY_FIELD_NAME_DEFAULT = "id";
        public static final String PARTITION_KEY_FIELD_NAME_DEFAULT = "public";
        public static final String FIELD_NAME_DEFAULT = "embeddings";

        @Getter
        protected final CreateCollectionParam createCollectionParam;

        @Getter
        protected final CreateIndexParam createIndexParam;

        @Override
        public String getCollectionName() {
            return createCollectionParam.getCollectionName();
        }
    }

    /**
     * This schema is designed for storing Zen metrics.
     * TODO: Add more fields/features to describe a metric.
     */
    public static class RecommenderSchema extends BasicSchema {

        private RecommenderSchema(CreateCollectionParam collectionParam, CreateIndexParam indexParam) {
            super(collectionParam, indexParam);
        }

        public static RecommenderSchema create(ZenAiConfig.Storages.MilvusConf conf) {
            ZenAiConfig.Storages.Collection collection = conf.getActiveRecommenderCollection();
            return new RecommenderSchema(
                defaultCollectionParam(collection, collection.getEmbeddingsDimension()),
                MilvusUtil.createIndexParam(collection));
        }

        private static CreateCollectionParam defaultCollectionParam(ZenAiConfig.Storages.Collection collection,
            int dimension) {
            FieldType pkType = FieldType.newBuilder()
                .withName(collection.getPrimaryKey())
                .withDataType(DataType.VarChar)
                .withPrimaryKey(true)
                .withMaxLength(100)
                .withAutoID(false)
                .build();
            // 被embedding的字段
            FieldType embeddedFieldType = FieldType.newBuilder()
                .withName(collection.getEmbeddedFieldName())
                .withDataType(DataType.VarChar)
                .withMaxLength(255)
                .build();
            // embedding vector字段
            FieldType embeddingFieldType = FieldType.newBuilder()
                .withName(collection.getFieldName())
                .withDataType(DataType.FloatVector)
                .withDimension(dimension)
                .build();
            // 指定分区键字段,每一个Collection都需要指定一个分区键,除了能够Hive/Spark那样切分数据外,还能够加速相似查询。
            // 虽然Milvus支持多种方案以切分数据,但从管理复杂度、查询效率上来看,一个Collection对应多个数据分区,是最佳的方案。
            FieldType partitionKeyType = FieldType.newBuilder()
                .withName(collection.getPartitionKey())
                .withPartitionKey(true)
                .withDataType(DataType.VarChar)
                .withMaxLength(100)
                .build();
            return CreateCollectionParam.newBuilder()
                .withCollectionName(collection.getName())
                .withDescription(collection.getDescription())
                // .withShardsNum(2)
                .addFieldType(pkType)
                .addFieldType(embeddedFieldType)
                .addFieldType(embeddingFieldType)
                .addFieldType(partitionKeyType)
                // 开启动态字段添加功能
                .withEnableDynamicField(true)
                .build();
        }

    }

    public static class QASchema extends BasicSchema {
        public static final String ANSWER_FIELD_NAME = "answer";
        public static final String SCORE_FIELD_NAME = "score";
        public static final float SCORE_MAX_DEFAULT = 5.0f;
        public static final float SCORE_MIN_DEFAULT = 0.0f;

        public static final String INTENTION_FIELD_NAME = "intention";
        public static final String QUESTION_OCCURRENCE = "occurrence";

        public QASchema(CreateCollectionParam createCollectionParam, CreateIndexParam createIndexParam) {
            super(createCollectionParam, createIndexParam);
        }

        public static QASchema create(ZenAiConfig.Storages.MilvusConf conf) {
            ZenAiConfig.Storages.Collection collection = conf.getActiveQACollection();
            return new QASchema(
                defaultCollectionParam(collection, collection.getEmbeddingsDimension()),
                MilvusUtil.createIndexParam(collection));
        }

        private static CreateCollectionParam defaultCollectionParam(ZenAiConfig.Storages.Collection collection,
            int dimension) {
            FieldType pkType = FieldType.newBuilder()
                .withName(collection.getPrimaryKey())
                .withDataType(DataType.VarChar)
                .withPrimaryKey(true)
                .withMaxLength(100)
                .withAutoID(false)
                .build();
            FieldType embeddedFieldType = FieldType.newBuilder()
                .withName(collection.getEmbeddedFieldName())
                .withDataType(DataType.VarChar)
                .withMaxLength(65535)
                .build();
            FieldType embeddingFieldType = FieldType.newBuilder()
                .withName(collection.getFieldName())
                .withDataType(DataType.FloatVector)
                .withDimension(dimension)
                .build();
            FieldType partitionKeyType = FieldType.newBuilder()
                .withName(collection.getPartitionKey())
                .withPartitionKey(true)
                .withDataType(DataType.VarChar)
                .withMaxLength(100)
                .build();
            return CreateCollectionParam.newBuilder()
                .withCollectionName(collection.getName())
                .withDescription(collection.getDescription())
                // .withShardsNum(2)
                .addFieldType(pkType)
                .addFieldType(embeddedFieldType)
                .addFieldType(embeddingFieldType)
                .addFieldType(partitionKeyType)
                .withEnableDynamicField(true) // enable to insert new fields without modifying the code
                .build();
        }
    }

Milvus可行的操作接口定义

public interface IMilvusOperations {

    ZenAiConfig.Storages.Collection getCollection();

    MilvusConnection.MultiStatus delete(Filter filter);

    MilvusConnection.MultiStatus create(MilvusMeta.Index index);

    MilvusConnection.MultiStatus drop(String index);

    MilvusConnection.MultiStatus insert(MilvusData.Dataset dataset);

    MilvusConnection.MultiStatus insertAndFlush(MilvusData.Dataset dataset);
        /**
     * Query records by filter on the specified partition, which works like a normal SQL engine.
     *
     * @param partition which partition to query
     * @param filter boolean expression obeys the rules of Milvus
     * @param outputFields if empty, the result will contain all the fields, including the dynamic;
     *                     otherwise the result only contains the specified fields.
     * @return a nonnull instance, size of which is 0 if no matched records, otherwise is positive.
     */
    MilvusData.BasicPredictData queryByPartition(String partition, Filter filter, List<String> outputFields);

    List<MilvusData.BasicPredictData> search(List<List<Float>> vectors, Filter filter, int topK,
        List<String> outputFields);
 }

抽象系统接口定义

/**
 * 每个系统可能有不同的embedding的实现,因此需要定义一个接口。
 */
public interface IEmbedding {
    ImmutableList<List<Float>> getEmbeddings(List<String> messages);
}

/**
 * 通用接口定义,供应用层使用,可以基于sentence返回Milvus相似性结果集。
 */
public interface INlpSystem extends IMilvusOperations, IDataset, IEmbedding {

    default MilvusData.BasicPredictData predict(String sentence) {
        return predict(sentence, Filter.TRUE);
    }

    default MilvusData.BasicPredictData predict(String sentence, Filter filter) {
        return predict(sentence, filter, getCollection().getOutputFields());
    }

    default MilvusData.BasicPredictData predict(String sentence, Filter filter,
        List<String> outputFields) {
        ImmutableList<List<Float>> vectors = getEmbeddings(Lists.newArrayList(sentence));
        if (vectors.isEmpty()) {
            return MilvusData.BasicPredictData.EMPTY;
        }
        List<String> mergedOutputFields = Sets.union(
                ImmutableSet.copyOf(outputFields),
                ImmutableSet.copyOf(getCollection().getOutputFields()))
            .immutableCopy().asList();
        List<MilvusData.BasicPredictData> res = search(vectors, filter, getCollection().getTopk(), mergedOutputFields);
        return res.isEmpty() ? MilvusData.BasicPredictData.EMPTY : res.get(0);
    }

    default MilvusData.BasicPredictData predictByPartition(String partition, String sentence) {
        return predictByPartition(partition, sentence, Filter.TRUE, getCollection().getOutputFields());
    }

    default MilvusData.BasicPredictData predictByPartition(String partition, String sentence,
        Filter filter, List<String> outputFields) {
        ImmutableList<List<Float>> vectors = getEmbeddings(Lists.newArrayList(sentence));
        if (vectors.isEmpty()) {
            return MilvusData.BasicPredictData.EMPTY;
        }
        List<String> mergedOutputFields = Sets.union(
                ImmutableSet.copyOf(outputFields),
                ImmutableSet.copyOf(getCollection().getOutputFields()))
            .immutableCopy().asList();
        return searchByPartition(partition, vectors.get(0), filter, mergedOutputFields);
    }

}

/**
 * Q & A系统接口。
 */
public interface IQuestionAnswering extends INlpSystem {
}

/**
 * 推荐系统接口。
 */
public interface IRecommender extends INlpSystem, ISyncer {
}

插入数据集定义

列式格式构建插入Milvus的数据集,需要注意的是,Milvus JAVA SDK 2.3.1版本并不支持列式导致dynamic fields,因此我对源码进行了改造,以支持列式插入动态字段。
这个问题,已经反馈给了社区,并且已经在v2.3.2版本中支持。

public interface MilvusData {
    interface BasicData {
    	/**
    	 * Return a list view of the splitted data, to avoid copy.
    	 */
        BasicData[] split(int splitSize);

        /**
         * Return a view of the range [start, end) data, to avoid copy.
         */
        BasicData subData(int groupId, int start, int end);

        int size();
    }

    interface EmbeddingsProducer extends Function<List<String>, ImmutableList<List<Float>>> {
    }

    @Getter
    @Setter
    @AllArgsConstructor
    @NoArgsConstructor
    class Dataset {
        private List<BasicInsertData> inserts;
    }

    abstract class GroupedBasicData implements BasicData {
        @Getter
        private final int groupId;
        @Getter
        @Setter
        @Accessors(chain = true)
        private GroupedBasicData parent;

        protected GroupedBasicData(int groupId) {
            this.groupId = groupId;
        }

        /**
         * Split the data into more more sub-dataset.
         *
         * @param groups the number of expected groups
         * @return an array of sub-dataset views from the original dataset
         */
        public abstract BasicData[] grouped(int groups);
		
		/**
		 *  每一个切分或是extract的子数据集,都应该拥有一个可以唯一标识它的ID
		 */
        public String fullGroupId() {
            if (parent == null) {
                return String.valueOf(groupId);
            }
            return parent.fullGroupId() + "-" + groupId;
        }
    }

    @Getter
    abstract class PartitionedBasicData<T> extends GroupedBasicData {
        // 每一个系统都需要指定一个分区键,因此为了能够最小化存储,这里使用一个变量
        // 保存整个数据集应该插入
        private final T partition;

        protected PartitionedBasicData(T partition, int groupId) {
            super(groupId);
            this.partition = partition;
        }

        public abstract List<T> getPartitions();
    }
    /**
     * 以列式的形式构建插入数据集,并完成数据导入到Milvus。
     * Milvus
     */
    class BasicInsertData extends PartitionedBasicData<String> {
        private final String collection;
        private final ImmutableList<String> ids;
        private final ImmutableList<String> embeddingsInput;
        private final Supplier<ImmutableList<List<Float>>> vectorsSupplier;
        private final EmbeddingsProducer embeddingsProducer;
        private ImmutableMap<String, List<?>> dynamicFields;

        private final AtomicBoolean vectorsInitialized = new AtomicBoolean(false);
}

结果集定义

public interface MilvusData {
    /**
     * 一个通用的数据集,可以保存search/query的结果,行式数据结构。
     */
    @Getter
    class BasicPredictData extends GroupedBasicData {

        @Getter
        @Builder
        @AllArgsConstructor
        public static class Row {
            @JsonProperty
            private String id;
            private String embeddingsInput;
            @JsonProperty
            private Map<String, Object> extensions;
            @JsonProperty
            private float distance;
            @JsonIgnore
            private List<Float> vector;

            public <T> T getAs(String key, Class<T> clazz) {
                return getAs(key, clazz, null);
            }

            public <T> T getAs(String key, Class<T> clazz, T defaultValue) {
                return clazz.cast(extensions.getOrDefault(key, defaultValue));
            }
        }
    }
}

数据插入实例:列式插入

这个代码示例展示了如何构建列式数据集,并将其插入Milvus的流程。

注意到这里特别演示了使用了多线程并行 插入的功能,其原因有二:

  1. 一个批次的数据集过大,Milvus无法一次快速且稳定地完成插入动作,因此需要将原始数据集进行分组,例如这里分成3个组;
  2. 通常LLM(Large language Model)的一次API调用,只能支持生成16个向量数组,因此这里又对每一个分组后的子数据集进行横向切分,产生多个Batch,每个Batch包含一条记录。
@Test
void testSyncCollections() throws ExecutionException, InterruptedException {
        ExecutorService executorService = Executors.newFixedThreadPool(2);
        // 一组唯一值,用于区别每一条数据记录
        ImmutableList<String> ids = ImmutableList.of("1", "2", "3", "4", "5");
        // 一组指标名,这些指标就是待检索的合法指标集。
        ImmutableList<String> metrics = ImmutableList.of("m1", "m2", "m3", "m4", "m5");
        // 生成一组包含5个向量的列表,对应于每一个指标名
        ImmutableList<List<Float>> vectors = generateVectors(5, TEST_COLLECTION_DIMENSION);
        // 构建插入数据集
        MilvusData.BasicInsertData data = new MilvusData.BasicInsertData(
            config.getStorages().getMilvusConf().getActiveRecommenderCollection().getName(), ids, metrics,
            //这里使用Java中的Provider接口,提供执行插入数据任务时,对指标名列表向量化,由于这里事先生成了向量数组,因此直接从索引构建数据
            messages -> messages.stream().map(metrics::indexOf).map(vectors::get).collect(toImmutableList()));

        ImmutableList<Double> randoms = ImmutableList.of(1.0d, 2.0d, 3.0d, 4.0d, 5.0d);
        // 添加动态字段及相应的数据
        data.updateDynamicFields(ImmutableMap.of("random", randoms));
        // 构建并行插入数据任务
        // 3 groups:  ([1, 2]), ([3, 4]), ([5])
        // 1 batche: ([1],[2]),([3],[4]),([5])
        CompletableFuture<Integer>[] futures = milvusService.syncMetrics(data, 3, 1, executorService);
        assertEquals(3, futures.length);
        CompletableFuture.allOf(futures).join();
        assertEquals(2, futures[0].get());
        assertEquals(2, futures[1].get());
        assertEquals(1, futures[2].get());
}

相似性检索实例:指标推荐

用户输入一个指标(Metric)名,或是包含指标名的语句,可以通过Milvus的Search接口,找到最相近的TOP K指标,前提是需要对输入指标名进行向量化,然后以此向量来r从Milvus库中既存的指标集中计算找到最相似的。

    @Test
    void testLookingForMetric() {
        int topK = config..getMilvusConf().getActiveRecommenderCollection().getTopk();
        Optional<MilvusData.BasicPredictData> metrics = milvusService.getActiveRecommenderSys().map(system -> system.predict("销售总额"));
        assertTrue(metrics.isPresent());
        assertEquals(metrics.get().getRows().size(), topK);

        metrics = milvusService.getActiveRecommenderSys().map(system -> system.predict("销售总额", lt("random", 0f)));
        assertFalse(metrics.isPresent());
    }

相似性检索实例:问答

用户输入一段描述,可以通过Milvus提供的Search接口,找到历史相关的问题,并返回与此问题相关的上下文,并辅助回答AI模型回答用户的当前问题。

    @Test
    void testSearchWithSimilarityOfMultiVectors() {
        ImmutableList<String> testQuestions = ImmutableList.of("用柱形图展示2019年12月的总销售额", "用拆线图展示2020年12月的总净利润");
        ImmutableList<String> testIds = testQuestions.stream()
            .map(DefaultQuestionAnswering::encodeQuestion)
            .collect(ImmutableList.toImmutableList());
        IEmbedding embeddingSvc = aiService.getMilvusService().get().getActiveQuestionAnswering().get();
        DefaultQuestionAnswering.QAInsertData insertData = system.getInsertDataBuilder()
            .ids(testIds)
            .questions(testQuestions)
            .answers(ImmutableList.of("很好", "不错"))
            // 用户对于此问题返回结果的评价
            .scores(ImmutableList.of(1.0f, 1.0f))
            // 定义embeddings生成器,在插入时才会计算embeddings
            .embeddingsProducer(
                questions -> {
                    ImmutableList<List<Float>> qvectors = embeddingSvc.getEmbeddings(questions);
                    ImmutableList<List<Float>> mvectors = embeddingSvc.getEmbeddings(ImmutableList.of("总销售额", "总净利润"));
                    return ImmutableList.of(merge(qvectors.get(0), mvectors.get(0)), merge(qvectors.get(1), mvectors.get(1)));
                })
            .build();
        system.insert(new MilvusData.Dataset(ImmutableList.of(insertData)));

        // Case 1:
        // 用柱形图展示2019年12月的总销售额: 52.0181
        // 用拆线图展示2020年12月的总净利润: 147.0664
        verifySearch(system, "用拆线图展示2020年12月的总销售额", "总销售额", 0, "销售额", this::merge);

        // Case 2:
        // 用柱形图展示2019年12月的总销售额: 313.70105
        // 用拆线图展示2020年12月的总净利润: 181.3783
        verifySearch(system, "今年5月的净利润详情", "净利润", 0, "利润", this::merge);

        // Case 3:
        // 用柱形图展示2019年12月的总销售额: 160.30568
        // 用拆线图展示2020年12月的总净利润: 357.7008
        verifySearch(system, "今年5月的销售额详情", "销售额", 0, "销售额", this::merge);
    }

总结

Milvus对上层提供了与传统数据库相似的接口,以管理Milvus数据,同时提供了带有过滤功能的数据检索接口,使得上层应用能够很方便地利用传统数据库思维,来设计 和实现自己的系统。
但在使用中也感受到一些局限性或可能提升的点:

  1. 库中的一行记录只能对应一个embedding vector:只能使用相同模型生成的vector才能更好地检索向量,如果想一处持编码的文本对应多个vectors是不可能的,用户不得不创建新的Collection存储相同文本的不同向量。
  2. 用户显示Flush/Load Collection:每一次更新数据集,客户端必须要显示地load collection的操作,才能将新的数据加载到Server结点的内存中,同时第一次加载Collection必须是全量。
  3. 粗糙的表达式字符串:对于API接口的使用,缺少便利的表达式类定义,只能传递字符串,很容易出错,只能在运行时才知道哪些出错了。
  4. 缓存特性的支持:通常Milvus被用作Cache角色被引入系统中,但Milvus缺少一些缓存特性,如过期自动清理、partial dataset的load/unload功能等。

你可能感兴趣的:(大语言模型,数据存储,milvus,java,相似性搜索)