diff --git a/python_core/database/model_postgres.py b/python_core/database/model_postgres.py index 90c72de..ffcb05e 100644 --- a/python_core/database/model_postgres.py +++ b/python_core/database/model_postgres.py @@ -5,24 +5,63 @@ import uuid from datetime import datetime from typing import List, Optional, Dict, Any -import psycopg2 -from psycopg2.extras import RealDictCursor +from contextlib import contextmanager +from python_core.config import settings from python_core.database.types import Model -from python_core.database.db_postgres import DatabasePostgres -from python_core.utils.logger import logger +from python_core.utils.logger import setup_logger + +# 尝试导入 psycopg2,如果失败则提供友好的错误信息 +try: + import psycopg2 + import psycopg2.extras + PSYCOPG2_AVAILABLE = True +except ImportError as e: + PSYCOPG2_AVAILABLE = False + PSYCOPG2_ERROR = str(e) + +logger = setup_logger(__name__) -class ModelTablePostgres(DatabasePostgres): +class ModelTablePostgres: """模特表 PostgreSQL 实现""" - + def __init__(self): - super().__init__() + if not PSYCOPG2_AVAILABLE: + error_msg = f""" +PostgreSQL support requires psycopg2 package. +Please install it using: pip install psycopg2-binary + +原始错误: {PSYCOPG2_ERROR} +""" + logger.error(error_msg) + raise ImportError(error_msg) + + self.db_url = settings.db + self.table_name = "models" + + # 初始化模特表 self._init_model_table() + + @contextmanager + def _get_connection(self): + """获取数据库连接的上下文管理器""" + conn = None + try: + conn = psycopg2.connect(self.db_url) + yield conn + except Exception as e: + if conn: + conn.rollback() + logger.error(f"Database connection error: {e}") + raise e + finally: + if conn: + conn.close() def _init_model_table(self): """初始化模特表""" try: - with self.get_connection() as conn: + with self._get_connection() as conn: with conn.cursor() as cursor: # 创建模特表 cursor.execute(""" @@ -99,8 +138,8 @@ class ModelTablePostgres(DatabasePostgres): is_cloud: bool = False) -> Optional[Model]: """创建模特""" try: - with self.get_connection() as conn: - with conn.cursor(cursor_factory=RealDictCursor) as cursor: + with self._get_connection() as conn: + with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cursor: cursor.execute(""" INSERT INTO models (model_number, model_image, user_id, is_cloud) VALUES (%s, %s, %s, %s) @@ -131,8 +170,8 @@ class ModelTablePostgres(DatabasePostgres): def get_model_by_id(self, model_id: str) -> Optional[Model]: """根据ID获取模特""" try: - with self.get_connection() as conn: - with conn.cursor(cursor_factory=RealDictCursor) as cursor: + with self._get_connection() as conn: + with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cursor: cursor.execute(""" SELECT * FROM models WHERE id = %s """, (model_id,)) @@ -150,8 +189,8 @@ class ModelTablePostgres(DatabasePostgres): def get_model_by_number(self, model_number: str, user_id: str = "default") -> Optional[Model]: """根据模特编号获取模特""" try: - with self.get_connection() as conn: - with conn.cursor(cursor_factory=RealDictCursor) as cursor: + with self._get_connection() as conn: + with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cursor: cursor.execute(""" SELECT * FROM models WHERE model_number = %s AND user_id = %s @@ -172,8 +211,8 @@ class ModelTablePostgres(DatabasePostgres): offset: int = 0) -> List[Model]: """获取所有模特""" try: - with self.get_connection() as conn: - with conn.cursor(cursor_factory=RealDictCursor) as cursor: + with self._get_connection() as conn: + with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cursor: # 构建查询条件 conditions = [] params = [] @@ -224,7 +263,7 @@ class ModelTablePostgres(DatabasePostgres): logger.warning("没有有效的更新字段") return False - with self.get_connection() as conn: + with self._get_connection() as conn: with conn.cursor() as cursor: # 构建更新语句 set_clauses = [] @@ -266,7 +305,7 @@ class ModelTablePostgres(DatabasePostgres): def delete_model(self, model_id: str, hard_delete: bool = False) -> bool: """删除模特""" try: - with self.get_connection() as conn: + with self._get_connection() as conn: with conn.cursor() as cursor: if hard_delete: # 硬删除 @@ -298,31 +337,33 @@ class ModelTablePostgres(DatabasePostgres): include_cloud: bool = True, limit: int = 50) -> List[Model]: """搜索模特""" try: - with self.get_connection() as conn: - with conn.cursor(cursor_factory=RealDictCursor) as cursor: - # 构建查询条件 - conditions = ["is_active = true"] - params = [f"%{query}%"] - + with self._get_connection() as conn: + with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cursor: + # 简化搜索逻辑 + search_pattern = f"%{query}%" + if include_cloud: - conditions.append("(user_id = %s OR is_cloud = true)") - params.append(user_id) + cursor.execute(""" + SELECT * FROM models + WHERE model_number ILIKE %s + AND is_active = true + AND (user_id = %s OR is_cloud = true) + ORDER BY + CASE WHEN model_number ILIKE %s THEN 1 ELSE 2 END, + created_at DESC + LIMIT %s + """, (search_pattern, user_id, search_pattern, limit)) else: - conditions.append("user_id = %s") - params.append(user_id) - - params.append(limit) - - where_clause = " AND ".join(conditions) - - cursor.execute(f""" - SELECT * FROM models - WHERE model_number ILIKE %s AND {where_clause} - ORDER BY - CASE WHEN model_number ILIKE %s THEN 1 ELSE 2 END, - created_at DESC - LIMIT %s - """, [f"%{query}%"] + params) + cursor.execute(""" + SELECT * FROM models + WHERE model_number ILIKE %s + AND is_active = true + AND user_id = %s + ORDER BY + CASE WHEN model_number ILIKE %s THEN 1 ELSE 2 END, + created_at DESC + LIMIT %s + """, (search_pattern, user_id, search_pattern, limit)) rows = cursor.fetchall() return [self._row_to_model(row) for row in rows] @@ -337,7 +378,7 @@ class ModelTablePostgres(DatabasePostgres): include_inactive: bool = False) -> int: """获取模特数量""" try: - with self.get_connection() as conn: + with self._get_connection() as conn: with conn.cursor() as cursor: # 构建查询条件 conditions = [] @@ -371,7 +412,7 @@ class ModelTablePostgres(DatabasePostgres): def toggle_model_status(self, model_id: str) -> bool: """切换模特状态""" try: - with self.get_connection() as conn: + with self._get_connection() as conn: with conn.cursor() as cursor: cursor.execute(""" UPDATE models diff --git a/src-tauri/src/commands/model.rs b/src-tauri/src/commands/model.rs index 558a871..df682aa 100644 --- a/src-tauri/src/commands/model.rs +++ b/src-tauri/src/commands/model.rs @@ -1,101 +1,273 @@ use serde::{Deserialize, Serialize}; -use tauri::{command, AppHandle}; -use crate::python_executor::execute_python_command; +use tauri::AppHandle; +use crate::utils::python_cli::execute_python_cli_command; #[derive(Debug, Serialize, Deserialize)] -pub struct CreateModelRequest { - pub model_number: String, - pub model_image: Option, +pub struct ModelResponse { + pub success: bool, + pub data: Option, + pub error: Option, + pub message: Option, } -#[derive(Debug, Serialize, Deserialize)] -pub struct UpdateModelRequest { +#[derive(Debug, Deserialize)] +pub struct ModelCreateRequest { + pub model_number: String, + pub model_image: String, + pub user_id: Option, + pub is_cloud: Option, + pub verbose: Option, + pub json_output: Option, +} + +#[derive(Debug, Deserialize)] +pub struct ModelListRequest { + pub user_id: Option, + pub include_cloud: Option, + pub include_inactive: Option, + pub limit: Option, + pub offset: Option, + pub verbose: Option, + pub json_output: Option, +} + +#[derive(Debug, Deserialize)] +pub struct ModelGetRequest { + pub model_id: String, + pub verbose: Option, + pub json_output: Option, +} + +#[derive(Debug, Deserialize)] +pub struct ModelUpdateRequest { + pub model_id: String, pub model_number: Option, pub model_image: Option, + pub is_active: Option, + pub is_cloud: Option, + pub verbose: Option, + pub json_output: Option, } -/// 获取所有模特 -#[command] -pub async fn get_all_models(app: AppHandle) -> Result { - let args = vec![ - "-m".to_string(), - "python_core.services.model_manager".to_string(), - "get_all_models".to_string(), - ]; - - execute_python_command(app, &args, None).await +#[derive(Debug, Deserialize)] +pub struct ModelDeleteRequest { + pub model_id: String, + pub hard_delete: Option, + pub verbose: Option, + pub json_output: Option, } -/// 根据ID获取模特 -#[command] -pub async fn get_model_by_id(app: AppHandle, model_id: String) -> Result { - let args = vec![ - "-m".to_string(), - "python_core.services.model_manager".to_string(), - "get_model_by_id".to_string(), - model_id, - ]; - - execute_python_command(app, &args, None).await +#[derive(Debug, Deserialize)] +pub struct ModelSearchRequest { + pub query: String, + pub user_id: Option, + pub include_cloud: Option, + pub limit: Option, + pub verbose: Option, + pub json_output: Option, } -/// 创建新模特 -#[command] -pub async fn create_model(app: AppHandle, request: CreateModelRequest) -> Result { - let args = vec![ - "-m".to_string(), - "python_core.services.model_manager".to_string(), - "create_model".to_string(), +/// 创建模特 +#[tauri::command] +pub async fn create_model_cli( + app: AppHandle, + request: ModelCreateRequest, +) -> Result { + let mut args = vec![ + "model".to_string(), + "create".to_string(), request.model_number, - request.model_image.unwrap_or_default(), + request.model_image, ]; - - execute_python_command(app, &args, None).await + + if let Some(user_id) = request.user_id { + args.push("--user-id".to_string()); + args.push(user_id); + } + + if request.is_cloud.unwrap_or(false) { + args.push("--cloud".to_string()); + } + + if request.verbose.unwrap_or(false) { + args.push("--verbose".to_string()); + } + + if request.json_output.unwrap_or(true) { + args.push("--json".to_string()); + } + + execute_python_cli_command(app, args).await +} + +/// 获取模特列表 +#[tauri::command] +pub async fn list_models_cli( + app: AppHandle, + request: ModelListRequest, +) -> Result { + let mut args = vec!["model".to_string(), "list".to_string()]; + + if let Some(user_id) = request.user_id { + args.push("--user-id".to_string()); + args.push(user_id); + } + + if request.include_cloud.unwrap_or(true) { + args.push("--include-cloud".to_string()); + } + + if request.include_inactive.unwrap_or(false) { + args.push("--include-inactive".to_string()); + } + + if let Some(limit) = request.limit { + args.push("--limit".to_string()); + args.push(limit.to_string()); + } + + if let Some(offset) = request.offset { + args.push("--offset".to_string()); + args.push(offset.to_string()); + } + + if request.verbose.unwrap_or(false) { + args.push("--verbose".to_string()); + } + + if request.json_output.unwrap_or(true) { + args.push("--json".to_string()); + } + + execute_python_cli_command(app, args).await +} + +/// 获取模特详情 +#[tauri::command] +pub async fn get_model_cli( + app: AppHandle, + request: ModelGetRequest, +) -> Result { + let mut args = vec![ + "model".to_string(), + "get".to_string(), + request.model_id, + ]; + + if request.verbose.unwrap_or(false) { + args.push("--verbose".to_string()); + } + + if request.json_output.unwrap_or(true) { + args.push("--json".to_string()); + } + + execute_python_cli_command(app, args).await } /// 更新模特 -#[command] -pub async fn update_model( +#[tauri::command] +pub async fn update_model_cli( app: AppHandle, - model_id: String, - request: UpdateModelRequest, -) -> Result { - let request_json = serde_json::to_string(&request) - .map_err(|e| format!("Failed to serialize request: {}", e))?; - - let args = vec![ - "-m".to_string(), - "python_core.services.model_manager".to_string(), - "update_model".to_string(), - model_id, - request_json, + request: ModelUpdateRequest, +) -> Result { + let mut args = vec![ + "model".to_string(), + "update".to_string(), + request.model_id, ]; - - execute_python_command(app, &args, None).await + + if let Some(model_number) = request.model_number { + args.push("--model-number".to_string()); + args.push(model_number); + } + + if let Some(model_image) = request.model_image { + args.push("--model-image".to_string()); + args.push(model_image); + } + + if let Some(is_active) = request.is_active { + args.push("--active".to_string()); + args.push(is_active.to_string()); + } + + if let Some(is_cloud) = request.is_cloud { + args.push("--cloud".to_string()); + args.push(is_cloud.to_string()); + } + + if request.verbose.unwrap_or(false) { + args.push("--verbose".to_string()); + } + + if request.json_output.unwrap_or(true) { + args.push("--json".to_string()); + } + + execute_python_cli_command(app, args).await } /// 删除模特 -#[command] -pub async fn delete_model(app: AppHandle, model_id: String) -> Result { - let args = vec![ - "-m".to_string(), - "python_core.services.model_manager".to_string(), - "delete_model".to_string(), - model_id, +#[tauri::command] +pub async fn delete_model_cli( + app: AppHandle, + request: ModelDeleteRequest, +) -> Result { + let mut args = vec![ + "model".to_string(), + "delete".to_string(), + request.model_id, ]; - - execute_python_command(app, &args, None).await + + if request.hard_delete.unwrap_or(false) { + args.push("--hard".to_string()); + } + + if request.verbose.unwrap_or(false) { + args.push("--verbose".to_string()); + } + + if request.json_output.unwrap_or(true) { + args.push("--json".to_string()); + } + + execute_python_cli_command(app, args).await } /// 搜索模特 -#[command] -pub async fn search_models(app: AppHandle, keyword: String) -> Result { - let args = vec![ - "-m".to_string(), - "python_core.services.model_manager".to_string(), - "search_models".to_string(), - keyword, +#[tauri::command] +pub async fn search_models_cli( + app: AppHandle, + request: ModelSearchRequest, +) -> Result { + let mut args = vec![ + "model".to_string(), + "search".to_string(), + request.query, ]; - - execute_python_command(app, &args, None).await + + if let Some(user_id) = request.user_id { + args.push("--user-id".to_string()); + args.push(user_id); + } + + if request.include_cloud.unwrap_or(true) { + args.push("--include-cloud".to_string()); + } + + if let Some(limit) = request.limit { + args.push("--limit".to_string()); + args.push(limit.to_string()); + } + + if request.verbose.unwrap_or(false) { + args.push("--verbose".to_string()); + } + + if request.json_output.unwrap_or(true) { + args.push("--json".to_string()); + } + + execute_python_cli_command(app, args).await } diff --git a/src-tauri/src/lib.rs b/src-tauri/src/lib.rs index 48a5a94..e9671a5 100644 --- a/src-tauri/src/lib.rs +++ b/src-tauri/src/lib.rs @@ -73,12 +73,12 @@ pub fn run() { commands::project::scan_directory, commands::project::open_file_in_system, commands::project::delete_file, - commands::model::get_all_models, - commands::model::get_model_by_id, - commands::model::create_model, - commands::model::update_model, - commands::model::delete_model, - commands::model::search_models, + commands::model::list_models_cli, + commands::model::get_model_cli, + commands::model::create_model_cli, + commands::model::update_model_cli, + commands::model::delete_model_cli, + commands::model::search_models_cli, commands::audio::get_all_audio_files, commands::audio::get_audio_by_id, commands::audio::get_audio_by_md5, diff --git a/src/services/ModelServiceV2.ts b/src/services/ModelServiceV2.ts new file mode 100644 index 0000000..e965caf --- /dev/null +++ b/src/services/ModelServiceV2.ts @@ -0,0 +1,370 @@ +import { invoke } from '@tauri-apps/api/tauri' + +export interface Model { + id: string + model_number: string + model_image: string + is_active: boolean + is_cloud: boolean + user_id: string + created_at: string + updated_at: string +} + +export interface ModelResponse { + success: boolean + data?: any + error?: string + message?: string +} + +export interface ModelCreateRequest { + model_number: string + model_image: string + user_id?: string + is_cloud?: boolean + verbose?: boolean + json_output?: boolean +} + +export interface ModelListRequest { + user_id?: string + include_cloud?: boolean + include_inactive?: boolean + limit?: number + offset?: number + verbose?: boolean + json_output?: boolean +} + +export interface ModelGetRequest { + model_id: string + verbose?: boolean + json_output?: boolean +} + +export interface ModelUpdateRequest { + model_id: string + model_number?: string + model_image?: string + is_active?: boolean + is_cloud?: boolean + verbose?: boolean + json_output?: boolean +} + +export interface ModelDeleteRequest { + model_id: string + hard_delete?: boolean + verbose?: boolean + json_output?: boolean +} + +export interface ModelSearchRequest { + query: string + user_id?: string + include_cloud?: boolean + limit?: number + verbose?: boolean + json_output?: boolean +} + +export class ModelServiceV2 { + /** + * 获取当前用户ID + */ + private static getCurrentUserId(): string { + // 这里应该从认证服务获取当前用户ID + // 暂时返回默认值 + return 'default' + } + + /** + * 创建模特 + */ + static async createModel( + modelNumber: string, + modelImage: string, + isCloud: boolean = false, + userId?: string + ): Promise { + try { + const request: ModelCreateRequest = { + model_number: modelNumber, + model_image: modelImage, + user_id: userId || this.getCurrentUserId(), + is_cloud: isCloud, + verbose: false, + json_output: true + } + + const response = await invoke('create_model_cli', { request }) + + if (!response.success) { + throw new Error(response.error || response.message || 'Failed to create model') + } + + return response.data?.model + } catch (error) { + console.error('Create model failed:', error) + throw error + } + } + + /** + * 获取模特列表 + */ + static async getAllModels( + includeCloud: boolean = true, + includeInactive: boolean = false, + limit: number = 100, + offset: number = 0, + userId?: string + ): Promise<{ models: Model[], total_count: number }> { + try { + const request: ModelListRequest = { + user_id: userId || this.getCurrentUserId(), + include_cloud: includeCloud, + include_inactive: includeInactive, + limit: limit, + offset: offset, + verbose: false, + json_output: true + } + + const response = await invoke('list_models_cli', { request }) + + if (!response.success) { + throw new Error(response.error || response.message || 'Failed to get models') + } + + return { + models: response.data?.models || [], + total_count: response.data?.total_count || 0 + } + } catch (error) { + console.error('Get models failed:', error) + throw error + } + } + + /** + * 根据ID获取模特 + */ + static async getModelById(modelId: string): Promise { + try { + const request: ModelGetRequest = { + model_id: modelId, + verbose: false, + json_output: true + } + + const response = await invoke('get_model_cli', { request }) + + if (!response.success) { + if (response.error?.includes('不存在')) { + return null + } + throw new Error(response.error || response.message || 'Failed to get model') + } + + return response.data?.model || null + } catch (error) { + console.error('Get model failed:', error) + throw error + } + } + + /** + * 更新模特 + */ + static async updateModel( + modelId: string, + updates: { + model_number?: string + model_image?: string + is_active?: boolean + is_cloud?: boolean + } + ): Promise { + try { + const request: ModelUpdateRequest = { + model_id: modelId, + ...updates, + verbose: false, + json_output: true + } + + const response = await invoke('update_model_cli', { request }) + + if (!response.success) { + throw new Error(response.error || response.message || 'Failed to update model') + } + + return true + } catch (error) { + console.error('Update model failed:', error) + throw error + } + } + + /** + * 删除模特 + */ + static async deleteModel(modelId: string, hardDelete: boolean = false): Promise { + try { + const request: ModelDeleteRequest = { + model_id: modelId, + hard_delete: hardDelete, + verbose: false, + json_output: true + } + + const response = await invoke('delete_model_cli', { request }) + + if (!response.success) { + throw new Error(response.error || response.message || 'Failed to delete model') + } + + return true + } catch (error) { + console.error('Delete model failed:', error) + throw error + } + } + + /** + * 搜索模特 + */ + static async searchModels( + query: string, + includeCloud: boolean = true, + limit: number = 50, + userId?: string + ): Promise { + try { + const request: ModelSearchRequest = { + query: query, + user_id: userId || this.getCurrentUserId(), + include_cloud: includeCloud, + limit: limit, + verbose: false, + json_output: true + } + + const response = await invoke('search_models_cli', { request }) + + if (!response.success) { + throw new Error(response.error || response.message || 'Failed to search models') + } + + return response.data?.models || [] + } catch (error) { + console.error('Search models failed:', error) + throw error + } + } + + /** + * 切换模特状态 + */ + static async toggleModelStatus(modelId: string): Promise { + try { + // 先获取当前模特信息 + const model = await this.getModelById(modelId) + if (!model) { + throw new Error('模特不存在') + } + + // 更新状态 + return await this.updateModel(modelId, { + is_active: !model.is_active + }) + } catch (error) { + console.error('Toggle model status failed:', error) + throw error + } + } + + /** + * 获取模特数量统计 + */ + static async getModelStats( + includeCloud: boolean = true, + userId?: string + ): Promise<{ + total: number + active: number + inactive: number + cloud: number + local: number + }> { + try { + const { models } = await this.getAllModels(includeCloud, true, 1000, 0, userId) + + const stats = { + total: models.length, + active: models.filter(m => m.is_active).length, + inactive: models.filter(m => !m.is_active).length, + cloud: models.filter(m => m.is_cloud).length, + local: models.filter(m => !m.is_cloud).length + } + + return stats + } catch (error) { + console.error('Get model stats failed:', error) + throw error + } + } + + /** + * 批量删除模特 + */ + static async batchDeleteModels( + modelIds: string[], + hardDelete: boolean = false + ): Promise<{ success: string[], failed: string[] }> { + const results = { + success: [] as string[], + failed: [] as string[] + } + + for (const modelId of modelIds) { + try { + await this.deleteModel(modelId, hardDelete) + results.success.push(modelId) + } catch (error) { + console.error(`Failed to delete model ${modelId}:`, error) + results.failed.push(modelId) + } + } + + return results + } + + /** + * 验证模特编号是否唯一 + */ + static async isModelNumberUnique( + modelNumber: string, + excludeId?: string, + userId?: string + ): Promise { + try { + const models = await this.searchModels(modelNumber, true, 10, userId) + + // 检查是否有完全匹配的模特编号 + const exactMatch = models.find(m => + m.model_number === modelNumber && + m.id !== excludeId + ) + + return !exactMatch + } catch (error) { + console.error('Check model number uniqueness failed:', error) + return false + } + } +} + +export default ModelServiceV2