fix: 模特管理

This commit is contained in:
root 2025-07-13 10:56:28 +08:00
parent 59a99f21c9
commit e69cb49e67
4 changed files with 704 additions and 121 deletions

View File

@ -5,24 +5,63 @@
import uuid
from datetime import datetime
from typing import List, Optional, Dict, Any
import psycopg2
from psycopg2.extras import RealDictCursor
from contextlib import contextmanager
from python_core.config import settings
from python_core.database.types import Model
from python_core.database.db_postgres import DatabasePostgres
from python_core.utils.logger import logger
from python_core.utils.logger import setup_logger
# 尝试导入 psycopg2如果失败则提供友好的错误信息
try:
import psycopg2
import psycopg2.extras
PSYCOPG2_AVAILABLE = True
except ImportError as e:
PSYCOPG2_AVAILABLE = False
PSYCOPG2_ERROR = str(e)
logger = setup_logger(__name__)
class ModelTablePostgres(DatabasePostgres):
class ModelTablePostgres:
"""模特表 PostgreSQL 实现"""
def __init__(self):
super().__init__()
if not PSYCOPG2_AVAILABLE:
error_msg = f"""
PostgreSQL support requires psycopg2 package.
Please install it using: pip install psycopg2-binary
原始错误: {PSYCOPG2_ERROR}
"""
logger.error(error_msg)
raise ImportError(error_msg)
self.db_url = settings.db
self.table_name = "models"
# 初始化模特表
self._init_model_table()
@contextmanager
def _get_connection(self):
"""获取数据库连接的上下文管理器"""
conn = None
try:
conn = psycopg2.connect(self.db_url)
yield conn
except Exception as e:
if conn:
conn.rollback()
logger.error(f"Database connection error: {e}")
raise e
finally:
if conn:
conn.close()
def _init_model_table(self):
"""初始化模特表"""
try:
with self.get_connection() as conn:
with self._get_connection() as conn:
with conn.cursor() as cursor:
# 创建模特表
cursor.execute("""
@ -99,8 +138,8 @@ class ModelTablePostgres(DatabasePostgres):
is_cloud: bool = False) -> Optional[Model]:
"""创建模特"""
try:
with self.get_connection() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
with self._get_connection() as conn:
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cursor:
cursor.execute("""
INSERT INTO models (model_number, model_image, user_id, is_cloud)
VALUES (%s, %s, %s, %s)
@ -131,8 +170,8 @@ class ModelTablePostgres(DatabasePostgres):
def get_model_by_id(self, model_id: str) -> Optional[Model]:
"""根据ID获取模特"""
try:
with self.get_connection() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
with self._get_connection() as conn:
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cursor:
cursor.execute("""
SELECT * FROM models WHERE id = %s
""", (model_id,))
@ -150,8 +189,8 @@ class ModelTablePostgres(DatabasePostgres):
def get_model_by_number(self, model_number: str, user_id: str = "default") -> Optional[Model]:
"""根据模特编号获取模特"""
try:
with self.get_connection() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
with self._get_connection() as conn:
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cursor:
cursor.execute("""
SELECT * FROM models
WHERE model_number = %s AND user_id = %s
@ -172,8 +211,8 @@ class ModelTablePostgres(DatabasePostgres):
offset: int = 0) -> List[Model]:
"""获取所有模特"""
try:
with self.get_connection() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
with self._get_connection() as conn:
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cursor:
# 构建查询条件
conditions = []
params = []
@ -224,7 +263,7 @@ class ModelTablePostgres(DatabasePostgres):
logger.warning("没有有效的更新字段")
return False
with self.get_connection() as conn:
with self._get_connection() as conn:
with conn.cursor() as cursor:
# 构建更新语句
set_clauses = []
@ -266,7 +305,7 @@ class ModelTablePostgres(DatabasePostgres):
def delete_model(self, model_id: str, hard_delete: bool = False) -> bool:
"""删除模特"""
try:
with self.get_connection() as conn:
with self._get_connection() as conn:
with conn.cursor() as cursor:
if hard_delete:
# 硬删除
@ -298,31 +337,33 @@ class ModelTablePostgres(DatabasePostgres):
include_cloud: bool = True, limit: int = 50) -> List[Model]:
"""搜索模特"""
try:
with self.get_connection() as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
# 构建查询条件
conditions = ["is_active = true"]
params = [f"%{query}%"]
with self._get_connection() as conn:
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cursor:
# 简化搜索逻辑
search_pattern = f"%{query}%"
if include_cloud:
conditions.append("(user_id = %s OR is_cloud = true)")
params.append(user_id)
cursor.execute("""
SELECT * FROM models
WHERE model_number ILIKE %s
AND is_active = true
AND (user_id = %s OR is_cloud = true)
ORDER BY
CASE WHEN model_number ILIKE %s THEN 1 ELSE 2 END,
created_at DESC
LIMIT %s
""", (search_pattern, user_id, search_pattern, limit))
else:
conditions.append("user_id = %s")
params.append(user_id)
params.append(limit)
where_clause = " AND ".join(conditions)
cursor.execute(f"""
SELECT * FROM models
WHERE model_number ILIKE %s AND {where_clause}
ORDER BY
CASE WHEN model_number ILIKE %s THEN 1 ELSE 2 END,
created_at DESC
LIMIT %s
""", [f"%{query}%"] + params)
cursor.execute("""
SELECT * FROM models
WHERE model_number ILIKE %s
AND is_active = true
AND user_id = %s
ORDER BY
CASE WHEN model_number ILIKE %s THEN 1 ELSE 2 END,
created_at DESC
LIMIT %s
""", (search_pattern, user_id, search_pattern, limit))
rows = cursor.fetchall()
return [self._row_to_model(row) for row in rows]
@ -337,7 +378,7 @@ class ModelTablePostgres(DatabasePostgres):
include_inactive: bool = False) -> int:
"""获取模特数量"""
try:
with self.get_connection() as conn:
with self._get_connection() as conn:
with conn.cursor() as cursor:
# 构建查询条件
conditions = []
@ -371,7 +412,7 @@ class ModelTablePostgres(DatabasePostgres):
def toggle_model_status(self, model_id: str) -> bool:
"""切换模特状态"""
try:
with self.get_connection() as conn:
with self._get_connection() as conn:
with conn.cursor() as cursor:
cursor.execute("""
UPDATE models

View File

@ -1,101 +1,273 @@
use serde::{Deserialize, Serialize};
use tauri::{command, AppHandle};
use crate::python_executor::execute_python_command;
use tauri::AppHandle;
use crate::utils::python_cli::execute_python_cli_command;
#[derive(Debug, Serialize, Deserialize)]
pub struct CreateModelRequest {
pub model_number: String,
pub model_image: Option<String>,
pub struct ModelResponse {
pub success: bool,
pub data: Option<serde_json::Value>,
pub error: Option<String>,
pub message: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct UpdateModelRequest {
#[derive(Debug, Deserialize)]
pub struct ModelCreateRequest {
pub model_number: String,
pub model_image: String,
pub user_id: Option<String>,
pub is_cloud: Option<bool>,
pub verbose: Option<bool>,
pub json_output: Option<bool>,
}
#[derive(Debug, Deserialize)]
pub struct ModelListRequest {
pub user_id: Option<String>,
pub include_cloud: Option<bool>,
pub include_inactive: Option<bool>,
pub limit: Option<i32>,
pub offset: Option<i32>,
pub verbose: Option<bool>,
pub json_output: Option<bool>,
}
#[derive(Debug, Deserialize)]
pub struct ModelGetRequest {
pub model_id: String,
pub verbose: Option<bool>,
pub json_output: Option<bool>,
}
#[derive(Debug, Deserialize)]
pub struct ModelUpdateRequest {
pub model_id: String,
pub model_number: Option<String>,
pub model_image: Option<String>,
pub is_active: Option<bool>,
pub is_cloud: Option<bool>,
pub verbose: Option<bool>,
pub json_output: Option<bool>,
}
/// 获取所有模特
#[command]
pub async fn get_all_models(app: AppHandle) -> Result<String, String> {
let args = vec![
"-m".to_string(),
"python_core.services.model_manager".to_string(),
"get_all_models".to_string(),
];
execute_python_command(app, &args, None).await
#[derive(Debug, Deserialize)]
pub struct ModelDeleteRequest {
pub model_id: String,
pub hard_delete: Option<bool>,
pub verbose: Option<bool>,
pub json_output: Option<bool>,
}
/// 根据ID获取模特
#[command]
pub async fn get_model_by_id(app: AppHandle, model_id: String) -> Result<String, String> {
let args = vec![
"-m".to_string(),
"python_core.services.model_manager".to_string(),
"get_model_by_id".to_string(),
model_id,
];
execute_python_command(app, &args, None).await
#[derive(Debug, Deserialize)]
pub struct ModelSearchRequest {
pub query: String,
pub user_id: Option<String>,
pub include_cloud: Option<bool>,
pub limit: Option<i32>,
pub verbose: Option<bool>,
pub json_output: Option<bool>,
}
/// 创建新模特
#[command]
pub async fn create_model(app: AppHandle, request: CreateModelRequest) -> Result<String, String> {
let args = vec![
"-m".to_string(),
"python_core.services.model_manager".to_string(),
"create_model".to_string(),
/// 创建模特
#[tauri::command]
pub async fn create_model_cli(
app: AppHandle,
request: ModelCreateRequest,
) -> Result<ModelResponse, String> {
let mut args = vec![
"model".to_string(),
"create".to_string(),
request.model_number,
request.model_image.unwrap_or_default(),
request.model_image,
];
execute_python_command(app, &args, None).await
if let Some(user_id) = request.user_id {
args.push("--user-id".to_string());
args.push(user_id);
}
if request.is_cloud.unwrap_or(false) {
args.push("--cloud".to_string());
}
if request.verbose.unwrap_or(false) {
args.push("--verbose".to_string());
}
if request.json_output.unwrap_or(true) {
args.push("--json".to_string());
}
execute_python_cli_command(app, args).await
}
/// 获取模特列表
#[tauri::command]
pub async fn list_models_cli(
app: AppHandle,
request: ModelListRequest,
) -> Result<ModelResponse, String> {
let mut args = vec!["model".to_string(), "list".to_string()];
if let Some(user_id) = request.user_id {
args.push("--user-id".to_string());
args.push(user_id);
}
if request.include_cloud.unwrap_or(true) {
args.push("--include-cloud".to_string());
}
if request.include_inactive.unwrap_or(false) {
args.push("--include-inactive".to_string());
}
if let Some(limit) = request.limit {
args.push("--limit".to_string());
args.push(limit.to_string());
}
if let Some(offset) = request.offset {
args.push("--offset".to_string());
args.push(offset.to_string());
}
if request.verbose.unwrap_or(false) {
args.push("--verbose".to_string());
}
if request.json_output.unwrap_or(true) {
args.push("--json".to_string());
}
execute_python_cli_command(app, args).await
}
/// 获取模特详情
#[tauri::command]
pub async fn get_model_cli(
app: AppHandle,
request: ModelGetRequest,
) -> Result<ModelResponse, String> {
let mut args = vec![
"model".to_string(),
"get".to_string(),
request.model_id,
];
if request.verbose.unwrap_or(false) {
args.push("--verbose".to_string());
}
if request.json_output.unwrap_or(true) {
args.push("--json".to_string());
}
execute_python_cli_command(app, args).await
}
/// 更新模特
#[command]
pub async fn update_model(
#[tauri::command]
pub async fn update_model_cli(
app: AppHandle,
model_id: String,
request: UpdateModelRequest,
) -> Result<String, String> {
let request_json = serde_json::to_string(&request)
.map_err(|e| format!("Failed to serialize request: {}", e))?;
let args = vec![
"-m".to_string(),
"python_core.services.model_manager".to_string(),
"update_model".to_string(),
model_id,
request_json,
request: ModelUpdateRequest,
) -> Result<ModelResponse, String> {
let mut args = vec![
"model".to_string(),
"update".to_string(),
request.model_id,
];
execute_python_command(app, &args, None).await
if let Some(model_number) = request.model_number {
args.push("--model-number".to_string());
args.push(model_number);
}
if let Some(model_image) = request.model_image {
args.push("--model-image".to_string());
args.push(model_image);
}
if let Some(is_active) = request.is_active {
args.push("--active".to_string());
args.push(is_active.to_string());
}
if let Some(is_cloud) = request.is_cloud {
args.push("--cloud".to_string());
args.push(is_cloud.to_string());
}
if request.verbose.unwrap_or(false) {
args.push("--verbose".to_string());
}
if request.json_output.unwrap_or(true) {
args.push("--json".to_string());
}
execute_python_cli_command(app, args).await
}
/// 删除模特
#[command]
pub async fn delete_model(app: AppHandle, model_id: String) -> Result<String, String> {
let args = vec![
"-m".to_string(),
"python_core.services.model_manager".to_string(),
"delete_model".to_string(),
model_id,
#[tauri::command]
pub async fn delete_model_cli(
app: AppHandle,
request: ModelDeleteRequest,
) -> Result<ModelResponse, String> {
let mut args = vec![
"model".to_string(),
"delete".to_string(),
request.model_id,
];
execute_python_command(app, &args, None).await
if request.hard_delete.unwrap_or(false) {
args.push("--hard".to_string());
}
if request.verbose.unwrap_or(false) {
args.push("--verbose".to_string());
}
if request.json_output.unwrap_or(true) {
args.push("--json".to_string());
}
execute_python_cli_command(app, args).await
}
/// 搜索模特
#[command]
pub async fn search_models(app: AppHandle, keyword: String) -> Result<String, String> {
let args = vec![
"-m".to_string(),
"python_core.services.model_manager".to_string(),
"search_models".to_string(),
keyword,
#[tauri::command]
pub async fn search_models_cli(
app: AppHandle,
request: ModelSearchRequest,
) -> Result<ModelResponse, String> {
let mut args = vec![
"model".to_string(),
"search".to_string(),
request.query,
];
execute_python_command(app, &args, None).await
if let Some(user_id) = request.user_id {
args.push("--user-id".to_string());
args.push(user_id);
}
if request.include_cloud.unwrap_or(true) {
args.push("--include-cloud".to_string());
}
if let Some(limit) = request.limit {
args.push("--limit".to_string());
args.push(limit.to_string());
}
if request.verbose.unwrap_or(false) {
args.push("--verbose".to_string());
}
if request.json_output.unwrap_or(true) {
args.push("--json".to_string());
}
execute_python_cli_command(app, args).await
}

View File

@ -73,12 +73,12 @@ pub fn run() {
commands::project::scan_directory,
commands::project::open_file_in_system,
commands::project::delete_file,
commands::model::get_all_models,
commands::model::get_model_by_id,
commands::model::create_model,
commands::model::update_model,
commands::model::delete_model,
commands::model::search_models,
commands::model::list_models_cli,
commands::model::get_model_cli,
commands::model::create_model_cli,
commands::model::update_model_cli,
commands::model::delete_model_cli,
commands::model::search_models_cli,
commands::audio::get_all_audio_files,
commands::audio::get_audio_by_id,
commands::audio::get_audio_by_md5,

View File

@ -0,0 +1,370 @@
import { invoke } from '@tauri-apps/api/tauri'
export interface Model {
id: string
model_number: string
model_image: string
is_active: boolean
is_cloud: boolean
user_id: string
created_at: string
updated_at: string
}
export interface ModelResponse {
success: boolean
data?: any
error?: string
message?: string
}
export interface ModelCreateRequest {
model_number: string
model_image: string
user_id?: string
is_cloud?: boolean
verbose?: boolean
json_output?: boolean
}
export interface ModelListRequest {
user_id?: string
include_cloud?: boolean
include_inactive?: boolean
limit?: number
offset?: number
verbose?: boolean
json_output?: boolean
}
export interface ModelGetRequest {
model_id: string
verbose?: boolean
json_output?: boolean
}
export interface ModelUpdateRequest {
model_id: string
model_number?: string
model_image?: string
is_active?: boolean
is_cloud?: boolean
verbose?: boolean
json_output?: boolean
}
export interface ModelDeleteRequest {
model_id: string
hard_delete?: boolean
verbose?: boolean
json_output?: boolean
}
export interface ModelSearchRequest {
query: string
user_id?: string
include_cloud?: boolean
limit?: number
verbose?: boolean
json_output?: boolean
}
export class ModelServiceV2 {
/**
* ID
*/
private static getCurrentUserId(): string {
// 这里应该从认证服务获取当前用户ID
// 暂时返回默认值
return 'default'
}
/**
*
*/
static async createModel(
modelNumber: string,
modelImage: string,
isCloud: boolean = false,
userId?: string
): Promise<Model> {
try {
const request: ModelCreateRequest = {
model_number: modelNumber,
model_image: modelImage,
user_id: userId || this.getCurrentUserId(),
is_cloud: isCloud,
verbose: false,
json_output: true
}
const response = await invoke<ModelResponse>('create_model_cli', { request })
if (!response.success) {
throw new Error(response.error || response.message || 'Failed to create model')
}
return response.data?.model
} catch (error) {
console.error('Create model failed:', error)
throw error
}
}
/**
*
*/
static async getAllModels(
includeCloud: boolean = true,
includeInactive: boolean = false,
limit: number = 100,
offset: number = 0,
userId?: string
): Promise<{ models: Model[], total_count: number }> {
try {
const request: ModelListRequest = {
user_id: userId || this.getCurrentUserId(),
include_cloud: includeCloud,
include_inactive: includeInactive,
limit: limit,
offset: offset,
verbose: false,
json_output: true
}
const response = await invoke<ModelResponse>('list_models_cli', { request })
if (!response.success) {
throw new Error(response.error || response.message || 'Failed to get models')
}
return {
models: response.data?.models || [],
total_count: response.data?.total_count || 0
}
} catch (error) {
console.error('Get models failed:', error)
throw error
}
}
/**
* ID获取模特
*/
static async getModelById(modelId: string): Promise<Model | null> {
try {
const request: ModelGetRequest = {
model_id: modelId,
verbose: false,
json_output: true
}
const response = await invoke<ModelResponse>('get_model_cli', { request })
if (!response.success) {
if (response.error?.includes('不存在')) {
return null
}
throw new Error(response.error || response.message || 'Failed to get model')
}
return response.data?.model || null
} catch (error) {
console.error('Get model failed:', error)
throw error
}
}
/**
*
*/
static async updateModel(
modelId: string,
updates: {
model_number?: string
model_image?: string
is_active?: boolean
is_cloud?: boolean
}
): Promise<boolean> {
try {
const request: ModelUpdateRequest = {
model_id: modelId,
...updates,
verbose: false,
json_output: true
}
const response = await invoke<ModelResponse>('update_model_cli', { request })
if (!response.success) {
throw new Error(response.error || response.message || 'Failed to update model')
}
return true
} catch (error) {
console.error('Update model failed:', error)
throw error
}
}
/**
*
*/
static async deleteModel(modelId: string, hardDelete: boolean = false): Promise<boolean> {
try {
const request: ModelDeleteRequest = {
model_id: modelId,
hard_delete: hardDelete,
verbose: false,
json_output: true
}
const response = await invoke<ModelResponse>('delete_model_cli', { request })
if (!response.success) {
throw new Error(response.error || response.message || 'Failed to delete model')
}
return true
} catch (error) {
console.error('Delete model failed:', error)
throw error
}
}
/**
*
*/
static async searchModels(
query: string,
includeCloud: boolean = true,
limit: number = 50,
userId?: string
): Promise<Model[]> {
try {
const request: ModelSearchRequest = {
query: query,
user_id: userId || this.getCurrentUserId(),
include_cloud: includeCloud,
limit: limit,
verbose: false,
json_output: true
}
const response = await invoke<ModelResponse>('search_models_cli', { request })
if (!response.success) {
throw new Error(response.error || response.message || 'Failed to search models')
}
return response.data?.models || []
} catch (error) {
console.error('Search models failed:', error)
throw error
}
}
/**
*
*/
static async toggleModelStatus(modelId: string): Promise<boolean> {
try {
// 先获取当前模特信息
const model = await this.getModelById(modelId)
if (!model) {
throw new Error('模特不存在')
}
// 更新状态
return await this.updateModel(modelId, {
is_active: !model.is_active
})
} catch (error) {
console.error('Toggle model status failed:', error)
throw error
}
}
/**
*
*/
static async getModelStats(
includeCloud: boolean = true,
userId?: string
): Promise<{
total: number
active: number
inactive: number
cloud: number
local: number
}> {
try {
const { models } = await this.getAllModels(includeCloud, true, 1000, 0, userId)
const stats = {
total: models.length,
active: models.filter(m => m.is_active).length,
inactive: models.filter(m => !m.is_active).length,
cloud: models.filter(m => m.is_cloud).length,
local: models.filter(m => !m.is_cloud).length
}
return stats
} catch (error) {
console.error('Get model stats failed:', error)
throw error
}
}
/**
*
*/
static async batchDeleteModels(
modelIds: string[],
hardDelete: boolean = false
): Promise<{ success: string[], failed: string[] }> {
const results = {
success: [] as string[],
failed: [] as string[]
}
for (const modelId of modelIds) {
try {
await this.deleteModel(modelId, hardDelete)
results.success.push(modelId)
} catch (error) {
console.error(`Failed to delete model ${modelId}:`, error)
results.failed.push(modelId)
}
}
return results
}
/**
*
*/
static async isModelNumberUnique(
modelNumber: string,
excludeId?: string,
userId?: string
): Promise<boolean> {
try {
const models = await this.searchModels(modelNumber, true, 10, userId)
// 检查是否有完全匹配的模特编号
const exactMatch = models.find(m =>
m.model_number === modelNumber &&
m.id !== excludeId
)
return !exactMatch
} catch (error) {
console.error('Check model number uniqueness failed:', error)
return false
}
}
}
export default ModelServiceV2