RAG
RAG(检索增强生成)在模型回答之前先从外部知识库检索相关文档,将文档内容注入提示词,使模型能够基于私有数据给出准确、可溯源的答案。
Spring AI 的 RAG 采用模块化流水线设计,六个阶段可独立替换:
查询转换 → 查询扩展 → 文档检索 → 文档合并 → 文档后处理 → 查询增强
1. 最小化配置
只需提供一个 DocumentRetriever 即可启用 RAG,其余阶段使用默认实现。
VectorStore vectorStore = new SimpleVectorStore(embeddingModel);
// 预先写入知识库文档
vectorStore.add(List.of(
new Document("Spring AI 1.0 于 2025 年发布,支持 OpenAI、Ollama 等 17 种模型"),
new Document("ChatClient 是 Spring AI 的核心入口,提供流式 API")
));
DocumentRetriever retriever = VectorStoreDocumentRetriever.builder()
.vectorStore(vectorStore)
.build();
ChatClient chatClient = ChatClient.builder(chatModel)
.defaultAdvisors(RetrievalAugmentationAdvisor.builder()
.documentRetriever(retriever)
.build())
.build();
String response = chatClient.prompt()
.user("Spring AI 是什么时候发布的?")
.call()
.content();
默认行为:
- 直接用原始查询检索 VectorStore
- 检索到的文档拼接后放入提示词
- 未检索到文档时回复"超出知识范围"
- 检索结果可通过
chatResponse.getMetadata().get("rag_document_context")获取
2. 文档检索
VectorStoreDocumentRetriever 是内置检索器,委托 VectorStore 执行相似度搜索。
VectorStoreDocumentRetriever retriever = VectorStoreDocumentRetriever.builder()
.vectorStore(vectorStore)
.topK(5) // 返回文档数,默认 4
.similarityThreshold(0.75) // 相似度阈值,低于此值的文档被过滤
.build();
动态过滤
通过 Query.context 在每次请求时传递过滤条件,无需重建 Advisor:
Filter.Expression tenantFilter = new Filter.Expression(
ExpressionType.EQ, new Key("tenant"), new Value("tenant-001"));
chatClient.prompt()
.user("查询订单处理流程")
.advisors(a -> a.param("vector_store_filter_expression", tenantFilter))
.call()
.content();
3. 查询转换
在检索之前对用户查询进行改写,提升检索命中率。支持链式组合。
3.1 对话压缩
将对话历史 + 追问压缩为独立查询,解决多轮对话中省略和指代问题:
QueryTransformer compressor = CompressionQueryTransformer.builder()
.chatClientBuilder(chatClientBuilder)
.build();
3.2 查询改写
将模糊或冗长的查询改写为更适合检索的短句:
QueryTransformer rewriter = RewriteQueryTransformer.builder()
.chatClientBuilder(chatClientBuilder)
.targetSearchSystem("向量数据库")
.build();
3.3 查询翻译
将非英文查询翻译为目标语言,匹配嵌入模型的训练语言:
QueryTransformer translator = TranslationQueryTransformer.builder()
.chatClientBuilder(chatClientBuilder)
.targetLanguage("english")
.build();
链式组合
多个 Transformer 按顺序依次执行:
RetrievalAugmentationAdvisor advisor = RetrievalAugmentationAdvisor.builder()
.documentRetriever(retriever)
.queryTransformers(List.of(compressor, rewriter))
.build();
4. 查询扩展
将单个查询扩展为多个语义变体,每个变体独立检索,扩大召回范围:
QueryExpander expander = MultiQueryExpander.builder()
.chatClientBuilder(chatClientBuilder)
.numberOfQueries(3) // 生成 3 个变体,默认 3
.includeOriginal(true) // 保留原始查询,默认 true
.build();
RetrievalAugmentationAdvisor advisor = RetrievalAugmentationAdvisor.builder()
.documentRetriever(retriever)
.queryExpander(expander)
.build();
5. 自定义提示词模板
ContextualQueryAugmenter 负责将检索到的文档和用户查询拼接为最终提示词,支持自定义模板:
PromptTemplate template = new PromptTemplate("""
请根据以下参考资料回答用户问题。
参考资料:
{context}
用户问题:{query}
要求:
1. 基于参考资料回答,不要编造
2. 引用资料中的具体内容
""");
QueryAugmenter augmenter = ContextualQueryAugmenter.builder()
.promptTemplate(template)
.allowEmptyContext(false) // 无文档时用空上下文模板回复
.build();
RetrievalAugmentationAdvisor advisor = RetrievalAugmentationAdvisor.builder()
.documentRetriever(retriever)
.queryAugmenter(augmenter)
.build();
6. 完整示例
完整示例:RagExample.java
public class RagExample {
public static void main(String[] args) {
// 1. 准备 EmbeddingModel 和 VectorStore
OllamaApi ollamaApi = OllamaApi.builder()
.baseUrl("http://localhost:11434")
.build();
OllamaEmbeddingModel embeddingModel = OllamaEmbeddingModel.builder()
.ollamaApi(ollamaApi)
.defaultOptions(OllamaOptions.builder().model("mxbai-embed-large").build())
.build();
SimpleVectorStore vectorStore = new SimpleVectorStore(embeddingModel);
// 写入知识库
vectorStore.add(List.of(
new Document("Spring AI 1.0 于 2025 年 5 月发布,支持 17 种 AI 模型"),
new Document("ChatClient 提供 prompt().user().call().content() 流式 API"),
new Document("RAG 通过 RetrievalAugmentationAdvisor 实现检索增强生成")
));
// 2. 创建 ChatModel 和 ChatClient
OllamaChatModel chatModel = OllamaChatModel.builder()
.ollamaApi(ollamaApi)
.defaultOptions(OllamaOptions.builder().model("qwen3:8b").build())
.build();
ChatClient chatClient = ChatClient.builder(chatModel)
.defaultAdvisors(RetrievalAugmentationAdvisor.builder()
.documentRetriever(VectorStoreDocumentRetriever.builder()
.vectorStore(vectorStore)
.topK(3)
.build())
.build())
.build();
// 3. 基于知识库提问
String response = chatClient.prompt()
.user("Spring AI 什么时候发布的?支持多少种模型?")
.call()
.content();
System.out.println(response);
// 4. 获取引用的文档
ChatResponse chatResponse = chatClient.prompt()
.user("RAG 是通过什么实现的?")
.call()
.chatResponse();
List<Document> docs = (List<Document>) chatResponse.getMetadata()
.get("rag_document_context");
System.out.println("引用文档数: " + (docs != null ? docs.size() : 0));
}
}