mxivideo/python_core/database/model_postgres.py

440 lines
17 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
模特表 PostgreSQL 实现
"""
import uuid
from datetime import datetime
from typing import List, Optional, Dict, Any
from contextlib import contextmanager
from python_core.config import settings
from python_core.database.types import Model
from python_core.utils.logger import setup_logger
# 尝试导入 psycopg2如果失败则提供友好的错误信息
try:
import psycopg2
import psycopg2.extras
PSYCOPG2_AVAILABLE = True
except ImportError as e:
PSYCOPG2_AVAILABLE = False
PSYCOPG2_ERROR = str(e)
logger = setup_logger(__name__)
class ModelTablePostgres:
"""模特表 PostgreSQL 实现"""
def __init__(self):
if not PSYCOPG2_AVAILABLE:
error_msg = f"""
PostgreSQL support requires psycopg2 package.
Please install it using: pip install psycopg2-binary
原始错误: {PSYCOPG2_ERROR}
"""
logger.error(error_msg)
raise ImportError(error_msg)
self.db_url = settings.db
self.table_name = "models"
# 初始化模特表
self._init_model_table()
@contextmanager
def _get_connection(self):
"""获取数据库连接的上下文管理器"""
conn = None
try:
conn = psycopg2.connect(self.db_url)
yield conn
except Exception as e:
if conn:
conn.rollback()
logger.error(f"Database connection error: {e}")
raise e
finally:
if conn:
conn.close()
def _init_model_table(self):
"""初始化模特表"""
try:
with self._get_connection() as conn:
with conn.cursor() as cursor:
# 创建模特表
cursor.execute("""
CREATE TABLE IF NOT EXISTS models (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
model_number VARCHAR(100) NOT NULL,
model_image VARCHAR(500) NOT NULL,
is_active BOOLEAN DEFAULT true,
is_cloud BOOLEAN DEFAULT false,
user_id VARCHAR(100) NOT NULL DEFAULT 'default',
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
UNIQUE(model_number, user_id)
)
""")
# 创建索引
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_models_user_id ON models(user_id)
""")
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_models_model_number ON models(model_number)
""")
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_models_is_active ON models(is_active)
""")
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_models_is_cloud ON models(is_cloud)
""")
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_models_created_at ON models(created_at)
""")
# 创建更新时间触发器
cursor.execute("""
CREATE OR REPLACE FUNCTION update_models_updated_at()
RETURNS TRIGGER AS $$
BEGIN
NEW.updated_at = CURRENT_TIMESTAMP;
RETURN NEW;
END;
$$ language 'plpgsql';
""")
cursor.execute("""
DROP TRIGGER IF EXISTS update_models_updated_at_trigger ON models;
CREATE TRIGGER update_models_updated_at_trigger
BEFORE UPDATE ON models
FOR EACH ROW
EXECUTE FUNCTION update_models_updated_at();
""")
conn.commit()
logger.info("模特表初始化成功")
except Exception as e:
logger.error(f"初始化模特表失败: {e}")
raise
def _row_to_model(self, row: Dict[str, Any]) -> Model:
"""将数据库行转换为 Model 对象"""
return Model(
id=str(row['id']),
model_number=row['model_number'],
model_image=row['model_image'],
is_active=row['is_active'],
is_cloud=row['is_cloud'],
user_id=row['user_id'],
created_at=row['created_at'].isoformat() if row['created_at'] else '',
updated_at=row['updated_at'].isoformat() if row['updated_at'] else ''
)
def create_model(self, model_number: str, model_image: str, user_id: str = "default",
is_cloud: bool = False) -> Optional[Model]:
"""创建模特"""
try:
with self._get_connection() as conn:
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cursor:
cursor.execute("""
INSERT INTO models (model_number, model_image, user_id, is_cloud)
VALUES (%s, %s, %s, %s)
RETURNING *
""", (model_number, model_image, user_id, is_cloud))
row = cursor.fetchone()
conn.commit()
if row:
model = self._row_to_model(row)
logger.info(f"创建模特成功: {model.model_number}")
return model
except psycopg2.IntegrityError as e:
if "unique" in str(e).lower():
logger.warning(f"模特编号已存在: {model_number}")
raise ValueError(f"模特编号 '{model_number}' 已存在")
else:
logger.error(f"创建模特失败: {e}")
raise
except Exception as e:
logger.error(f"创建模特失败: {e}")
raise
return None
def get_model_by_id(self, model_id: str) -> Optional[Model]:
"""根据ID获取模特"""
try:
with self._get_connection() as conn:
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cursor:
cursor.execute("""
SELECT * FROM models WHERE id = %s
""", (model_id,))
row = cursor.fetchone()
if row:
return self._row_to_model(row)
except Exception as e:
logger.error(f"获取模特失败: {e}")
raise
return None
def get_model_by_number(self, model_number: str, user_id: str = "default") -> Optional[Model]:
"""根据模特编号获取模特"""
try:
with self._get_connection() as conn:
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cursor:
cursor.execute("""
SELECT * FROM models
WHERE model_number = %s AND user_id = %s
""", (model_number, user_id))
row = cursor.fetchone()
if row:
return self._row_to_model(row)
except Exception as e:
logger.error(f"获取模特失败: {e}")
raise
return None
def get_all_models(self, user_id: str = "default", include_cloud: bool = True,
include_inactive: bool = False, limit: int = 100,
offset: int = 0) -> List[Model]:
"""获取所有模特"""
try:
with self._get_connection() as conn:
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cursor:
# 构建查询条件
conditions = []
params = []
if include_cloud:
conditions.append("(user_id = %s OR is_cloud = true)")
params.append(user_id)
else:
conditions.append("user_id = %s")
params.append(user_id)
if not include_inactive:
conditions.append("is_active = true")
where_clause = " AND ".join(conditions) if conditions else "1=1"
cursor.execute(f"""
SELECT * FROM models
WHERE {where_clause}
ORDER BY created_at DESC
LIMIT %s OFFSET %s
""", params + [limit, offset])
rows = cursor.fetchall()
return [self._row_to_model(row) for row in rows]
except Exception as e:
logger.error(f"获取模特列表失败: {e}")
raise
return []
def update_model(self, model_id: str, updates: Dict[str, Any]) -> bool:
"""更新模特"""
try:
if not updates:
return True
# 允许更新的字段
allowed_fields = {
'model_number', 'model_image', 'is_active', 'is_cloud'
}
# 过滤允许的字段
filtered_updates = {k: v for k, v in updates.items() if k in allowed_fields}
if not filtered_updates:
logger.warning("没有有效的更新字段")
return False
with self._get_connection() as conn:
with conn.cursor() as cursor:
# 构建更新语句
set_clauses = []
params = []
for field, value in filtered_updates.items():
set_clauses.append(f"{field} = %s")
params.append(value)
params.append(model_id)
cursor.execute(f"""
UPDATE models
SET {', '.join(set_clauses)}
WHERE id = %s
""", params)
affected_rows = cursor.rowcount
conn.commit()
if affected_rows > 0:
logger.info(f"更新模特成功: {model_id}")
return True
else:
logger.warning(f"模特不存在: {model_id}")
return False
except psycopg2.IntegrityError as e:
if "unique" in str(e).lower():
logger.warning(f"模特编号已存在")
raise ValueError("模特编号已存在")
else:
logger.error(f"更新模特失败: {e}")
raise
except Exception as e:
logger.error(f"更新模特失败: {e}")
raise
def delete_model(self, model_id: str, hard_delete: bool = False) -> bool:
"""删除模特"""
try:
with self._get_connection() as conn:
with conn.cursor() as cursor:
if hard_delete:
# 硬删除
cursor.execute("""
DELETE FROM models WHERE id = %s
""", (model_id,))
else:
# 软删除
cursor.execute("""
UPDATE models SET is_active = false WHERE id = %s
""", (model_id,))
affected_rows = cursor.rowcount
conn.commit()
if affected_rows > 0:
action = "删除" if hard_delete else "禁用"
logger.info(f"{action}模特成功: {model_id}")
return True
else:
logger.warning(f"模特不存在: {model_id}")
return False
except Exception as e:
logger.error(f"删除模特失败: {e}")
raise
def search_models(self, query: str, user_id: str = "default",
include_cloud: bool = True, limit: int = 50) -> List[Model]:
"""搜索模特"""
try:
with self._get_connection() as conn:
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cursor:
# 简化搜索逻辑
search_pattern = f"%{query}%"
if include_cloud:
cursor.execute("""
SELECT * FROM models
WHERE model_number ILIKE %s
AND is_active = true
AND (user_id = %s OR is_cloud = true)
ORDER BY
CASE WHEN model_number ILIKE %s THEN 1 ELSE 2 END,
created_at DESC
LIMIT %s
""", (search_pattern, user_id, search_pattern, limit))
else:
cursor.execute("""
SELECT * FROM models
WHERE model_number ILIKE %s
AND is_active = true
AND user_id = %s
ORDER BY
CASE WHEN model_number ILIKE %s THEN 1 ELSE 2 END,
created_at DESC
LIMIT %s
""", (search_pattern, user_id, search_pattern, limit))
rows = cursor.fetchall()
return [self._row_to_model(row) for row in rows]
except Exception as e:
logger.error(f"搜索模特失败: {e}")
raise
return []
def get_model_count(self, user_id: str = "default", include_cloud: bool = True,
include_inactive: bool = False) -> int:
"""获取模特数量"""
try:
with self._get_connection() as conn:
with conn.cursor() as cursor:
# 构建查询条件
conditions = []
params = []
if include_cloud:
conditions.append("(user_id = %s OR is_cloud = true)")
params.append(user_id)
else:
conditions.append("user_id = %s")
params.append(user_id)
if not include_inactive:
conditions.append("is_active = true")
where_clause = " AND ".join(conditions) if conditions else "1=1"
cursor.execute(f"""
SELECT COUNT(*) FROM models WHERE {where_clause}
""", params)
result = cursor.fetchone()
return result[0] if result else 0
except Exception as e:
logger.error(f"获取模特数量失败: {e}")
raise
return 0
def toggle_model_status(self, model_id: str) -> bool:
"""切换模特状态"""
try:
with self._get_connection() as conn:
with conn.cursor() as cursor:
cursor.execute("""
UPDATE models
SET is_active = NOT is_active
WHERE id = %s
""", (model_id,))
affected_rows = cursor.rowcount
conn.commit()
if affected_rows > 0:
logger.info(f"切换模特状态成功: {model_id}")
return True
else:
logger.warning(f"模特不存在: {model_id}")
return False
except Exception as e:
logger.error(f"切换模特状态失败: {e}")
raise
# 创建全局实例
model_table = ModelTablePostgres()