440 lines
17 KiB
Python
440 lines
17 KiB
Python
"""
|
||
模特表 PostgreSQL 实现
|
||
"""
|
||
|
||
import uuid
|
||
from datetime import datetime
|
||
from typing import List, Optional, Dict, Any
|
||
from contextlib import contextmanager
|
||
from python_core.config import settings
|
||
from python_core.database.types import Model
|
||
from python_core.utils.logger import setup_logger
|
||
|
||
# 尝试导入 psycopg2,如果失败则提供友好的错误信息
|
||
try:
|
||
import psycopg2
|
||
import psycopg2.extras
|
||
PSYCOPG2_AVAILABLE = True
|
||
except ImportError as e:
|
||
PSYCOPG2_AVAILABLE = False
|
||
PSYCOPG2_ERROR = str(e)
|
||
|
||
logger = setup_logger(__name__)
|
||
|
||
|
||
class ModelTablePostgres:
|
||
"""模特表 PostgreSQL 实现"""
|
||
|
||
def __init__(self):
|
||
if not PSYCOPG2_AVAILABLE:
|
||
error_msg = f"""
|
||
PostgreSQL support requires psycopg2 package.
|
||
Please install it using: pip install psycopg2-binary
|
||
|
||
原始错误: {PSYCOPG2_ERROR}
|
||
"""
|
||
logger.error(error_msg)
|
||
raise ImportError(error_msg)
|
||
|
||
self.db_url = settings.db
|
||
self.table_name = "models"
|
||
|
||
# 初始化模特表
|
||
self._init_model_table()
|
||
|
||
@contextmanager
|
||
def _get_connection(self):
|
||
"""获取数据库连接的上下文管理器"""
|
||
conn = None
|
||
try:
|
||
conn = psycopg2.connect(self.db_url)
|
||
yield conn
|
||
except Exception as e:
|
||
if conn:
|
||
conn.rollback()
|
||
logger.error(f"Database connection error: {e}")
|
||
raise e
|
||
finally:
|
||
if conn:
|
||
conn.close()
|
||
|
||
def _init_model_table(self):
|
||
"""初始化模特表"""
|
||
try:
|
||
with self._get_connection() as conn:
|
||
with conn.cursor() as cursor:
|
||
# 创建模特表
|
||
cursor.execute("""
|
||
CREATE TABLE IF NOT EXISTS models (
|
||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||
model_number VARCHAR(100) NOT NULL,
|
||
model_image VARCHAR(500) NOT NULL,
|
||
is_active BOOLEAN DEFAULT true,
|
||
is_cloud BOOLEAN DEFAULT false,
|
||
user_id VARCHAR(100) NOT NULL DEFAULT 'default',
|
||
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
|
||
updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
|
||
UNIQUE(model_number, user_id)
|
||
)
|
||
""")
|
||
|
||
# 创建索引
|
||
cursor.execute("""
|
||
CREATE INDEX IF NOT EXISTS idx_models_user_id ON models(user_id)
|
||
""")
|
||
cursor.execute("""
|
||
CREATE INDEX IF NOT EXISTS idx_models_model_number ON models(model_number)
|
||
""")
|
||
cursor.execute("""
|
||
CREATE INDEX IF NOT EXISTS idx_models_is_active ON models(is_active)
|
||
""")
|
||
cursor.execute("""
|
||
CREATE INDEX IF NOT EXISTS idx_models_is_cloud ON models(is_cloud)
|
||
""")
|
||
cursor.execute("""
|
||
CREATE INDEX IF NOT EXISTS idx_models_created_at ON models(created_at)
|
||
""")
|
||
|
||
# 创建更新时间触发器
|
||
cursor.execute("""
|
||
CREATE OR REPLACE FUNCTION update_models_updated_at()
|
||
RETURNS TRIGGER AS $$
|
||
BEGIN
|
||
NEW.updated_at = CURRENT_TIMESTAMP;
|
||
RETURN NEW;
|
||
END;
|
||
$$ language 'plpgsql';
|
||
""")
|
||
|
||
cursor.execute("""
|
||
DROP TRIGGER IF EXISTS update_models_updated_at_trigger ON models;
|
||
CREATE TRIGGER update_models_updated_at_trigger
|
||
BEFORE UPDATE ON models
|
||
FOR EACH ROW
|
||
EXECUTE FUNCTION update_models_updated_at();
|
||
""")
|
||
|
||
conn.commit()
|
||
logger.info("模特表初始化成功")
|
||
|
||
except Exception as e:
|
||
logger.error(f"初始化模特表失败: {e}")
|
||
raise
|
||
|
||
def _row_to_model(self, row: Dict[str, Any]) -> Model:
|
||
"""将数据库行转换为 Model 对象"""
|
||
return Model(
|
||
id=str(row['id']),
|
||
model_number=row['model_number'],
|
||
model_image=row['model_image'],
|
||
is_active=row['is_active'],
|
||
is_cloud=row['is_cloud'],
|
||
user_id=row['user_id'],
|
||
created_at=row['created_at'].isoformat() if row['created_at'] else '',
|
||
updated_at=row['updated_at'].isoformat() if row['updated_at'] else ''
|
||
)
|
||
|
||
def create_model(self, model_number: str, model_image: str, user_id: str = "default",
|
||
is_cloud: bool = False) -> Optional[Model]:
|
||
"""创建模特"""
|
||
try:
|
||
with self._get_connection() as conn:
|
||
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cursor:
|
||
cursor.execute("""
|
||
INSERT INTO models (model_number, model_image, user_id, is_cloud)
|
||
VALUES (%s, %s, %s, %s)
|
||
RETURNING *
|
||
""", (model_number, model_image, user_id, is_cloud))
|
||
|
||
row = cursor.fetchone()
|
||
conn.commit()
|
||
|
||
if row:
|
||
model = self._row_to_model(row)
|
||
logger.info(f"创建模特成功: {model.model_number}")
|
||
return model
|
||
|
||
except psycopg2.IntegrityError as e:
|
||
if "unique" in str(e).lower():
|
||
logger.warning(f"模特编号已存在: {model_number}")
|
||
raise ValueError(f"模特编号 '{model_number}' 已存在")
|
||
else:
|
||
logger.error(f"创建模特失败: {e}")
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"创建模特失败: {e}")
|
||
raise
|
||
|
||
return None
|
||
|
||
def get_model_by_id(self, model_id: str) -> Optional[Model]:
|
||
"""根据ID获取模特"""
|
||
try:
|
||
with self._get_connection() as conn:
|
||
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cursor:
|
||
cursor.execute("""
|
||
SELECT * FROM models WHERE id = %s
|
||
""", (model_id,))
|
||
|
||
row = cursor.fetchone()
|
||
if row:
|
||
return self._row_to_model(row)
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取模特失败: {e}")
|
||
raise
|
||
|
||
return None
|
||
|
||
def get_model_by_number(self, model_number: str, user_id: str = "default") -> Optional[Model]:
|
||
"""根据模特编号获取模特"""
|
||
try:
|
||
with self._get_connection() as conn:
|
||
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cursor:
|
||
cursor.execute("""
|
||
SELECT * FROM models
|
||
WHERE model_number = %s AND user_id = %s
|
||
""", (model_number, user_id))
|
||
|
||
row = cursor.fetchone()
|
||
if row:
|
||
return self._row_to_model(row)
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取模特失败: {e}")
|
||
raise
|
||
|
||
return None
|
||
|
||
def get_all_models(self, user_id: str = "default", include_cloud: bool = True,
|
||
include_inactive: bool = False, limit: int = 100,
|
||
offset: int = 0) -> List[Model]:
|
||
"""获取所有模特"""
|
||
try:
|
||
with self._get_connection() as conn:
|
||
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cursor:
|
||
# 构建查询条件
|
||
conditions = []
|
||
params = []
|
||
|
||
if include_cloud:
|
||
conditions.append("(user_id = %s OR is_cloud = true)")
|
||
params.append(user_id)
|
||
else:
|
||
conditions.append("user_id = %s")
|
||
params.append(user_id)
|
||
|
||
if not include_inactive:
|
||
conditions.append("is_active = true")
|
||
|
||
where_clause = " AND ".join(conditions) if conditions else "1=1"
|
||
|
||
cursor.execute(f"""
|
||
SELECT * FROM models
|
||
WHERE {where_clause}
|
||
ORDER BY created_at DESC
|
||
LIMIT %s OFFSET %s
|
||
""", params + [limit, offset])
|
||
|
||
rows = cursor.fetchall()
|
||
return [self._row_to_model(row) for row in rows]
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取模特列表失败: {e}")
|
||
raise
|
||
|
||
return []
|
||
|
||
def update_model(self, model_id: str, updates: Dict[str, Any]) -> bool:
|
||
"""更新模特"""
|
||
try:
|
||
if not updates:
|
||
return True
|
||
|
||
# 允许更新的字段
|
||
allowed_fields = {
|
||
'model_number', 'model_image', 'is_active', 'is_cloud'
|
||
}
|
||
|
||
# 过滤允许的字段
|
||
filtered_updates = {k: v for k, v in updates.items() if k in allowed_fields}
|
||
|
||
if not filtered_updates:
|
||
logger.warning("没有有效的更新字段")
|
||
return False
|
||
|
||
with self._get_connection() as conn:
|
||
with conn.cursor() as cursor:
|
||
# 构建更新语句
|
||
set_clauses = []
|
||
params = []
|
||
|
||
for field, value in filtered_updates.items():
|
||
set_clauses.append(f"{field} = %s")
|
||
params.append(value)
|
||
|
||
params.append(model_id)
|
||
|
||
cursor.execute(f"""
|
||
UPDATE models
|
||
SET {', '.join(set_clauses)}
|
||
WHERE id = %s
|
||
""", params)
|
||
|
||
affected_rows = cursor.rowcount
|
||
conn.commit()
|
||
|
||
if affected_rows > 0:
|
||
logger.info(f"更新模特成功: {model_id}")
|
||
return True
|
||
else:
|
||
logger.warning(f"模特不存在: {model_id}")
|
||
return False
|
||
|
||
except psycopg2.IntegrityError as e:
|
||
if "unique" in str(e).lower():
|
||
logger.warning(f"模特编号已存在")
|
||
raise ValueError("模特编号已存在")
|
||
else:
|
||
logger.error(f"更新模特失败: {e}")
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"更新模特失败: {e}")
|
||
raise
|
||
|
||
def delete_model(self, model_id: str, hard_delete: bool = False) -> bool:
|
||
"""删除模特"""
|
||
try:
|
||
with self._get_connection() as conn:
|
||
with conn.cursor() as cursor:
|
||
if hard_delete:
|
||
# 硬删除
|
||
cursor.execute("""
|
||
DELETE FROM models WHERE id = %s
|
||
""", (model_id,))
|
||
else:
|
||
# 软删除
|
||
cursor.execute("""
|
||
UPDATE models SET is_active = false WHERE id = %s
|
||
""", (model_id,))
|
||
|
||
affected_rows = cursor.rowcount
|
||
conn.commit()
|
||
|
||
if affected_rows > 0:
|
||
action = "删除" if hard_delete else "禁用"
|
||
logger.info(f"{action}模特成功: {model_id}")
|
||
return True
|
||
else:
|
||
logger.warning(f"模特不存在: {model_id}")
|
||
return False
|
||
|
||
except Exception as e:
|
||
logger.error(f"删除模特失败: {e}")
|
||
raise
|
||
|
||
def search_models(self, query: str, user_id: str = "default",
|
||
include_cloud: bool = True, limit: int = 50) -> List[Model]:
|
||
"""搜索模特"""
|
||
try:
|
||
with self._get_connection() as conn:
|
||
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cursor:
|
||
# 简化搜索逻辑
|
||
search_pattern = f"%{query}%"
|
||
|
||
if include_cloud:
|
||
cursor.execute("""
|
||
SELECT * FROM models
|
||
WHERE model_number ILIKE %s
|
||
AND is_active = true
|
||
AND (user_id = %s OR is_cloud = true)
|
||
ORDER BY
|
||
CASE WHEN model_number ILIKE %s THEN 1 ELSE 2 END,
|
||
created_at DESC
|
||
LIMIT %s
|
||
""", (search_pattern, user_id, search_pattern, limit))
|
||
else:
|
||
cursor.execute("""
|
||
SELECT * FROM models
|
||
WHERE model_number ILIKE %s
|
||
AND is_active = true
|
||
AND user_id = %s
|
||
ORDER BY
|
||
CASE WHEN model_number ILIKE %s THEN 1 ELSE 2 END,
|
||
created_at DESC
|
||
LIMIT %s
|
||
""", (search_pattern, user_id, search_pattern, limit))
|
||
|
||
rows = cursor.fetchall()
|
||
return [self._row_to_model(row) for row in rows]
|
||
|
||
except Exception as e:
|
||
logger.error(f"搜索模特失败: {e}")
|
||
raise
|
||
|
||
return []
|
||
|
||
def get_model_count(self, user_id: str = "default", include_cloud: bool = True,
|
||
include_inactive: bool = False) -> int:
|
||
"""获取模特数量"""
|
||
try:
|
||
with self._get_connection() as conn:
|
||
with conn.cursor() as cursor:
|
||
# 构建查询条件
|
||
conditions = []
|
||
params = []
|
||
|
||
if include_cloud:
|
||
conditions.append("(user_id = %s OR is_cloud = true)")
|
||
params.append(user_id)
|
||
else:
|
||
conditions.append("user_id = %s")
|
||
params.append(user_id)
|
||
|
||
if not include_inactive:
|
||
conditions.append("is_active = true")
|
||
|
||
where_clause = " AND ".join(conditions) if conditions else "1=1"
|
||
|
||
cursor.execute(f"""
|
||
SELECT COUNT(*) FROM models WHERE {where_clause}
|
||
""", params)
|
||
|
||
result = cursor.fetchone()
|
||
return result[0] if result else 0
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取模特数量失败: {e}")
|
||
raise
|
||
|
||
return 0
|
||
|
||
def toggle_model_status(self, model_id: str) -> bool:
|
||
"""切换模特状态"""
|
||
try:
|
||
with self._get_connection() as conn:
|
||
with conn.cursor() as cursor:
|
||
cursor.execute("""
|
||
UPDATE models
|
||
SET is_active = NOT is_active
|
||
WHERE id = %s
|
||
""", (model_id,))
|
||
|
||
affected_rows = cursor.rowcount
|
||
conn.commit()
|
||
|
||
if affected_rows > 0:
|
||
logger.info(f"切换模特状态成功: {model_id}")
|
||
return True
|
||
else:
|
||
logger.warning(f"模特不存在: {model_id}")
|
||
return False
|
||
|
||
except Exception as e:
|
||
logger.error(f"切换模特状态失败: {e}")
|
||
raise
|
||
|
||
|
||
# 创建全局实例
|
||
model_table = ModelTablePostgres()
|