跳到主要内容
版本:1.0.5

RAG

RAG(检索增强生成)在模型回答之前先从外部知识库检索相关文档,将文档内容注入提示词,使模型能够基于私有数据给出准确、可溯源的答案。

Spring AI 的 RAG 采用模块化流水线设计,六个阶段可独立替换:

查询转换 → 查询扩展 → 文档检索 → 文档合并 → 文档后处理 → 查询增强

1. 最小化配置

只需提供一个 DocumentRetriever 即可启用 RAG,其余阶段使用默认实现。

VectorStore vectorStore = new SimpleVectorStore(embeddingModel);
// 预先写入知识库文档
vectorStore.add(List.of(
new Document("Spring AI 1.0 于 2025 年发布,支持 OpenAI、Ollama 等 17 种模型"),
new Document("ChatClient 是 Spring AI 的核心入口,提供流式 API")
));

DocumentRetriever retriever = VectorStoreDocumentRetriever.builder()
.vectorStore(vectorStore)
.build();

ChatClient chatClient = ChatClient.builder(chatModel)
.defaultAdvisors(RetrievalAugmentationAdvisor.builder()
.documentRetriever(retriever)
.build())
.build();

String response = chatClient.prompt()
.user("Spring AI 是什么时候发布的?")
.call()
.content();

默认行为:

  • 直接用原始查询检索 VectorStore
  • 检索到的文档拼接后放入提示词
  • 未检索到文档时回复"超出知识范围"
  • 检索结果可通过 chatResponse.getMetadata().get("rag_document_context") 获取

2. 文档检索

VectorStoreDocumentRetriever 是内置检索器,委托 VectorStore 执行相似度搜索。

VectorStoreDocumentRetriever retriever = VectorStoreDocumentRetriever.builder()
.vectorStore(vectorStore)
.topK(5) // 返回文档数,默认 4
.similarityThreshold(0.75) // 相似度阈值,低于此值的文档被过滤
.build();

动态过滤

通过 Query.context 在每次请求时传递过滤条件,无需重建 Advisor:

Filter.Expression tenantFilter = new Filter.Expression(
ExpressionType.EQ, new Key("tenant"), new Value("tenant-001"));

chatClient.prompt()
.user("查询订单处理流程")
.advisors(a -> a.param("vector_store_filter_expression", tenantFilter))
.call()
.content();

3. 查询转换

在检索之前对用户查询进行改写,提升检索命中率。支持链式组合。

3.1 对话压缩

将对话历史 + 追问压缩为独立查询,解决多轮对话中省略和指代问题:

QueryTransformer compressor = CompressionQueryTransformer.builder()
.chatClientBuilder(chatClientBuilder)
.build();

3.2 查询改写

将模糊或冗长的查询改写为更适合检索的短句:

QueryTransformer rewriter = RewriteQueryTransformer.builder()
.chatClientBuilder(chatClientBuilder)
.targetSearchSystem("向量数据库")
.build();

3.3 查询翻译

将非英文查询翻译为目标语言,匹配嵌入模型的训练语言:

QueryTransformer translator = TranslationQueryTransformer.builder()
.chatClientBuilder(chatClientBuilder)
.targetLanguage("english")
.build();

链式组合

多个 Transformer 按顺序依次执行:

RetrievalAugmentationAdvisor advisor = RetrievalAugmentationAdvisor.builder()
.documentRetriever(retriever)
.queryTransformers(List.of(compressor, rewriter))
.build();

4. 查询扩展

将单个查询扩展为多个语义变体,每个变体独立检索,扩大召回范围:

QueryExpander expander = MultiQueryExpander.builder()
.chatClientBuilder(chatClientBuilder)
.numberOfQueries(3) // 生成 3 个变体,默认 3
.includeOriginal(true) // 保留原始查询,默认 true
.build();

RetrievalAugmentationAdvisor advisor = RetrievalAugmentationAdvisor.builder()
.documentRetriever(retriever)
.queryExpander(expander)
.build();

5. 自定义提示词模板

ContextualQueryAugmenter 负责将检索到的文档和用户查询拼接为最终提示词,支持自定义模板:

PromptTemplate template = new PromptTemplate("""
请根据以下参考资料回答用户问题。

参考资料:
{context}

用户问题:{query}

要求:
1. 基于参考资料回答,不要编造
2. 引用资料中的具体内容
""");

QueryAugmenter augmenter = ContextualQueryAugmenter.builder()
.promptTemplate(template)
.allowEmptyContext(false) // 无文档时用空上下文模板回复
.build();

RetrievalAugmentationAdvisor advisor = RetrievalAugmentationAdvisor.builder()
.documentRetriever(retriever)
.queryAugmenter(augmenter)
.build();

6. 完整示例

完整示例:RagExample.java
public class RagExample {

public static void main(String[] args) {
// 1. 准备 EmbeddingModel 和 VectorStore
OllamaApi ollamaApi = OllamaApi.builder()
.baseUrl("http://localhost:11434")
.build();
OllamaEmbeddingModel embeddingModel = OllamaEmbeddingModel.builder()
.ollamaApi(ollamaApi)
.defaultOptions(OllamaOptions.builder().model("mxbai-embed-large").build())
.build();
SimpleVectorStore vectorStore = new SimpleVectorStore(embeddingModel);

// 写入知识库
vectorStore.add(List.of(
new Document("Spring AI 1.0 于 2025 年 5 月发布,支持 17 种 AI 模型"),
new Document("ChatClient 提供 prompt().user().call().content() 流式 API"),
new Document("RAG 通过 RetrievalAugmentationAdvisor 实现检索增强生成")
));

// 2. 创建 ChatModel 和 ChatClient
OllamaChatModel chatModel = OllamaChatModel.builder()
.ollamaApi(ollamaApi)
.defaultOptions(OllamaOptions.builder().model("qwen3:8b").build())
.build();

ChatClient chatClient = ChatClient.builder(chatModel)
.defaultAdvisors(RetrievalAugmentationAdvisor.builder()
.documentRetriever(VectorStoreDocumentRetriever.builder()
.vectorStore(vectorStore)
.topK(3)
.build())
.build())
.build();

// 3. 基于知识库提问
String response = chatClient.prompt()
.user("Spring AI 什么时候发布的?支持多少种模型?")
.call()
.content();
System.out.println(response);

// 4. 获取引用的文档
ChatResponse chatResponse = chatClient.prompt()
.user("RAG 是通过什么实现的?")
.call()
.chatResponse();
List<Document> docs = (List<Document>) chatResponse.getMetadata()
.get("rag_document_context");
System.out.println("引用文档数: " + (docs != null ? docs.size() : 0));
}
}