From 0b42ea8dccebd69c33022ad8c206eff10ef1bf4f Mon Sep 17 00:00:00 2001 From: imeepos Date: Tue, 22 Jul 2025 10:56:54 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=94=B9=E9=80=A0query=5Fllm=5Fwith=5F?= =?UTF-8?q?grounding=E6=94=AF=E6=8C=81=E5=A4=9A=E8=BD=AE=E5=AF=B9=E8=AF=9D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 扩展RAG Grounding数据模型支持会话管理 - 添加多轮对话的RAG查询方法query_llm_with_grounding_multi_turn - 集成现有会话管理服务,支持历史消息存储和检索 - 更新前端类型定义和服务层,支持多轮对话参数 - 创建多轮RAG对话测试组件和页面 - 支持系统提示词、历史消息数量控制等配置选项 - 保持向后兼容,单轮对话功能不受影响 核心功能: 多轮RAG对话 - 基于检索增强生成的多轮对话 会话历史管理 - 自动保存和加载对话历史 智能检索增强 - 结合Vertex AI Search的知识检索 上下文保持 - 在多轮对话中保持对话上下文 灵活配置 - 支持历史消息数量、系统提示词等配置 来源追踪 - 显示检索来源和相关性信息 遵循promptx/tauri-desktop-app-expert开发规范 --- .../src/infrastructure/gemini_service.rs | 299 ++++++++++++++++ .../commands/rag_grounding_commands.rs | 43 ++- .../src/components/MultiTurnRagChatTest.tsx | 334 ++++++++++++++++++ .../src/pages/MultiTurnRagChatTestPage.tsx | 49 +++ .../src/services/ragGroundingService.ts | 3 + apps/desktop/src/types/ragGrounding.ts | 21 ++ 6 files changed, 741 insertions(+), 8 deletions(-) create mode 100644 apps/desktop/src/components/MultiTurnRagChatTest.tsx create mode 100644 apps/desktop/src/pages/MultiTurnRagChatTestPage.tsx diff --git a/apps/desktop/src-tauri/src/infrastructure/gemini_service.rs b/apps/desktop/src-tauri/src/infrastructure/gemini_service.rs index 850a332..73674ec 100644 --- a/apps/desktop/src-tauri/src/infrastructure/gemini_service.rs +++ b/apps/desktop/src-tauri/src/infrastructure/gemini_service.rs @@ -11,6 +11,13 @@ use reqwest::multipart; use crate::infrastructure::tolerant_json_parser::{TolerantJsonParser, ParserConfig, RecoveryStrategy}; use std::sync::{Arc, Mutex}; +// 导入会话管理相关模块 +use crate::data::models::conversation::{ + ConversationMessage, MessageRole, MessageContent, ConversationHistoryQuery, + CreateConversationSessionRequest, AddMessageRequest, +}; +use crate::data::repositories::conversation_repository::ConversationRepository; + /// Gemini API配置 #[derive(Debug, Clone)] pub struct GeminiConfig { @@ -180,6 +187,9 @@ pub struct RagGroundingRequest { pub user_input: String, pub config: Option, pub session_id: Option, + pub include_history: Option, + pub max_history_messages: Option, + pub system_prompt: Option, } /// RAG Grounding 响应 @@ -189,6 +199,9 @@ pub struct RagGroundingResponse { pub grounding_metadata: Option, pub response_time_ms: u64, pub model_used: String, + pub session_id: Option, + pub message_id: Option, + pub conversation_context: Option, } /// Grounding 元数据 @@ -206,6 +219,14 @@ pub struct GroundingSource { pub content: Option, } +/// 对话上下文 +#[derive(Debug, Serialize, Deserialize)] +pub struct ConversationContext { + pub total_messages: u32, + pub history_included: bool, + pub context_length: u32, +} + /// Vertex AI Search 工具配置 #[derive(Debug, Serialize)] struct VertexAISearchTool { @@ -903,6 +924,17 @@ impl GeminiService { /// RAG Grounding 查询 (参考 RAGUtils.py 中的 query_llm_with_grounding) pub async fn query_llm_with_grounding(&mut self, request: RagGroundingRequest) -> Result { + // 如果请求包含会话管理参数,使用多轮对话版本 + if request.session_id.is_some() || request.include_history.unwrap_or(false) { + return self.query_llm_with_grounding_multi_turn(request, None).await; + } + + // 否则使用原有的单轮对话逻辑 + self.query_llm_with_grounding_single_turn(request).await + } + + /// 单轮RAG Grounding查询(原有逻辑) + async fn query_llm_with_grounding_single_turn(&mut self, request: RagGroundingRequest) -> Result { let start_time = std::time::Instant::now(); println!("🔍 开始RAG Grounding查询: {}", request.user_input); @@ -976,6 +1008,9 @@ impl GeminiService { grounding_metadata: response.grounding_metadata, response_time_ms: elapsed.as_millis() as u64, model_used: rag_config.model_id, + session_id: None, + message_id: None, + conversation_context: None, }); } Err(e) => { @@ -990,6 +1025,267 @@ impl GeminiService { Err(anyhow!("RAG Grounding查询失败,已重试{}次: {}", self.config.max_retries, last_error.unwrap())) } + /// 多轮RAG Grounding查询(支持会话历史) + pub async fn query_llm_with_grounding_multi_turn( + &mut self, + request: RagGroundingRequest, + conversation_repo: Option>, + ) -> Result { + let start_time = std::time::Instant::now(); + println!("🔍 开始多轮RAG Grounding查询: {}", request.user_input); + + // 获取配置 + let rag_config = request.config.unwrap_or_default(); + + // 1. 确定或创建会话 + let session_id = match &request.session_id { + Some(id) => id.clone(), + None => { + // 如果没有提供session_id,生成一个新的 + uuid::Uuid::new_v4().to_string() + } + }; + + // 2. 获取历史消息(如果需要且有会话仓库) + let mut contents = Vec::new(); + let mut conversation_context = ConversationContext { + total_messages: 0, + history_included: false, + context_length: 0, + }; + + if let (Some(repo), true) = (&conversation_repo, request.include_history.unwrap_or(false)) { + match self.get_conversation_history(&repo, &session_id, request.max_history_messages.unwrap_or(10)).await { + Ok(history_messages) => { + conversation_context.total_messages = history_messages.len() as u32; + conversation_context.history_included = true; + + // 添加系统提示(如果有) + if let Some(system_prompt) = &request.system_prompt { + contents.push(ContentPart { + role: "system".to_string(), + parts: vec![Part::Text { text: system_prompt.clone() }], + }); + } else if let Some(default_prompt) = &rag_config.system_prompt { + contents.push(ContentPart { + role: "system".to_string(), + parts: vec![Part::Text { text: default_prompt.clone() }], + }); + } + + // 添加历史消息 + for msg in &history_messages { + let parts = self.convert_message_content_to_parts(&msg.content)?; + contents.push(ContentPart { + role: msg.role.to_string(), + parts, + }); + } + + conversation_context.context_length = contents.len() as u32; + println!("📚 加载了 {} 条历史消息", history_messages.len()); + } + Err(e) => { + println!("⚠️ 获取历史消息失败: {}", e); + } + } + } + + // 3. 添加当前用户消息 + contents.push(ContentPart { + role: "user".to_string(), + parts: vec![Part::Text { text: request.user_input.clone() }], + }); + + // 4. 执行RAG查询 + let response = self.execute_rag_grounding_with_contents(contents, rag_config).await?; + + // 5. 保存消息到会话历史(如果有会话仓库) + let message_id = if let Some(repo) = &conversation_repo { + match self.save_conversation_messages(&repo, &session_id, &request.user_input, &response.answer).await { + Ok(id) => Some(id), + Err(e) => { + println!("⚠️ 保存会话消息失败: {}", e); + None + } + } + } else { + None + }; + + let elapsed = start_time.elapsed(); + println!("✅ 多轮RAG Grounding查询完成,耗时: {:?}", elapsed); + + Ok(RagGroundingResponse { + answer: response.answer, + grounding_metadata: response.grounding_metadata, + response_time_ms: elapsed.as_millis() as u64, + model_used: response.model_used, + session_id: Some(session_id), + message_id, + conversation_context: Some(conversation_context), + }) + } + + /// 获取会话历史消息 + async fn get_conversation_history( + &self, + repo: &Arc, + session_id: &str, + max_messages: u32, + ) -> Result> { + let query = ConversationHistoryQuery { + session_id: session_id.to_string(), + limit: Some(max_messages), + offset: None, + include_system_messages: Some(false), + }; + + match repo.get_conversation_history(query) { + Ok(history) => Ok(history.messages), + Err(e) => Err(anyhow!("获取会话历史失败: {}", e)), + } + } + + /// 保存会话消息 + async fn save_conversation_messages( + &self, + repo: &Arc, + session_id: &str, + user_input: &str, + assistant_response: &str, + ) -> Result { + // 保存用户消息 + let user_message_request = AddMessageRequest { + session_id: session_id.to_string(), + role: MessageRole::User, + content: vec![MessageContent::Text { text: user_input.to_string() }], + metadata: None, + }; + + repo.add_message(user_message_request)?; + + // 保存助手回复 + let assistant_message_request = AddMessageRequest { + session_id: session_id.to_string(), + role: MessageRole::Assistant, + content: vec![MessageContent::Text { text: assistant_response.to_string() }], + metadata: None, + }; + + let assistant_message = repo.add_message(assistant_message_request)?; + Ok(assistant_message.id) + } + + /// 执行带有完整contents的RAG查询 + async fn execute_rag_grounding_with_contents( + &mut self, + contents: Vec, + rag_config: RagGroundingConfig, + ) -> Result { + // 获取访问令牌 + let access_token = self.get_access_token().await?; + + // 创建客户端配置 + let client_config = self.create_gemini_client(&access_token); + + // 构建数据存储路径 + let datastore_path = format!( + "projects/{}/locations/{}/collections/default_collection/dataStores/{}", + rag_config.project_id, + rag_config.location, + rag_config.data_store_id + ); + + // 构建工具配置 (Vertex AI Search) + let tools = vec![VertexAISearchTool { + retrieval: VertexAIRetrieval { + vertex_ai_search: VertexAISearchConfig { + datastore: datastore_path, + }, + }, + }]; + + // 构建生成配置 + let generation_config = GenerationConfig { + temperature: rag_config.temperature, + top_k: 32, + top_p: 1.0, + max_output_tokens: rag_config.max_output_tokens, + }; + + // 准备请求数据 + let request_data = serde_json::json!({ + "contents": contents, + "tools": tools, + "generationConfig": generation_config, + "toolConfig": { + "functionCallingConfig": { + "mode": "ANY", + "allowedFunctionNames": ["retrieval"] + } + } + }); + + // 发送请求到Cloudflare Gateway (使用beta API) + let generate_url = format!("{}/{}:generateContent", client_config.gateway_url, rag_config.model_id); + + // 重试机制 + let mut last_error = None; + for attempt in 0..self.config.max_retries { + match self.send_rag_grounding_request(&generate_url, &client_config, &request_data).await { + Ok(response) => { + return Ok(RagGroundingResponse { + answer: response.answer, + grounding_metadata: response.grounding_metadata, + response_time_ms: 0, // 这里会在调用方设置 + model_used: rag_config.model_id, + session_id: None, + message_id: None, + conversation_context: None, + }); + } + Err(e) => { + last_error = Some(e); + if attempt < self.config.max_retries - 1 { + tokio::time::sleep(tokio::time::Duration::from_secs(self.config.retry_delay)).await; + } + } + } + } + + Err(anyhow!("RAG Grounding查询失败,已重试{}次: {}", self.config.max_retries, last_error.unwrap())) + } + + /// 将消息内容转换为Gemini API的Part格式 + fn convert_message_content_to_parts(&self, content: &[MessageContent]) -> Result> { + let mut parts = Vec::new(); + for item in content { + match item { + MessageContent::Text { text } => { + parts.push(Part::Text { text: text.clone() }); + } + MessageContent::File { file_uri, mime_type, .. } => { + parts.push(Part::FileData { + file_data: FileData { + mime_type: mime_type.clone(), + file_uri: file_uri.clone(), + } + }); + } + MessageContent::InlineData { data, mime_type, .. } => { + parts.push(Part::InlineData { + inline_data: InlineData { + mime_type: mime_type.clone(), + data: data.clone(), + } + }); + } + } + } + Ok(parts) + } + /// 发送RAG Grounding请求 async fn send_rag_grounding_request( &self, @@ -1067,6 +1363,9 @@ impl GeminiService { grounding_metadata, response_time_ms: 0, // 将在调用方设置 model_used: "gemini-2.5-flash".to_string(), + session_id: None, + message_id: None, + conversation_context: None, }) } diff --git a/apps/desktop/src-tauri/src/presentation/commands/rag_grounding_commands.rs b/apps/desktop/src-tauri/src/presentation/commands/rag_grounding_commands.rs index 79695b6..ba42046 100644 --- a/apps/desktop/src-tauri/src/presentation/commands/rag_grounding_commands.rs +++ b/apps/desktop/src-tauri/src/presentation/commands/rag_grounding_commands.rs @@ -4,9 +4,10 @@ use crate::app_state::AppState; /// RAG Grounding 查询命令 /// 基于数据存储的检索增强生成,参考 RAGUtils.py 中的 query_llm_with_grounding 实现 +/// 现已支持多轮对话功能 #[command] pub async fn query_rag_grounding( - _state: State<'_, AppState>, + state: State<'_, AppState>, request: RagGroundingRequest, ) -> Result { println!("🔍 收到RAG Grounding查询请求: {}", request.user_input); @@ -16,14 +17,37 @@ pub async fn query_rag_grounding( let mut gemini_service = GeminiService::new(Some(config)) .map_err(|e| format!("创建GeminiService失败: {}", e))?; + // 如果请求包含会话管理参数,获取会话仓库 + let conversation_repo = if request.session_id.is_some() || request.include_history.unwrap_or(false) { + match state.inner().get_conversation_repository() { + Ok(repo) => Some(repo), + Err(e) => { + println!("⚠️ 获取会话仓库失败,将使用单轮对话模式: {}", e); + None + } + } + } else { + None + }; + // 执行RAG Grounding查询 - let response = gemini_service - .query_llm_with_grounding(request) - .await - .map_err(|e| { - eprintln!("RAG Grounding查询失败: {}", e); - format!("RAG Grounding查询失败: {}", e) - })?; + let response = if conversation_repo.is_some() { + gemini_service + .query_llm_with_grounding_multi_turn(request, conversation_repo) + .await + .map_err(|e| { + eprintln!("多轮RAG Grounding查询失败: {}", e); + format!("多轮RAG Grounding查询失败: {}", e) + })? + } else { + gemini_service + .query_llm_with_grounding(request) + .await + .map_err(|e| { + eprintln!("RAG Grounding查询失败: {}", e); + format!("RAG Grounding查询失败: {}", e) + })? + }; println!("✅ RAG Grounding查询成功,响应时间: {}ms", response.response_time_ms); @@ -47,6 +71,9 @@ pub async fn test_rag_grounding_connection( user_input: "测试连接".to_string(), config: None, session_id: Some("test-session".to_string()), + include_history: Some(false), + max_history_messages: None, + system_prompt: None, }; // 执行测试查询 diff --git a/apps/desktop/src/components/MultiTurnRagChatTest.tsx b/apps/desktop/src/components/MultiTurnRagChatTest.tsx new file mode 100644 index 0000000..4d7ec84 --- /dev/null +++ b/apps/desktop/src/components/MultiTurnRagChatTest.tsx @@ -0,0 +1,334 @@ +import React, { useState, useCallback, useRef, useEffect } from 'react'; +import { queryRagGrounding } from '../services/ragGroundingService'; +import { + RagGroundingQueryOptions, + RagGroundingResponse, + GroundingSource, + ConversationContext, +} from '../types/ragGrounding'; + +/** + * 多轮RAG对话测试组件 + * 支持基于检索增强生成的多轮对话功能 + */ + +interface ChatMessage { + id: string; + type: 'user' | 'assistant'; + content: string; + timestamp: Date; + status: 'sending' | 'sent' | 'error'; + metadata?: { + responseTime?: number; + modelUsed?: string; + groundingSources?: GroundingSource[]; + conversationContext?: ConversationContext; + }; +} + +export const MultiTurnRagChatTest: React.FC = () => { + const [messages, setMessages] = useState([]); + const [input, setInput] = useState(''); + const [isLoading, setIsLoading] = useState(false); + const [error, setError] = useState(null); + const [sessionId, setSessionId] = useState(null); + const [showHistory, setShowHistory] = useState(true); + const [maxHistoryMessages, setMaxHistoryMessages] = useState(10); + const [systemPrompt, setSystemPrompt] = useState('你是一个专业的服装搭配顾问,基于检索到的相关信息为用户提供准确、实用的搭配建议。'); + const [showGroundingSources, setShowGroundingSources] = useState(true); + + const messagesEndRef = useRef(null); + + // 自动滚动到底部 + const scrollToBottom = useCallback(() => { + messagesEndRef.current?.scrollIntoView({ behavior: 'smooth' }); + }, []); + + useEffect(() => { + scrollToBottom(); + }, [messages, scrollToBottom]); + + // 生成消息ID + const generateMessageId = () => { + return `msg_${Date.now()}_${Math.random().toString(36).substring(2, 11)}`; + }; + + // 发送消息 + const handleSendMessage = useCallback(async () => { + if (!input.trim() || isLoading) return; + + const userMessage: ChatMessage = { + id: generateMessageId(), + type: 'user', + content: input.trim(), + timestamp: new Date(), + status: 'sent' + }; + + const assistantMessage: ChatMessage = { + id: generateMessageId(), + type: 'assistant', + content: '', + timestamp: new Date(), + status: 'sending' + }; + + // 添加用户消息和占位助手消息 + setMessages(prev => [...prev, userMessage, assistantMessage]); + setInput(''); + setIsLoading(true); + setError(null); + + try { + // 构建查询选项 + const options: RagGroundingQueryOptions = { + sessionId: sessionId || undefined, + includeHistory: showHistory, + maxHistoryMessages: maxHistoryMessages, + systemPrompt: systemPrompt.trim() || undefined, + includeMetadata: showGroundingSources, + }; + + // 调用RAG Grounding服务 + const result = await queryRagGrounding(userMessage.content, options); + + if (result.success && result.data) { + // 更新会话ID + if (!sessionId && result.data.session_id) { + setSessionId(result.data.session_id); + } + + // 更新助手消息 + setMessages(prev => prev.map(msg => + msg.id === assistantMessage.id + ? { + ...msg, + content: result.data!.answer, + status: 'sent' as const, + metadata: { + responseTime: result.data!.response_time_ms, + modelUsed: result.data!.model_used, + groundingSources: result.data!.grounding_metadata?.sources, + conversationContext: result.data!.conversation_context, + } + } + : msg + )); + } else { + // 处理错误 + setError(result.error || '查询失败'); + setMessages(prev => prev.map(msg => + msg.id === assistantMessage.id + ? { ...msg, content: '抱歉,查询过程中发生了错误', status: 'error' as const } + : msg + )); + } + } catch (err) { + const errorMessage = err instanceof Error ? err.message : '未知错误'; + setError(errorMessage); + setMessages(prev => prev.map(msg => + msg.id === assistantMessage.id + ? { ...msg, content: '抱歉,发生了系统错误', status: 'error' as const } + : msg + )); + } finally { + setIsLoading(false); + } + }, [input, isLoading, sessionId, showHistory, maxHistoryMessages, systemPrompt, showGroundingSources]); + + // 清空对话 + const handleClearChat = useCallback(() => { + setMessages([]); + setSessionId(null); + setError(null); + }, []); + + // 处理键盘事件 + const handleKeyPress = useCallback((e: React.KeyboardEvent) => { + if (e.key === 'Enter' && !e.shiftKey) { + e.preventDefault(); + handleSendMessage(); + } + }, [handleSendMessage]); + + // 格式化响应时间 + const formatResponseTime = (ms: number) => { + if (ms < 1000) return `${ms}ms`; + return `${(ms / 1000).toFixed(1)}s`; + }; + + return ( +
+ {/* 标题和设置 */} +
+

多轮RAG对话测试

+
+
+
+ +
+
+ + setMaxHistoryMessages(parseInt(e.target.value) || 10)} + min="1" + max="50" + className="w-16 px-2 py-1 border rounded" + /> +
+
+ +
+
+
+
+ 会话ID: {sessionId ? sessionId.substring(0, 8) + '...' : '未创建'} +
+ +
+
+ + {/* 系统提示词设置 */} +
+ +