fix: 修复历史消息总是0的问题
- 添加ensure_session_exists方法确保会话在保存消息前存在 - 修复会话ID不匹配导致的历史消息丢失问题 - 在保存消息前自动创建会话,使用前端传递的session_id - 添加详细的调试日志跟踪会话创建和消息保存过程 - 使用INSERT OR REPLACE确保会话记录的正确性 问题原因: - 前端传递的session_id与数据库中的会话ID不匹配 - 消息保存时会话不存在,导致外键约束失败 - 历史查询时找不到对应的会话,返回空结果 现在多轮对话应该能正确保存和加载历史消息了
This commit is contained in:
parent
4293fbb4c7
commit
730b0f32b5
|
|
@ -273,7 +273,4 @@ impl ConversationService {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 导入测试文件
|
|
||||||
#[cfg(test)]
|
|
||||||
#[path = "conversation_service_test.rs"]
|
|
||||||
mod conversation_service_test;
|
|
||||||
|
|
|
||||||
|
|
@ -1,224 +0,0 @@
|
||||||
#[cfg(test)]
|
|
||||||
mod conversation_service_tests {
|
|
||||||
use super::*;
|
|
||||||
use crate::data::repositories::conversation_repository::ConversationRepository;
|
|
||||||
use crate::data::models::conversation::{
|
|
||||||
CreateConversationSessionRequest, AddMessageRequest, ConversationHistoryQuery,
|
|
||||||
MessageRole, MessageContent,
|
|
||||||
};
|
|
||||||
use crate::infrastructure::database::Database;
|
|
||||||
use std::sync::{Arc, Mutex};
|
|
||||||
|
|
||||||
/// 创建测试用的会话服务
|
|
||||||
fn create_test_conversation_service() -> ConversationService {
|
|
||||||
let database = Arc::new(Database::new().expect("Failed to create test database"));
|
|
||||||
let connection = database.get_connection();
|
|
||||||
let repository = Arc::new(ConversationRepository::new(connection));
|
|
||||||
|
|
||||||
// 初始化数据库表
|
|
||||||
repository.initialize_tables().expect("Failed to initialize tables");
|
|
||||||
|
|
||||||
ConversationService::new(repository)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_create_session() {
|
|
||||||
let service = create_test_conversation_service();
|
|
||||||
|
|
||||||
let request = CreateConversationSessionRequest {
|
|
||||||
title: Some("测试会话".to_string()),
|
|
||||||
metadata: None,
|
|
||||||
};
|
|
||||||
|
|
||||||
let result = service.create_session(request).await;
|
|
||||||
assert!(result.is_ok());
|
|
||||||
|
|
||||||
let session = result.unwrap();
|
|
||||||
assert_eq!(session.title, Some("测试会话".to_string()));
|
|
||||||
assert!(session.is_active);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_add_message() {
|
|
||||||
let service = create_test_conversation_service();
|
|
||||||
|
|
||||||
// 创建会话
|
|
||||||
let session_request = CreateConversationSessionRequest {
|
|
||||||
title: Some("测试会话".to_string()),
|
|
||||||
metadata: None,
|
|
||||||
};
|
|
||||||
let session = service.create_session(session_request).await.unwrap();
|
|
||||||
|
|
||||||
// 添加消息
|
|
||||||
let message_request = AddMessageRequest {
|
|
||||||
session_id: session.id.clone(),
|
|
||||||
role: MessageRole::User,
|
|
||||||
content: vec![MessageContent::Text { text: "你好".to_string() }],
|
|
||||||
metadata: None,
|
|
||||||
};
|
|
||||||
|
|
||||||
let result = service.add_message(message_request).await;
|
|
||||||
assert!(result.is_ok());
|
|
||||||
|
|
||||||
let message = result.unwrap();
|
|
||||||
assert_eq!(message.session_id, session.id);
|
|
||||||
assert_eq!(message.role, MessageRole::User);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_get_conversation_history() {
|
|
||||||
let service = create_test_conversation_service();
|
|
||||||
|
|
||||||
// 创建会话
|
|
||||||
let session_request = CreateConversationSessionRequest {
|
|
||||||
title: Some("测试会话".to_string()),
|
|
||||||
metadata: None,
|
|
||||||
};
|
|
||||||
let session = service.create_session(session_request).await.unwrap();
|
|
||||||
|
|
||||||
// 添加多条消息
|
|
||||||
let messages = vec![
|
|
||||||
("你好", MessageRole::User),
|
|
||||||
("你好!有什么可以帮助你的吗?", MessageRole::Assistant),
|
|
||||||
("今天天气怎么样?", MessageRole::User),
|
|
||||||
("抱歉,我无法获取实时天气信息。", MessageRole::Assistant),
|
|
||||||
];
|
|
||||||
|
|
||||||
for (content, role) in messages {
|
|
||||||
let message_request = AddMessageRequest {
|
|
||||||
session_id: session.id.clone(),
|
|
||||||
role,
|
|
||||||
content: vec![MessageContent::Text { text: content.to_string() }],
|
|
||||||
metadata: None,
|
|
||||||
};
|
|
||||||
service.add_message(message_request).await.unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
// 获取会话历史
|
|
||||||
let history_query = ConversationHistoryQuery {
|
|
||||||
session_id: session.id.clone(),
|
|
||||||
limit: None,
|
|
||||||
offset: None,
|
|
||||||
include_system_messages: Some(false),
|
|
||||||
};
|
|
||||||
|
|
||||||
let result = service.get_conversation_history(history_query).await;
|
|
||||||
assert!(result.is_ok());
|
|
||||||
|
|
||||||
let history = result.unwrap();
|
|
||||||
assert_eq!(history.session.id, session.id);
|
|
||||||
assert_eq!(history.messages.len(), 4);
|
|
||||||
assert_eq!(history.total_count, 4);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_generate_session_summary() {
|
|
||||||
let service = create_test_conversation_service();
|
|
||||||
|
|
||||||
// 创建会话
|
|
||||||
let session_request = CreateConversationSessionRequest {
|
|
||||||
title: None,
|
|
||||||
metadata: None,
|
|
||||||
};
|
|
||||||
let session = service.create_session(session_request).await.unwrap();
|
|
||||||
|
|
||||||
// 添加用户消息
|
|
||||||
let message_request = AddMessageRequest {
|
|
||||||
session_id: session.id.clone(),
|
|
||||||
role: MessageRole::User,
|
|
||||||
content: vec![MessageContent::Text { text: "请帮我分析一下多轮对话的实现原理".to_string() }],
|
|
||||||
metadata: None,
|
|
||||||
};
|
|
||||||
service.add_message(message_request).await.unwrap();
|
|
||||||
|
|
||||||
// 生成会话摘要
|
|
||||||
let result = service.generate_session_summary(&session.id).await;
|
|
||||||
assert!(result.is_ok());
|
|
||||||
|
|
||||||
let summary = result.unwrap();
|
|
||||||
assert!(!summary.is_empty());
|
|
||||||
assert!(summary.contains("多轮对话") || summary.len() <= 50);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_get_conversation_stats() {
|
|
||||||
let service = create_test_conversation_service();
|
|
||||||
|
|
||||||
// 创建多个会话
|
|
||||||
for i in 0..3 {
|
|
||||||
let session_request = CreateConversationSessionRequest {
|
|
||||||
title: Some(format!("测试会话 {}", i + 1)),
|
|
||||||
metadata: None,
|
|
||||||
};
|
|
||||||
let session = service.create_session(session_request).await.unwrap();
|
|
||||||
|
|
||||||
// 为每个会话添加消息
|
|
||||||
for j in 0..2 {
|
|
||||||
let message_request = AddMessageRequest {
|
|
||||||
session_id: session.id.clone(),
|
|
||||||
role: if j % 2 == 0 { MessageRole::User } else { MessageRole::Assistant },
|
|
||||||
content: vec![MessageContent::Text { text: format!("消息 {}", j + 1) }],
|
|
||||||
metadata: None,
|
|
||||||
};
|
|
||||||
service.add_message(message_request).await.unwrap();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 获取统计信息
|
|
||||||
let result = service.get_conversation_stats().await;
|
|
||||||
assert!(result.is_ok());
|
|
||||||
|
|
||||||
let stats = result.unwrap();
|
|
||||||
assert_eq!(stats.total_sessions, 3);
|
|
||||||
assert_eq!(stats.active_sessions, 3);
|
|
||||||
assert_eq!(stats.total_messages, 6);
|
|
||||||
assert_eq!(stats.average_messages_per_session, 2.0);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_delete_session() {
|
|
||||||
let service = create_test_conversation_service();
|
|
||||||
|
|
||||||
// 创建会话
|
|
||||||
let session_request = CreateConversationSessionRequest {
|
|
||||||
title: Some("待删除会话".to_string()),
|
|
||||||
metadata: None,
|
|
||||||
};
|
|
||||||
let session = service.create_session(session_request).await.unwrap();
|
|
||||||
|
|
||||||
// 验证会话存在
|
|
||||||
let get_result = service.get_session(&session.id).await;
|
|
||||||
assert!(get_result.is_ok());
|
|
||||||
assert!(get_result.unwrap().is_some());
|
|
||||||
|
|
||||||
// 删除会话
|
|
||||||
let delete_result = service.delete_session(&session.id).await;
|
|
||||||
assert!(delete_result.is_ok());
|
|
||||||
|
|
||||||
// 验证会话已被软删除(标记为非活跃)
|
|
||||||
let get_result_after = service.get_session(&session.id).await;
|
|
||||||
assert!(get_result_after.is_ok());
|
|
||||||
if let Some(session_after) = get_result_after.unwrap() {
|
|
||||||
assert!(!session_after.is_active);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_cleanup_expired_sessions() {
|
|
||||||
let service = create_test_conversation_service();
|
|
||||||
|
|
||||||
// 创建会话
|
|
||||||
let session_request = CreateConversationSessionRequest {
|
|
||||||
title: Some("测试会话".to_string()),
|
|
||||||
metadata: None,
|
|
||||||
};
|
|
||||||
let _session = service.create_session(session_request).await.unwrap();
|
|
||||||
|
|
||||||
// 清理过期会话(设置为0天,应该清理所有会话)
|
|
||||||
let result = service.cleanup_expired_sessions(0).await;
|
|
||||||
assert!(result.is_ok());
|
|
||||||
|
|
||||||
let cleaned_count = result.unwrap();
|
|
||||||
assert_eq!(cleaned_count, 1);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -131,6 +131,39 @@ impl ConversationRepository {
|
||||||
Ok(session)
|
Ok(session)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// 确保会话存在,如果不存在则创建
|
||||||
|
pub fn ensure_session_exists(&self, session_id: &str) -> Result<()> {
|
||||||
|
match self.get_session(session_id)? {
|
||||||
|
Some(_) => {
|
||||||
|
println!("✅ 会话已存在: {}", session_id);
|
||||||
|
Ok(())
|
||||||
|
},
|
||||||
|
None => {
|
||||||
|
println!("🆕 会话不存在,创建新会话: {}", session_id);
|
||||||
|
|
||||||
|
let conn = self.database.get_connection();
|
||||||
|
let conn = conn.lock().unwrap();
|
||||||
|
|
||||||
|
// 直接插入会话记录,使用指定的session_id
|
||||||
|
conn.execute(
|
||||||
|
"INSERT OR REPLACE INTO conversation_sessions (id, title, created_at, updated_at, is_active, metadata)
|
||||||
|
VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
|
||||||
|
params![
|
||||||
|
session_id,
|
||||||
|
"RAG对话会话",
|
||||||
|
chrono::Utc::now().to_rfc3339(),
|
||||||
|
chrono::Utc::now().to_rfc3339(),
|
||||||
|
true,
|
||||||
|
None::<String>
|
||||||
|
],
|
||||||
|
)?;
|
||||||
|
|
||||||
|
println!("✅ 新会话创建成功: {}", session_id);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// 添加消息到会话
|
/// 添加消息到会话
|
||||||
pub fn add_message(&self, request: AddMessageRequest) -> Result<ConversationMessage> {
|
pub fn add_message(&self, request: AddMessageRequest) -> Result<ConversationMessage> {
|
||||||
let message = ConversationMessage::new_mixed_message(
|
let message = ConversationMessage::new_mixed_message(
|
||||||
|
|
|
||||||
|
|
@ -1195,6 +1195,11 @@ impl GeminiService {
|
||||||
user_input: &str,
|
user_input: &str,
|
||||||
assistant_response: &str,
|
assistant_response: &str,
|
||||||
) -> Result<String> {
|
) -> Result<String> {
|
||||||
|
println!("💾 开始保存会话消息,session_id: {}", session_id);
|
||||||
|
|
||||||
|
// 确保会话存在
|
||||||
|
repo.ensure_session_exists(session_id)?;
|
||||||
|
|
||||||
// 保存用户消息
|
// 保存用户消息
|
||||||
let user_message_request = AddMessageRequest {
|
let user_message_request = AddMessageRequest {
|
||||||
session_id: session_id.to_string(),
|
session_id: session_id.to_string(),
|
||||||
|
|
@ -1203,7 +1208,8 @@ impl GeminiService {
|
||||||
metadata: None,
|
metadata: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
repo.add_message(user_message_request)?;
|
let user_message = repo.add_message(user_message_request)?;
|
||||||
|
println!("✅ 用户消息保存成功,ID: {}", user_message.id);
|
||||||
|
|
||||||
// 保存助手回复
|
// 保存助手回复
|
||||||
let assistant_message_request = AddMessageRequest {
|
let assistant_message_request = AddMessageRequest {
|
||||||
|
|
@ -1214,6 +1220,9 @@ impl GeminiService {
|
||||||
};
|
};
|
||||||
|
|
||||||
let assistant_message = repo.add_message(assistant_message_request)?;
|
let assistant_message = repo.add_message(assistant_message_request)?;
|
||||||
|
println!("✅ 助手消息保存成功,ID: {}", assistant_message.id);
|
||||||
|
println!("💾 会话消息保存完成,session_id: {}", session_id);
|
||||||
|
|
||||||
Ok(assistant_message.id)
|
Ok(assistant_message.id)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue