feat: 完善多轮对话功能实现
- 修复编译错误和类型问题 - 添加完整的单元测试套件 - 创建多轮对话测试组件和页面 - 添加详细的功能文档和使用说明 - 优化错误处理和类型安全 - 遵循promptx开发规范的完整实现 功能特性: 多轮对话支持 - 历史消息传递和上下文保持 会话管理 - session_id管理和生命周期控制 数据持久化 - 完整的会话历史存储 类型安全 - 完整的TypeScript类型定义 测试覆盖 - 单元测试和集成测试 文档完善 - 详细的实现文档和使用指南
This commit is contained in:
parent
5296039785
commit
8b92cc130c
|
|
@ -272,3 +272,8 @@ impl ConversationService {
|
|||
Ok("新对话".to_string())
|
||||
}
|
||||
}
|
||||
|
||||
// 导入测试文件
|
||||
#[cfg(test)]
|
||||
#[path = "conversation_service_test.rs"]
|
||||
mod conversation_service_test;
|
||||
|
|
|
|||
|
|
@ -0,0 +1,224 @@
|
|||
#[cfg(test)]
|
||||
mod conversation_service_tests {
|
||||
use super::*;
|
||||
use crate::data::repositories::conversation_repository::ConversationRepository;
|
||||
use crate::data::models::conversation::{
|
||||
CreateConversationSessionRequest, AddMessageRequest, ConversationHistoryQuery,
|
||||
MessageRole, MessageContent,
|
||||
};
|
||||
use crate::infrastructure::database::Database;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
/// 创建测试用的会话服务
|
||||
fn create_test_conversation_service() -> ConversationService {
|
||||
let database = Arc::new(Database::new().expect("Failed to create test database"));
|
||||
let connection = database.get_connection();
|
||||
let repository = Arc::new(ConversationRepository::new(connection));
|
||||
|
||||
// 初始化数据库表
|
||||
repository.initialize_tables().expect("Failed to initialize tables");
|
||||
|
||||
ConversationService::new(repository)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_create_session() {
|
||||
let service = create_test_conversation_service();
|
||||
|
||||
let request = CreateConversationSessionRequest {
|
||||
title: Some("测试会话".to_string()),
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
let result = service.create_session(request).await;
|
||||
assert!(result.is_ok());
|
||||
|
||||
let session = result.unwrap();
|
||||
assert_eq!(session.title, Some("测试会话".to_string()));
|
||||
assert!(session.is_active);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_add_message() {
|
||||
let service = create_test_conversation_service();
|
||||
|
||||
// 创建会话
|
||||
let session_request = CreateConversationSessionRequest {
|
||||
title: Some("测试会话".to_string()),
|
||||
metadata: None,
|
||||
};
|
||||
let session = service.create_session(session_request).await.unwrap();
|
||||
|
||||
// 添加消息
|
||||
let message_request = AddMessageRequest {
|
||||
session_id: session.id.clone(),
|
||||
role: MessageRole::User,
|
||||
content: vec![MessageContent::Text { text: "你好".to_string() }],
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
let result = service.add_message(message_request).await;
|
||||
assert!(result.is_ok());
|
||||
|
||||
let message = result.unwrap();
|
||||
assert_eq!(message.session_id, session.id);
|
||||
assert_eq!(message.role, MessageRole::User);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_conversation_history() {
|
||||
let service = create_test_conversation_service();
|
||||
|
||||
// 创建会话
|
||||
let session_request = CreateConversationSessionRequest {
|
||||
title: Some("测试会话".to_string()),
|
||||
metadata: None,
|
||||
};
|
||||
let session = service.create_session(session_request).await.unwrap();
|
||||
|
||||
// 添加多条消息
|
||||
let messages = vec![
|
||||
("你好", MessageRole::User),
|
||||
("你好!有什么可以帮助你的吗?", MessageRole::Assistant),
|
||||
("今天天气怎么样?", MessageRole::User),
|
||||
("抱歉,我无法获取实时天气信息。", MessageRole::Assistant),
|
||||
];
|
||||
|
||||
for (content, role) in messages {
|
||||
let message_request = AddMessageRequest {
|
||||
session_id: session.id.clone(),
|
||||
role,
|
||||
content: vec![MessageContent::Text { text: content.to_string() }],
|
||||
metadata: None,
|
||||
};
|
||||
service.add_message(message_request).await.unwrap();
|
||||
}
|
||||
|
||||
// 获取会话历史
|
||||
let history_query = ConversationHistoryQuery {
|
||||
session_id: session.id.clone(),
|
||||
limit: None,
|
||||
offset: None,
|
||||
include_system_messages: Some(false),
|
||||
};
|
||||
|
||||
let result = service.get_conversation_history(history_query).await;
|
||||
assert!(result.is_ok());
|
||||
|
||||
let history = result.unwrap();
|
||||
assert_eq!(history.session.id, session.id);
|
||||
assert_eq!(history.messages.len(), 4);
|
||||
assert_eq!(history.total_count, 4);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_generate_session_summary() {
|
||||
let service = create_test_conversation_service();
|
||||
|
||||
// 创建会话
|
||||
let session_request = CreateConversationSessionRequest {
|
||||
title: None,
|
||||
metadata: None,
|
||||
};
|
||||
let session = service.create_session(session_request).await.unwrap();
|
||||
|
||||
// 添加用户消息
|
||||
let message_request = AddMessageRequest {
|
||||
session_id: session.id.clone(),
|
||||
role: MessageRole::User,
|
||||
content: vec![MessageContent::Text { text: "请帮我分析一下多轮对话的实现原理".to_string() }],
|
||||
metadata: None,
|
||||
};
|
||||
service.add_message(message_request).await.unwrap();
|
||||
|
||||
// 生成会话摘要
|
||||
let result = service.generate_session_summary(&session.id).await;
|
||||
assert!(result.is_ok());
|
||||
|
||||
let summary = result.unwrap();
|
||||
assert!(!summary.is_empty());
|
||||
assert!(summary.contains("多轮对话") || summary.len() <= 50);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_conversation_stats() {
|
||||
let service = create_test_conversation_service();
|
||||
|
||||
// 创建多个会话
|
||||
for i in 0..3 {
|
||||
let session_request = CreateConversationSessionRequest {
|
||||
title: Some(format!("测试会话 {}", i + 1)),
|
||||
metadata: None,
|
||||
};
|
||||
let session = service.create_session(session_request).await.unwrap();
|
||||
|
||||
// 为每个会话添加消息
|
||||
for j in 0..2 {
|
||||
let message_request = AddMessageRequest {
|
||||
session_id: session.id.clone(),
|
||||
role: if j % 2 == 0 { MessageRole::User } else { MessageRole::Assistant },
|
||||
content: vec![MessageContent::Text { text: format!("消息 {}", j + 1) }],
|
||||
metadata: None,
|
||||
};
|
||||
service.add_message(message_request).await.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
// 获取统计信息
|
||||
let result = service.get_conversation_stats().await;
|
||||
assert!(result.is_ok());
|
||||
|
||||
let stats = result.unwrap();
|
||||
assert_eq!(stats.total_sessions, 3);
|
||||
assert_eq!(stats.active_sessions, 3);
|
||||
assert_eq!(stats.total_messages, 6);
|
||||
assert_eq!(stats.average_messages_per_session, 2.0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_session() {
|
||||
let service = create_test_conversation_service();
|
||||
|
||||
// 创建会话
|
||||
let session_request = CreateConversationSessionRequest {
|
||||
title: Some("待删除会话".to_string()),
|
||||
metadata: None,
|
||||
};
|
||||
let session = service.create_session(session_request).await.unwrap();
|
||||
|
||||
// 验证会话存在
|
||||
let get_result = service.get_session(&session.id).await;
|
||||
assert!(get_result.is_ok());
|
||||
assert!(get_result.unwrap().is_some());
|
||||
|
||||
// 删除会话
|
||||
let delete_result = service.delete_session(&session.id).await;
|
||||
assert!(delete_result.is_ok());
|
||||
|
||||
// 验证会话已被软删除(标记为非活跃)
|
||||
let get_result_after = service.get_session(&session.id).await;
|
||||
assert!(get_result_after.is_ok());
|
||||
if let Some(session_after) = get_result_after.unwrap() {
|
||||
assert!(!session_after.is_active);
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_cleanup_expired_sessions() {
|
||||
let service = create_test_conversation_service();
|
||||
|
||||
// 创建会话
|
||||
let session_request = CreateConversationSessionRequest {
|
||||
title: Some("测试会话".to_string()),
|
||||
metadata: None,
|
||||
};
|
||||
let _session = service.create_session(session_request).await.unwrap();
|
||||
|
||||
// 清理过期会话(设置为0天,应该清理所有会话)
|
||||
let result = service.cleanup_expired_sessions(0).await;
|
||||
assert!(result.is_ok());
|
||||
|
||||
let cleaned_count = result.unwrap();
|
||||
assert_eq!(cleaned_count, 1);
|
||||
}
|
||||
}
|
||||
|
|
@ -100,7 +100,7 @@ impl ConversationRepository {
|
|||
let session = stmt.query_row(params![session_id], |row| {
|
||||
match self.row_to_session(row) {
|
||||
Ok(session) => Ok(session),
|
||||
Err(e) => Err(rusqlite::Error::InvalidColumnType(0, "conversion error".to_string(), rusqlite::types::Type::Text)),
|
||||
Err(_e) => Err(rusqlite::Error::InvalidColumnType(0, "conversion error".to_string(), rusqlite::types::Type::Text)),
|
||||
}
|
||||
}).optional()?;
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue