feat: 改造query_llm_with_grounding支持多轮对话
- 扩展RAG Grounding数据模型支持会话管理 - 添加多轮对话的RAG查询方法query_llm_with_grounding_multi_turn - 集成现有会话管理服务,支持历史消息存储和检索 - 更新前端类型定义和服务层,支持多轮对话参数 - 创建多轮RAG对话测试组件和页面 - 支持系统提示词、历史消息数量控制等配置选项 - 保持向后兼容,单轮对话功能不受影响 核心功能: 多轮RAG对话 - 基于检索增强生成的多轮对话 会话历史管理 - 自动保存和加载对话历史 智能检索增强 - 结合Vertex AI Search的知识检索 上下文保持 - 在多轮对话中保持对话上下文 灵活配置 - 支持历史消息数量、系统提示词等配置 来源追踪 - 显示检索来源和相关性信息 遵循promptx/tauri-desktop-app-expert开发规范
This commit is contained in:
parent
8b92cc130c
commit
0b42ea8dcc
|
|
@ -11,6 +11,13 @@ use reqwest::multipart;
|
|||
use crate::infrastructure::tolerant_json_parser::{TolerantJsonParser, ParserConfig, RecoveryStrategy};
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
// 导入会话管理相关模块
|
||||
use crate::data::models::conversation::{
|
||||
ConversationMessage, MessageRole, MessageContent, ConversationHistoryQuery,
|
||||
CreateConversationSessionRequest, AddMessageRequest,
|
||||
};
|
||||
use crate::data::repositories::conversation_repository::ConversationRepository;
|
||||
|
||||
/// Gemini API配置
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct GeminiConfig {
|
||||
|
|
@ -180,6 +187,9 @@ pub struct RagGroundingRequest {
|
|||
pub user_input: String,
|
||||
pub config: Option<RagGroundingConfig>,
|
||||
pub session_id: Option<String>,
|
||||
pub include_history: Option<bool>,
|
||||
pub max_history_messages: Option<u32>,
|
||||
pub system_prompt: Option<String>,
|
||||
}
|
||||
|
||||
/// RAG Grounding 响应
|
||||
|
|
@ -189,6 +199,9 @@ pub struct RagGroundingResponse {
|
|||
pub grounding_metadata: Option<GroundingMetadata>,
|
||||
pub response_time_ms: u64,
|
||||
pub model_used: String,
|
||||
pub session_id: Option<String>,
|
||||
pub message_id: Option<String>,
|
||||
pub conversation_context: Option<ConversationContext>,
|
||||
}
|
||||
|
||||
/// Grounding 元数据
|
||||
|
|
@ -206,6 +219,14 @@ pub struct GroundingSource {
|
|||
pub content: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
/// 对话上下文
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct ConversationContext {
|
||||
pub total_messages: u32,
|
||||
pub history_included: bool,
|
||||
pub context_length: u32,
|
||||
}
|
||||
|
||||
/// Vertex AI Search 工具配置
|
||||
#[derive(Debug, Serialize)]
|
||||
struct VertexAISearchTool {
|
||||
|
|
@ -903,6 +924,17 @@ impl GeminiService {
|
|||
|
||||
/// RAG Grounding 查询 (参考 RAGUtils.py 中的 query_llm_with_grounding)
|
||||
pub async fn query_llm_with_grounding(&mut self, request: RagGroundingRequest) -> Result<RagGroundingResponse> {
|
||||
// 如果请求包含会话管理参数,使用多轮对话版本
|
||||
if request.session_id.is_some() || request.include_history.unwrap_or(false) {
|
||||
return self.query_llm_with_grounding_multi_turn(request, None).await;
|
||||
}
|
||||
|
||||
// 否则使用原有的单轮对话逻辑
|
||||
self.query_llm_with_grounding_single_turn(request).await
|
||||
}
|
||||
|
||||
/// 单轮RAG Grounding查询(原有逻辑)
|
||||
async fn query_llm_with_grounding_single_turn(&mut self, request: RagGroundingRequest) -> Result<RagGroundingResponse> {
|
||||
let start_time = std::time::Instant::now();
|
||||
println!("🔍 开始RAG Grounding查询: {}", request.user_input);
|
||||
|
||||
|
|
@ -976,6 +1008,9 @@ impl GeminiService {
|
|||
grounding_metadata: response.grounding_metadata,
|
||||
response_time_ms: elapsed.as_millis() as u64,
|
||||
model_used: rag_config.model_id,
|
||||
session_id: None,
|
||||
message_id: None,
|
||||
conversation_context: None,
|
||||
});
|
||||
}
|
||||
Err(e) => {
|
||||
|
|
@ -990,6 +1025,267 @@ impl GeminiService {
|
|||
Err(anyhow!("RAG Grounding查询失败,已重试{}次: {}", self.config.max_retries, last_error.unwrap()))
|
||||
}
|
||||
|
||||
/// 多轮RAG Grounding查询(支持会话历史)
|
||||
pub async fn query_llm_with_grounding_multi_turn(
|
||||
&mut self,
|
||||
request: RagGroundingRequest,
|
||||
conversation_repo: Option<Arc<ConversationRepository>>,
|
||||
) -> Result<RagGroundingResponse> {
|
||||
let start_time = std::time::Instant::now();
|
||||
println!("🔍 开始多轮RAG Grounding查询: {}", request.user_input);
|
||||
|
||||
// 获取配置
|
||||
let rag_config = request.config.unwrap_or_default();
|
||||
|
||||
// 1. 确定或创建会话
|
||||
let session_id = match &request.session_id {
|
||||
Some(id) => id.clone(),
|
||||
None => {
|
||||
// 如果没有提供session_id,生成一个新的
|
||||
uuid::Uuid::new_v4().to_string()
|
||||
}
|
||||
};
|
||||
|
||||
// 2. 获取历史消息(如果需要且有会话仓库)
|
||||
let mut contents = Vec::new();
|
||||
let mut conversation_context = ConversationContext {
|
||||
total_messages: 0,
|
||||
history_included: false,
|
||||
context_length: 0,
|
||||
};
|
||||
|
||||
if let (Some(repo), true) = (&conversation_repo, request.include_history.unwrap_or(false)) {
|
||||
match self.get_conversation_history(&repo, &session_id, request.max_history_messages.unwrap_or(10)).await {
|
||||
Ok(history_messages) => {
|
||||
conversation_context.total_messages = history_messages.len() as u32;
|
||||
conversation_context.history_included = true;
|
||||
|
||||
// 添加系统提示(如果有)
|
||||
if let Some(system_prompt) = &request.system_prompt {
|
||||
contents.push(ContentPart {
|
||||
role: "system".to_string(),
|
||||
parts: vec![Part::Text { text: system_prompt.clone() }],
|
||||
});
|
||||
} else if let Some(default_prompt) = &rag_config.system_prompt {
|
||||
contents.push(ContentPart {
|
||||
role: "system".to_string(),
|
||||
parts: vec![Part::Text { text: default_prompt.clone() }],
|
||||
});
|
||||
}
|
||||
|
||||
// 添加历史消息
|
||||
for msg in &history_messages {
|
||||
let parts = self.convert_message_content_to_parts(&msg.content)?;
|
||||
contents.push(ContentPart {
|
||||
role: msg.role.to_string(),
|
||||
parts,
|
||||
});
|
||||
}
|
||||
|
||||
conversation_context.context_length = contents.len() as u32;
|
||||
println!("📚 加载了 {} 条历史消息", history_messages.len());
|
||||
}
|
||||
Err(e) => {
|
||||
println!("⚠️ 获取历史消息失败: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 3. 添加当前用户消息
|
||||
contents.push(ContentPart {
|
||||
role: "user".to_string(),
|
||||
parts: vec![Part::Text { text: request.user_input.clone() }],
|
||||
});
|
||||
|
||||
// 4. 执行RAG查询
|
||||
let response = self.execute_rag_grounding_with_contents(contents, rag_config).await?;
|
||||
|
||||
// 5. 保存消息到会话历史(如果有会话仓库)
|
||||
let message_id = if let Some(repo) = &conversation_repo {
|
||||
match self.save_conversation_messages(&repo, &session_id, &request.user_input, &response.answer).await {
|
||||
Ok(id) => Some(id),
|
||||
Err(e) => {
|
||||
println!("⚠️ 保存会话消息失败: {}", e);
|
||||
None
|
||||
}
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let elapsed = start_time.elapsed();
|
||||
println!("✅ 多轮RAG Grounding查询完成,耗时: {:?}", elapsed);
|
||||
|
||||
Ok(RagGroundingResponse {
|
||||
answer: response.answer,
|
||||
grounding_metadata: response.grounding_metadata,
|
||||
response_time_ms: elapsed.as_millis() as u64,
|
||||
model_used: response.model_used,
|
||||
session_id: Some(session_id),
|
||||
message_id,
|
||||
conversation_context: Some(conversation_context),
|
||||
})
|
||||
}
|
||||
|
||||
/// 获取会话历史消息
|
||||
async fn get_conversation_history(
|
||||
&self,
|
||||
repo: &Arc<ConversationRepository>,
|
||||
session_id: &str,
|
||||
max_messages: u32,
|
||||
) -> Result<Vec<ConversationMessage>> {
|
||||
let query = ConversationHistoryQuery {
|
||||
session_id: session_id.to_string(),
|
||||
limit: Some(max_messages),
|
||||
offset: None,
|
||||
include_system_messages: Some(false),
|
||||
};
|
||||
|
||||
match repo.get_conversation_history(query) {
|
||||
Ok(history) => Ok(history.messages),
|
||||
Err(e) => Err(anyhow!("获取会话历史失败: {}", e)),
|
||||
}
|
||||
}
|
||||
|
||||
/// 保存会话消息
|
||||
async fn save_conversation_messages(
|
||||
&self,
|
||||
repo: &Arc<ConversationRepository>,
|
||||
session_id: &str,
|
||||
user_input: &str,
|
||||
assistant_response: &str,
|
||||
) -> Result<String> {
|
||||
// 保存用户消息
|
||||
let user_message_request = AddMessageRequest {
|
||||
session_id: session_id.to_string(),
|
||||
role: MessageRole::User,
|
||||
content: vec![MessageContent::Text { text: user_input.to_string() }],
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
repo.add_message(user_message_request)?;
|
||||
|
||||
// 保存助手回复
|
||||
let assistant_message_request = AddMessageRequest {
|
||||
session_id: session_id.to_string(),
|
||||
role: MessageRole::Assistant,
|
||||
content: vec![MessageContent::Text { text: assistant_response.to_string() }],
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
let assistant_message = repo.add_message(assistant_message_request)?;
|
||||
Ok(assistant_message.id)
|
||||
}
|
||||
|
||||
/// 执行带有完整contents的RAG查询
|
||||
async fn execute_rag_grounding_with_contents(
|
||||
&mut self,
|
||||
contents: Vec<ContentPart>,
|
||||
rag_config: RagGroundingConfig,
|
||||
) -> Result<RagGroundingResponse> {
|
||||
// 获取访问令牌
|
||||
let access_token = self.get_access_token().await?;
|
||||
|
||||
// 创建客户端配置
|
||||
let client_config = self.create_gemini_client(&access_token);
|
||||
|
||||
// 构建数据存储路径
|
||||
let datastore_path = format!(
|
||||
"projects/{}/locations/{}/collections/default_collection/dataStores/{}",
|
||||
rag_config.project_id,
|
||||
rag_config.location,
|
||||
rag_config.data_store_id
|
||||
);
|
||||
|
||||
// 构建工具配置 (Vertex AI Search)
|
||||
let tools = vec![VertexAISearchTool {
|
||||
retrieval: VertexAIRetrieval {
|
||||
vertex_ai_search: VertexAISearchConfig {
|
||||
datastore: datastore_path,
|
||||
},
|
||||
},
|
||||
}];
|
||||
|
||||
// 构建生成配置
|
||||
let generation_config = GenerationConfig {
|
||||
temperature: rag_config.temperature,
|
||||
top_k: 32,
|
||||
top_p: 1.0,
|
||||
max_output_tokens: rag_config.max_output_tokens,
|
||||
};
|
||||
|
||||
// 准备请求数据
|
||||
let request_data = serde_json::json!({
|
||||
"contents": contents,
|
||||
"tools": tools,
|
||||
"generationConfig": generation_config,
|
||||
"toolConfig": {
|
||||
"functionCallingConfig": {
|
||||
"mode": "ANY",
|
||||
"allowedFunctionNames": ["retrieval"]
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// 发送请求到Cloudflare Gateway (使用beta API)
|
||||
let generate_url = format!("{}/{}:generateContent", client_config.gateway_url, rag_config.model_id);
|
||||
|
||||
// 重试机制
|
||||
let mut last_error = None;
|
||||
for attempt in 0..self.config.max_retries {
|
||||
match self.send_rag_grounding_request(&generate_url, &client_config, &request_data).await {
|
||||
Ok(response) => {
|
||||
return Ok(RagGroundingResponse {
|
||||
answer: response.answer,
|
||||
grounding_metadata: response.grounding_metadata,
|
||||
response_time_ms: 0, // 这里会在调用方设置
|
||||
model_used: rag_config.model_id,
|
||||
session_id: None,
|
||||
message_id: None,
|
||||
conversation_context: None,
|
||||
});
|
||||
}
|
||||
Err(e) => {
|
||||
last_error = Some(e);
|
||||
if attempt < self.config.max_retries - 1 {
|
||||
tokio::time::sleep(tokio::time::Duration::from_secs(self.config.retry_delay)).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Err(anyhow!("RAG Grounding查询失败,已重试{}次: {}", self.config.max_retries, last_error.unwrap()))
|
||||
}
|
||||
|
||||
/// 将消息内容转换为Gemini API的Part格式
|
||||
fn convert_message_content_to_parts(&self, content: &[MessageContent]) -> Result<Vec<Part>> {
|
||||
let mut parts = Vec::new();
|
||||
for item in content {
|
||||
match item {
|
||||
MessageContent::Text { text } => {
|
||||
parts.push(Part::Text { text: text.clone() });
|
||||
}
|
||||
MessageContent::File { file_uri, mime_type, .. } => {
|
||||
parts.push(Part::FileData {
|
||||
file_data: FileData {
|
||||
mime_type: mime_type.clone(),
|
||||
file_uri: file_uri.clone(),
|
||||
}
|
||||
});
|
||||
}
|
||||
MessageContent::InlineData { data, mime_type, .. } => {
|
||||
parts.push(Part::InlineData {
|
||||
inline_data: InlineData {
|
||||
mime_type: mime_type.clone(),
|
||||
data: data.clone(),
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(parts)
|
||||
}
|
||||
|
||||
/// 发送RAG Grounding请求
|
||||
async fn send_rag_grounding_request(
|
||||
&self,
|
||||
|
|
@ -1067,6 +1363,9 @@ impl GeminiService {
|
|||
grounding_metadata,
|
||||
response_time_ms: 0, // 将在调用方设置
|
||||
model_used: "gemini-2.5-flash".to_string(),
|
||||
session_id: None,
|
||||
message_id: None,
|
||||
conversation_context: None,
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -4,9 +4,10 @@ use crate::app_state::AppState;
|
|||
|
||||
/// RAG Grounding 查询命令
|
||||
/// 基于数据存储的检索增强生成,参考 RAGUtils.py 中的 query_llm_with_grounding 实现
|
||||
/// 现已支持多轮对话功能
|
||||
#[command]
|
||||
pub async fn query_rag_grounding(
|
||||
_state: State<'_, AppState>,
|
||||
state: State<'_, AppState>,
|
||||
request: RagGroundingRequest,
|
||||
) -> Result<RagGroundingResponse, String> {
|
||||
println!("🔍 收到RAG Grounding查询请求: {}", request.user_input);
|
||||
|
|
@ -16,14 +17,37 @@ pub async fn query_rag_grounding(
|
|||
let mut gemini_service = GeminiService::new(Some(config))
|
||||
.map_err(|e| format!("创建GeminiService失败: {}", e))?;
|
||||
|
||||
// 如果请求包含会话管理参数,获取会话仓库
|
||||
let conversation_repo = if request.session_id.is_some() || request.include_history.unwrap_or(false) {
|
||||
match state.inner().get_conversation_repository() {
|
||||
Ok(repo) => Some(repo),
|
||||
Err(e) => {
|
||||
println!("⚠️ 获取会话仓库失败,将使用单轮对话模式: {}", e);
|
||||
None
|
||||
}
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// 执行RAG Grounding查询
|
||||
let response = gemini_service
|
||||
.query_llm_with_grounding(request)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
eprintln!("RAG Grounding查询失败: {}", e);
|
||||
format!("RAG Grounding查询失败: {}", e)
|
||||
})?;
|
||||
let response = if conversation_repo.is_some() {
|
||||
gemini_service
|
||||
.query_llm_with_grounding_multi_turn(request, conversation_repo)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
eprintln!("多轮RAG Grounding查询失败: {}", e);
|
||||
format!("多轮RAG Grounding查询失败: {}", e)
|
||||
})?
|
||||
} else {
|
||||
gemini_service
|
||||
.query_llm_with_grounding(request)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
eprintln!("RAG Grounding查询失败: {}", e);
|
||||
format!("RAG Grounding查询失败: {}", e)
|
||||
})?
|
||||
};
|
||||
|
||||
println!("✅ RAG Grounding查询成功,响应时间: {}ms", response.response_time_ms);
|
||||
|
||||
|
|
@ -47,6 +71,9 @@ pub async fn test_rag_grounding_connection(
|
|||
user_input: "测试连接".to_string(),
|
||||
config: None,
|
||||
session_id: Some("test-session".to_string()),
|
||||
include_history: Some(false),
|
||||
max_history_messages: None,
|
||||
system_prompt: None,
|
||||
};
|
||||
|
||||
// 执行测试查询
|
||||
|
|
|
|||
|
|
@ -0,0 +1,334 @@
|
|||
import React, { useState, useCallback, useRef, useEffect } from 'react';
|
||||
import { queryRagGrounding } from '../services/ragGroundingService';
|
||||
import {
|
||||
RagGroundingQueryOptions,
|
||||
RagGroundingResponse,
|
||||
GroundingSource,
|
||||
ConversationContext,
|
||||
} from '../types/ragGrounding';
|
||||
|
||||
/**
|
||||
* 多轮RAG对话测试组件
|
||||
* 支持基于检索增强生成的多轮对话功能
|
||||
*/
|
||||
|
||||
interface ChatMessage {
|
||||
id: string;
|
||||
type: 'user' | 'assistant';
|
||||
content: string;
|
||||
timestamp: Date;
|
||||
status: 'sending' | 'sent' | 'error';
|
||||
metadata?: {
|
||||
responseTime?: number;
|
||||
modelUsed?: string;
|
||||
groundingSources?: GroundingSource[];
|
||||
conversationContext?: ConversationContext;
|
||||
};
|
||||
}
|
||||
|
||||
export const MultiTurnRagChatTest: React.FC = () => {
|
||||
const [messages, setMessages] = useState<ChatMessage[]>([]);
|
||||
const [input, setInput] = useState('');
|
||||
const [isLoading, setIsLoading] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const [sessionId, setSessionId] = useState<string | null>(null);
|
||||
const [showHistory, setShowHistory] = useState(true);
|
||||
const [maxHistoryMessages, setMaxHistoryMessages] = useState(10);
|
||||
const [systemPrompt, setSystemPrompt] = useState('你是一个专业的服装搭配顾问,基于检索到的相关信息为用户提供准确、实用的搭配建议。');
|
||||
const [showGroundingSources, setShowGroundingSources] = useState(true);
|
||||
|
||||
const messagesEndRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
// 自动滚动到底部
|
||||
const scrollToBottom = useCallback(() => {
|
||||
messagesEndRef.current?.scrollIntoView({ behavior: 'smooth' });
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
scrollToBottom();
|
||||
}, [messages, scrollToBottom]);
|
||||
|
||||
// 生成消息ID
|
||||
const generateMessageId = () => {
|
||||
return `msg_${Date.now()}_${Math.random().toString(36).substring(2, 11)}`;
|
||||
};
|
||||
|
||||
// 发送消息
|
||||
const handleSendMessage = useCallback(async () => {
|
||||
if (!input.trim() || isLoading) return;
|
||||
|
||||
const userMessage: ChatMessage = {
|
||||
id: generateMessageId(),
|
||||
type: 'user',
|
||||
content: input.trim(),
|
||||
timestamp: new Date(),
|
||||
status: 'sent'
|
||||
};
|
||||
|
||||
const assistantMessage: ChatMessage = {
|
||||
id: generateMessageId(),
|
||||
type: 'assistant',
|
||||
content: '',
|
||||
timestamp: new Date(),
|
||||
status: 'sending'
|
||||
};
|
||||
|
||||
// 添加用户消息和占位助手消息
|
||||
setMessages(prev => [...prev, userMessage, assistantMessage]);
|
||||
setInput('');
|
||||
setIsLoading(true);
|
||||
setError(null);
|
||||
|
||||
try {
|
||||
// 构建查询选项
|
||||
const options: RagGroundingQueryOptions = {
|
||||
sessionId: sessionId || undefined,
|
||||
includeHistory: showHistory,
|
||||
maxHistoryMessages: maxHistoryMessages,
|
||||
systemPrompt: systemPrompt.trim() || undefined,
|
||||
includeMetadata: showGroundingSources,
|
||||
};
|
||||
|
||||
// 调用RAG Grounding服务
|
||||
const result = await queryRagGrounding(userMessage.content, options);
|
||||
|
||||
if (result.success && result.data) {
|
||||
// 更新会话ID
|
||||
if (!sessionId && result.data.session_id) {
|
||||
setSessionId(result.data.session_id);
|
||||
}
|
||||
|
||||
// 更新助手消息
|
||||
setMessages(prev => prev.map(msg =>
|
||||
msg.id === assistantMessage.id
|
||||
? {
|
||||
...msg,
|
||||
content: result.data!.answer,
|
||||
status: 'sent' as const,
|
||||
metadata: {
|
||||
responseTime: result.data!.response_time_ms,
|
||||
modelUsed: result.data!.model_used,
|
||||
groundingSources: result.data!.grounding_metadata?.sources,
|
||||
conversationContext: result.data!.conversation_context,
|
||||
}
|
||||
}
|
||||
: msg
|
||||
));
|
||||
} else {
|
||||
// 处理错误
|
||||
setError(result.error || '查询失败');
|
||||
setMessages(prev => prev.map(msg =>
|
||||
msg.id === assistantMessage.id
|
||||
? { ...msg, content: '抱歉,查询过程中发生了错误', status: 'error' as const }
|
||||
: msg
|
||||
));
|
||||
}
|
||||
} catch (err) {
|
||||
const errorMessage = err instanceof Error ? err.message : '未知错误';
|
||||
setError(errorMessage);
|
||||
setMessages(prev => prev.map(msg =>
|
||||
msg.id === assistantMessage.id
|
||||
? { ...msg, content: '抱歉,发生了系统错误', status: 'error' as const }
|
||||
: msg
|
||||
));
|
||||
} finally {
|
||||
setIsLoading(false);
|
||||
}
|
||||
}, [input, isLoading, sessionId, showHistory, maxHistoryMessages, systemPrompt, showGroundingSources]);
|
||||
|
||||
// 清空对话
|
||||
const handleClearChat = useCallback(() => {
|
||||
setMessages([]);
|
||||
setSessionId(null);
|
||||
setError(null);
|
||||
}, []);
|
||||
|
||||
// 处理键盘事件
|
||||
const handleKeyPress = useCallback((e: React.KeyboardEvent) => {
|
||||
if (e.key === 'Enter' && !e.shiftKey) {
|
||||
e.preventDefault();
|
||||
handleSendMessage();
|
||||
}
|
||||
}, [handleSendMessage]);
|
||||
|
||||
// 格式化响应时间
|
||||
const formatResponseTime = (ms: number) => {
|
||||
if (ms < 1000) return `${ms}ms`;
|
||||
return `${(ms / 1000).toFixed(1)}s`;
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="flex flex-col h-full max-w-6xl mx-auto p-4">
|
||||
{/* 标题和设置 */}
|
||||
<div className="mb-4">
|
||||
<h2 className="text-2xl font-bold mb-2">多轮RAG对话测试</h2>
|
||||
<div className="grid grid-cols-1 md:grid-cols-2 gap-4 text-sm">
|
||||
<div className="space-y-2">
|
||||
<div className="flex items-center gap-2">
|
||||
<label className="flex items-center gap-1">
|
||||
<input
|
||||
type="checkbox"
|
||||
checked={showHistory}
|
||||
onChange={(e) => setShowHistory(e.target.checked)}
|
||||
className="rounded"
|
||||
/>
|
||||
包含历史消息
|
||||
</label>
|
||||
</div>
|
||||
<div className="flex items-center gap-2">
|
||||
<label>最大历史消息数:</label>
|
||||
<input
|
||||
type="number"
|
||||
value={maxHistoryMessages}
|
||||
onChange={(e) => setMaxHistoryMessages(parseInt(e.target.value) || 10)}
|
||||
min="1"
|
||||
max="50"
|
||||
className="w-16 px-2 py-1 border rounded"
|
||||
/>
|
||||
</div>
|
||||
<div className="flex items-center gap-2">
|
||||
<label className="flex items-center gap-1">
|
||||
<input
|
||||
type="checkbox"
|
||||
checked={showGroundingSources}
|
||||
onChange={(e) => setShowGroundingSources(e.target.checked)}
|
||||
className="rounded"
|
||||
/>
|
||||
显示检索来源
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<div className="text-gray-600">
|
||||
会话ID: {sessionId ? sessionId.substring(0, 8) + '...' : '未创建'}
|
||||
</div>
|
||||
<button
|
||||
onClick={handleClearChat}
|
||||
className="px-3 py-1 bg-gray-500 text-white rounded hover:bg-gray-600"
|
||||
>
|
||||
清空对话
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* 系统提示词设置 */}
|
||||
<div className="mt-4">
|
||||
<label className="block text-sm font-medium mb-1">系统提示词:</label>
|
||||
<textarea
|
||||
value={systemPrompt}
|
||||
onChange={(e) => setSystemPrompt(e.target.value)}
|
||||
placeholder="设置系统提示词..."
|
||||
className="w-full px-3 py-2 border rounded-lg resize-none"
|
||||
rows={2}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* 错误提示 */}
|
||||
{error && (
|
||||
<div className="mb-4 p-3 bg-red-100 border border-red-400 text-red-700 rounded">
|
||||
错误: {error}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* 消息列表 */}
|
||||
<div className="flex-1 overflow-y-auto border rounded-lg p-4 mb-4 bg-gray-50">
|
||||
{messages.length === 0 ? (
|
||||
<div className="text-center text-gray-500 py-8">
|
||||
开始RAG对话吧!这是一个基于检索增强生成的多轮对话测试界面。
|
||||
</div>
|
||||
) : (
|
||||
<div className="space-y-4">
|
||||
{messages.map((message) => (
|
||||
<div
|
||||
key={message.id}
|
||||
className={`flex ${message.type === 'user' ? 'justify-end' : 'justify-start'}`}
|
||||
>
|
||||
<div
|
||||
className={`max-w-2xl px-4 py-2 rounded-lg ${
|
||||
message.type === 'user'
|
||||
? 'bg-blue-500 text-white'
|
||||
: message.status === 'error'
|
||||
? 'bg-red-100 text-red-800 border border-red-300'
|
||||
: 'bg-white border border-gray-300'
|
||||
}`}
|
||||
>
|
||||
<div className="whitespace-pre-wrap">{message.content}</div>
|
||||
|
||||
{/* 消息元数据 */}
|
||||
<div className="text-xs mt-2 opacity-70">
|
||||
{message.timestamp.toLocaleTimeString()}
|
||||
{message.metadata?.responseTime && (
|
||||
<span className="ml-2">
|
||||
({formatResponseTime(message.metadata.responseTime)})
|
||||
</span>
|
||||
)}
|
||||
{message.status === 'sending' && (
|
||||
<span className="ml-2">发送中...</span>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* 检索来源 */}
|
||||
{message.metadata?.groundingSources && message.metadata.groundingSources.length > 0 && (
|
||||
<div className="mt-2 pt-2 border-t border-gray-200">
|
||||
<div className="text-xs font-medium mb-1">检索来源:</div>
|
||||
<div className="space-y-1">
|
||||
{message.metadata.groundingSources.map((source, index) => (
|
||||
<div key={index} className="text-xs bg-gray-100 p-2 rounded">
|
||||
<div className="font-medium">{source.title}</div>
|
||||
{source.uri && (
|
||||
<div className="text-blue-600 truncate">{source.uri}</div>
|
||||
)}
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* 对话上下文信息 */}
|
||||
{message.metadata?.conversationContext && (
|
||||
<div className="mt-2 pt-2 border-t border-gray-200 text-xs text-gray-600">
|
||||
上下文: {message.metadata.conversationContext.total_messages} 条消息
|
||||
{message.metadata.conversationContext.history_included &&
|
||||
` (包含 ${message.metadata.conversationContext.context_length} 条历史)`
|
||||
}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
<div ref={messagesEndRef} />
|
||||
</div>
|
||||
|
||||
{/* 输入区域 */}
|
||||
<div className="flex gap-2">
|
||||
<textarea
|
||||
value={input}
|
||||
onChange={(e) => setInput(e.target.value)}
|
||||
onKeyPress={handleKeyPress}
|
||||
placeholder="输入消息... (Enter发送,Shift+Enter换行)"
|
||||
className="flex-1 px-3 py-2 border rounded-lg resize-none"
|
||||
rows={2}
|
||||
disabled={isLoading}
|
||||
/>
|
||||
<button
|
||||
onClick={handleSendMessage}
|
||||
disabled={!input.trim() || isLoading}
|
||||
className="px-6 py-2 bg-blue-500 text-white rounded-lg hover:bg-blue-600 disabled:bg-gray-300 disabled:cursor-not-allowed"
|
||||
>
|
||||
{isLoading ? '查询中...' : '发送'}
|
||||
</button>
|
||||
</div>
|
||||
|
||||
{/* 统计信息 */}
|
||||
<div className="mt-2 text-xs text-gray-500 text-center">
|
||||
消息数: {messages.length} |
|
||||
历史消息: {showHistory ? '开启' : '关闭'} |
|
||||
最大历史: {maxHistoryMessages}条 |
|
||||
检索来源: {showGroundingSources ? '显示' : '隐藏'}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
|
@ -0,0 +1,49 @@
|
|||
import React from 'react';
|
||||
import { MultiTurnRagChatTest } from '../components/MultiTurnRagChatTest';
|
||||
|
||||
/**
|
||||
* 多轮RAG对话测试页面
|
||||
* 基于检索增强生成的多轮对话功能测试界面
|
||||
*/
|
||||
export const MultiTurnRagChatTestPage: React.FC = () => {
|
||||
return (
|
||||
<div className="h-screen flex flex-col">
|
||||
{/* 页面头部 */}
|
||||
<header className="bg-white border-b border-gray-200 px-6 py-4">
|
||||
<div className="flex items-center justify-between">
|
||||
<div>
|
||||
<h1 className="text-2xl font-bold text-gray-900">多轮RAG对话功能测试</h1>
|
||||
<p className="text-sm text-gray-600 mt-1">
|
||||
基于检索增强生成(RAG)的多轮对话系统,支持会话历史管理和智能检索
|
||||
</p>
|
||||
</div>
|
||||
<div className="text-sm text-gray-500">
|
||||
<div>版本: v0.2.2</div>
|
||||
<div>模型: Gemini 2.5 Flash</div>
|
||||
<div>检索: Vertex AI Search</div>
|
||||
</div>
|
||||
</div>
|
||||
</header>
|
||||
|
||||
{/* 主要内容区域 */}
|
||||
<main className="flex-1 overflow-hidden">
|
||||
<MultiTurnRagChatTest />
|
||||
</main>
|
||||
|
||||
{/* 页面底部 */}
|
||||
<footer className="bg-gray-50 border-t border-gray-200 px-6 py-3">
|
||||
<div className="flex items-center justify-between text-sm text-gray-600">
|
||||
<div>
|
||||
基于 Tauri + React + TypeScript 构建的RAG多轮对话系统
|
||||
</div>
|
||||
<div className="flex items-center gap-4">
|
||||
<span>🔍 智能检索</span>
|
||||
<span>💬 多轮对话</span>
|
||||
<span>📚 知识增强</span>
|
||||
<span>🔄 上下文保持</span>
|
||||
</div>
|
||||
</div>
|
||||
</footer>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
|
@ -60,6 +60,9 @@ export class RagGroundingService {
|
|||
user_input: userInput,
|
||||
config,
|
||||
session_id: options.sessionId,
|
||||
include_history: options.includeHistory,
|
||||
max_history_messages: options.maxHistoryMessages,
|
||||
system_prompt: options.systemPrompt,
|
||||
};
|
||||
|
||||
console.log('🔍 发起RAG Grounding查询:', { userInput, options });
|
||||
|
|
|
|||
|
|
@ -23,6 +23,9 @@ export interface RagGroundingRequest {
|
|||
user_input: string;
|
||||
config?: RagGroundingConfig;
|
||||
session_id?: string;
|
||||
include_history?: boolean;
|
||||
max_history_messages?: number;
|
||||
system_prompt?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -33,6 +36,9 @@ export interface RagGroundingResponse {
|
|||
grounding_metadata?: GroundingMetadata;
|
||||
response_time_ms: number;
|
||||
model_used: string;
|
||||
session_id?: string;
|
||||
message_id?: string;
|
||||
conversation_context?: ConversationContext;
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -52,6 +58,15 @@ export interface GroundingSource {
|
|||
content?: any; // JSON 内容
|
||||
}
|
||||
|
||||
/**
|
||||
* 对话上下文
|
||||
*/
|
||||
export interface ConversationContext {
|
||||
total_messages: number;
|
||||
history_included: boolean;
|
||||
context_length: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* RAG Grounding 配置信息
|
||||
*/
|
||||
|
|
@ -93,6 +108,12 @@ export interface RagGroundingQueryOptions {
|
|||
customConfig?: Partial<RagGroundingConfig>;
|
||||
/** 超时时间(毫秒) */
|
||||
timeout?: number;
|
||||
/** 是否包含历史消息 */
|
||||
includeHistory?: boolean;
|
||||
/** 最大历史消息数 */
|
||||
maxHistoryMessages?: number;
|
||||
/** 系统提示词 */
|
||||
systemPrompt?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
|||
Loading…
Reference in New Issue