686 lines
24 KiB
Python
686 lines
24 KiB
Python
# 用户表 - PostgreSQL 版本
|
||
|
||
import hashlib
|
||
import uuid
|
||
from typing import Dict, List, Any, Optional
|
||
from datetime import datetime
|
||
from contextlib import contextmanager
|
||
from python_core.config import settings
|
||
from python_core.utils.logger import setup_logger
|
||
|
||
# 尝试导入 psycopg2,如果失败则提供友好的错误信息
|
||
try:
|
||
import psycopg2
|
||
import psycopg2.extras
|
||
PSYCOPG2_AVAILABLE = True
|
||
except ImportError as e:
|
||
PSYCOPG2_AVAILABLE = False
|
||
PSYCOPG2_ERROR = str(e)
|
||
|
||
logger = setup_logger(__name__)
|
||
|
||
class UserTablePostgres:
|
||
"""
|
||
用户表类 - PostgreSQL 版本
|
||
基于 PostgreSQL 数据库实现的用户管理功能
|
||
"""
|
||
|
||
def __init__(self):
|
||
# 检查 psycopg2 是否可用
|
||
if not PSYCOPG2_AVAILABLE:
|
||
error_msg = f"""
|
||
PostgreSQL 驱动 psycopg2 未安装。请安装:
|
||
|
||
方案1(推荐):
|
||
pip install psycopg2-binary
|
||
|
||
方案2:
|
||
# Ubuntu/Debian
|
||
sudo apt-get install postgresql libpq-dev python3-dev
|
||
pip install psycopg2
|
||
|
||
# CentOS/RHEL
|
||
sudo yum install postgresql postgresql-devel python3-devel
|
||
pip install psycopg2
|
||
|
||
# macOS
|
||
brew install postgresql
|
||
pip install psycopg2
|
||
|
||
原始错误: {PSYCOPG2_ERROR}
|
||
"""
|
||
logger.error(error_msg)
|
||
raise ImportError(error_msg)
|
||
|
||
self.db_url = settings.db
|
||
self.table_name = "users"
|
||
|
||
# 初始化用户表
|
||
self._init_user_table()
|
||
|
||
@contextmanager
|
||
def _get_connection(self):
|
||
"""获取数据库连接的上下文管理器"""
|
||
conn = None
|
||
try:
|
||
conn = psycopg2.connect(self.db_url)
|
||
yield conn
|
||
except Exception as e:
|
||
if conn:
|
||
conn.rollback()
|
||
logger.error(f"Database connection error: {e}")
|
||
raise e
|
||
finally:
|
||
if conn:
|
||
conn.close()
|
||
|
||
def _init_user_table(self):
|
||
"""初始化用户表"""
|
||
try:
|
||
with self._get_connection() as conn:
|
||
with conn.cursor() as cursor:
|
||
# 创建用户表(如果不存在)- 不包含新字段
|
||
create_table_sql = """
|
||
CREATE TABLE IF NOT EXISTS users (
|
||
id VARCHAR(36) PRIMARY KEY,
|
||
username VARCHAR(50) UNIQUE NOT NULL,
|
||
email VARCHAR(100) UNIQUE NOT NULL,
|
||
password_hash VARCHAR(64) NOT NULL,
|
||
is_active BOOLEAN DEFAULT TRUE,
|
||
last_login TIMESTAMP,
|
||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||
);
|
||
"""
|
||
cursor.execute(create_table_sql)
|
||
|
||
# 检查并添加缺失的字段
|
||
self._migrate_user_table(cursor)
|
||
|
||
# 创建索引
|
||
indexes = [
|
||
"CREATE INDEX IF NOT EXISTS idx_users_username ON users(username);",
|
||
"CREATE INDEX IF NOT EXISTS idx_users_email ON users(email);",
|
||
"CREATE INDEX IF NOT EXISTS idx_users_is_active ON users(is_active);",
|
||
"CREATE INDEX IF NOT EXISTS idx_users_created_at ON users(created_at);"
|
||
]
|
||
|
||
for index_sql in indexes:
|
||
cursor.execute(index_sql)
|
||
|
||
conn.commit()
|
||
logger.info("User table initialized")
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to initialize user table: {e}")
|
||
raise e
|
||
|
||
def _migrate_user_table(self, cursor):
|
||
"""
|
||
迁移用户表结构,添加缺失的字段
|
||
"""
|
||
try:
|
||
# 检查 display_name 字段是否存在
|
||
cursor.execute("""
|
||
SELECT column_name
|
||
FROM information_schema.columns
|
||
WHERE table_name = 'users' AND column_name = 'display_name'
|
||
""")
|
||
|
||
if not cursor.fetchone():
|
||
logger.info("Adding display_name column to users table")
|
||
cursor.execute("""
|
||
ALTER TABLE users
|
||
ADD COLUMN display_name VARCHAR(100) DEFAULT ''
|
||
""")
|
||
|
||
# 为现有用户设置 display_name 为 username
|
||
cursor.execute("""
|
||
UPDATE users
|
||
SET display_name = username
|
||
WHERE display_name = '' OR display_name IS NULL
|
||
""")
|
||
|
||
# 设置字段为 NOT NULL
|
||
cursor.execute("""
|
||
ALTER TABLE users
|
||
ALTER COLUMN display_name SET NOT NULL
|
||
""")
|
||
|
||
# 检查 avatar_url 字段是否存在
|
||
cursor.execute("""
|
||
SELECT column_name
|
||
FROM information_schema.columns
|
||
WHERE table_name = 'users' AND column_name = 'avatar_url'
|
||
""")
|
||
|
||
if not cursor.fetchone():
|
||
logger.info("Adding avatar_url column to users table")
|
||
cursor.execute("""
|
||
ALTER TABLE users
|
||
ADD COLUMN avatar_url TEXT DEFAULT ''
|
||
""")
|
||
|
||
# 检查 last_login 字段是否存在
|
||
cursor.execute("""
|
||
SELECT column_name
|
||
FROM information_schema.columns
|
||
WHERE table_name = 'users' AND column_name = 'last_login'
|
||
""")
|
||
|
||
if not cursor.fetchone():
|
||
logger.info("Adding last_login column to users table")
|
||
cursor.execute("""
|
||
ALTER TABLE users
|
||
ADD COLUMN last_login TIMESTAMP
|
||
""")
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to migrate user table: {e}")
|
||
raise e
|
||
|
||
def _hash_password(self, password: str) -> str:
|
||
"""密码哈希"""
|
||
return hashlib.sha256(password.encode()).hexdigest()
|
||
|
||
def _verify_password(self, password: str, password_hash: str) -> bool:
|
||
"""验证密码"""
|
||
return self._hash_password(password) == password_hash
|
||
|
||
def create_user(self, username: str, email: str, password: str, display_name: str = None) -> str:
|
||
"""
|
||
创建用户
|
||
|
||
Args:
|
||
username: 用户名
|
||
email: 邮箱
|
||
password: 密码
|
||
display_name: 显示名称
|
||
|
||
Returns:
|
||
用户ID
|
||
"""
|
||
try:
|
||
with self._get_connection() as conn:
|
||
with conn.cursor() as cursor:
|
||
# 检查用户名是否已存在
|
||
cursor.execute("SELECT id FROM users WHERE username = %s", (username,))
|
||
if cursor.fetchone():
|
||
raise ValueError(f"Username '{username}' already exists")
|
||
|
||
# 检查邮箱是否已存在
|
||
cursor.execute("SELECT id FROM users WHERE email = %s", (email,))
|
||
if cursor.fetchone():
|
||
raise ValueError(f"Email '{email}' already exists")
|
||
|
||
# 生成用户ID
|
||
user_id = str(uuid.uuid4())
|
||
|
||
# 插入用户记录
|
||
insert_sql = """
|
||
INSERT INTO users (id, username, email, password_hash, display_name, avatar_url, is_active, created_at, updated_at)
|
||
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)
|
||
"""
|
||
|
||
now = datetime.now()
|
||
cursor.execute(insert_sql, (
|
||
user_id,
|
||
username,
|
||
email,
|
||
self._hash_password(password),
|
||
display_name or username,
|
||
'',
|
||
True,
|
||
now,
|
||
now
|
||
))
|
||
|
||
conn.commit()
|
||
logger.info(f"Created user: {username} (ID: {user_id})")
|
||
return user_id
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to create user '{username}': {e}")
|
||
raise e
|
||
|
||
def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
根据ID获取用户
|
||
|
||
Args:
|
||
user_id: 用户ID
|
||
|
||
Returns:
|
||
用户信息,如果不存在返回None
|
||
"""
|
||
try:
|
||
with self._get_connection() as conn:
|
||
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cursor:
|
||
cursor.execute("SELECT * FROM users WHERE id = %s", (user_id,))
|
||
row = cursor.fetchone()
|
||
|
||
if row:
|
||
return {
|
||
'id': row['id'],
|
||
'username': row['username'],
|
||
'email': row['email'],
|
||
'password_hash': row['password_hash'],
|
||
'display_name': row.get('display_name', row['username']),
|
||
'avatar_url': row.get('avatar_url', ''),
|
||
'is_active': row['is_active'],
|
||
'last_login': row.get('last_login').isoformat() if row.get('last_login') else None,
|
||
'created_at': row.get('created_at').isoformat() if row.get('created_at') else None,
|
||
'updated_at': row.get('updated_at').isoformat() if row.get('updated_at') else None
|
||
}
|
||
return None
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to get user by ID '{user_id}': {e}")
|
||
return None
|
||
|
||
def get_user_by_username(self, username: str) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
根据用户名获取用户
|
||
|
||
Args:
|
||
username: 用户名
|
||
|
||
Returns:
|
||
用户信息,如果不存在返回None
|
||
"""
|
||
try:
|
||
with self._get_connection() as conn:
|
||
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cursor:
|
||
cursor.execute("SELECT * FROM users WHERE username = %s", (username,))
|
||
row = cursor.fetchone()
|
||
|
||
if row:
|
||
return {
|
||
'id': row['id'],
|
||
'username': row['username'],
|
||
'email': row['email'],
|
||
'password_hash': row['password_hash'],
|
||
'display_name': row.get('display_name', row['username']),
|
||
'avatar_url': row.get('avatar_url', ''),
|
||
'is_active': row['is_active'],
|
||
'last_login': row.get('last_login').isoformat() if row.get('last_login') else None,
|
||
'created_at': row.get('created_at').isoformat() if row.get('created_at') else None,
|
||
'updated_at': row.get('updated_at').isoformat() if row.get('updated_at') else None
|
||
}
|
||
return None
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to get user by username '{username}': {e}")
|
||
return None
|
||
|
||
def get_user_by_email(self, email: str) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
根据邮箱获取用户
|
||
|
||
Args:
|
||
email: 邮箱
|
||
|
||
Returns:
|
||
用户信息,如果不存在返回None
|
||
"""
|
||
try:
|
||
with self._get_connection() as conn:
|
||
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cursor:
|
||
cursor.execute("SELECT * FROM users WHERE email = %s", (email,))
|
||
row = cursor.fetchone()
|
||
|
||
if row:
|
||
return {
|
||
'id': row['id'],
|
||
'username': row['username'],
|
||
'email': row['email'],
|
||
'password_hash': row['password_hash'],
|
||
'display_name': row.get('display_name', row['username']),
|
||
'avatar_url': row.get('avatar_url', ''),
|
||
'is_active': row['is_active'],
|
||
'last_login': row.get('last_login').isoformat() if row.get('last_login') else None,
|
||
'created_at': row.get('created_at').isoformat() if row.get('created_at') else None,
|
||
'updated_at': row.get('updated_at').isoformat() if row.get('updated_at') else None
|
||
}
|
||
return None
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to get user by email '{email}': {e}")
|
||
return None
|
||
|
||
def authenticate_user(self, username_or_email: str, password: str) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
用户认证
|
||
|
||
Args:
|
||
username_or_email: 用户名或邮箱
|
||
password: 密码
|
||
|
||
Returns:
|
||
认证成功返回用户信息,失败返回None
|
||
"""
|
||
try:
|
||
# 尝试按用户名查找
|
||
user = self.get_user_by_username(username_or_email)
|
||
|
||
# 如果按用户名找不到,尝试按邮箱查找
|
||
if not user:
|
||
user = self.get_user_by_email(username_or_email)
|
||
|
||
# 验证用户存在且密码正确
|
||
if user and self._verify_password(password, user["password_hash"]):
|
||
# 更新最后登录时间
|
||
self.update_last_login(user["id"])
|
||
|
||
# 返回用户信息(不包含密码哈希)
|
||
user_info = user.copy()
|
||
del user_info["password_hash"]
|
||
|
||
logger.info(f"User authenticated: {user['username']}")
|
||
return user_info
|
||
|
||
logger.warning(f"Authentication failed for: {username_or_email}")
|
||
return None
|
||
|
||
except Exception as e:
|
||
logger.error(f"Authentication error for '{username_or_email}': {e}")
|
||
return None
|
||
|
||
def update_user(self, user_id: str, updates: Dict[str, Any]) -> bool:
|
||
"""
|
||
更新用户信息
|
||
|
||
Args:
|
||
user_id: 用户ID
|
||
updates: 要更新的字段
|
||
|
||
Returns:
|
||
更新成功返回True
|
||
"""
|
||
try:
|
||
with self._get_connection() as conn:
|
||
with conn.cursor() as cursor:
|
||
# 检查用户是否存在
|
||
cursor.execute("SELECT id FROM users WHERE id = %s", (user_id,))
|
||
if not cursor.fetchone():
|
||
logger.warning(f"User not found: {user_id}")
|
||
return False
|
||
|
||
# 移除不应该直接更新的字段
|
||
protected_fields = ["id", "created_at", "password_hash"]
|
||
filtered_updates = {k: v for k, v in updates.items() if k not in protected_fields}
|
||
|
||
if not filtered_updates:
|
||
logger.warning("No valid fields to update")
|
||
return False
|
||
|
||
# 添加更新时间
|
||
filtered_updates["updated_at"] = datetime.now()
|
||
|
||
# 构建更新SQL
|
||
set_clauses = []
|
||
values = []
|
||
for key, value in filtered_updates.items():
|
||
set_clauses.append(f"{key} = %s")
|
||
values.append(value)
|
||
|
||
values.append(user_id) # WHERE条件的参数
|
||
|
||
update_sql = f"UPDATE users SET {', '.join(set_clauses)} WHERE id = %s"
|
||
cursor.execute(update_sql, values)
|
||
|
||
conn.commit()
|
||
|
||
if cursor.rowcount > 0:
|
||
logger.info(f"Updated user: {user_id}")
|
||
return True
|
||
else:
|
||
logger.warning(f"No rows updated for user: {user_id}")
|
||
return False
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to update user '{user_id}': {e}")
|
||
return False
|
||
|
||
def update_password(self, user_id: str, new_password: str) -> bool:
|
||
"""
|
||
更新用户密码
|
||
|
||
Args:
|
||
user_id: 用户ID
|
||
new_password: 新密码
|
||
|
||
Returns:
|
||
更新成功返回True
|
||
"""
|
||
try:
|
||
with self._get_connection() as conn:
|
||
with conn.cursor() as cursor:
|
||
# 检查用户是否存在
|
||
cursor.execute("SELECT id FROM users WHERE id = %s", (user_id,))
|
||
if not cursor.fetchone():
|
||
logger.warning(f"User not found: {user_id}")
|
||
return False
|
||
|
||
# 更新密码
|
||
update_sql = "UPDATE users SET password_hash = %s, updated_at = %s WHERE id = %s"
|
||
cursor.execute(update_sql, (
|
||
self._hash_password(new_password),
|
||
datetime.now(),
|
||
user_id
|
||
))
|
||
|
||
conn.commit()
|
||
|
||
if cursor.rowcount > 0:
|
||
logger.info(f"Password updated for user: {user_id}")
|
||
return True
|
||
else:
|
||
logger.warning(f"No rows updated for user: {user_id}")
|
||
return False
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to update password for user '{user_id}': {e}")
|
||
return False
|
||
|
||
def update_last_login(self, user_id: str) -> bool:
|
||
"""
|
||
更新最后登录时间
|
||
|
||
Args:
|
||
user_id: 用户ID
|
||
|
||
Returns:
|
||
更新成功返回True
|
||
"""
|
||
try:
|
||
return self.update_user(user_id, {
|
||
"last_login": datetime.now()
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to update last login for user '{user_id}': {e}")
|
||
return False
|
||
|
||
def deactivate_user(self, user_id: str) -> bool:
|
||
"""
|
||
停用用户
|
||
|
||
Args:
|
||
user_id: 用户ID
|
||
|
||
Returns:
|
||
操作成功返回True
|
||
"""
|
||
try:
|
||
return self.update_user(user_id, {"is_active": False})
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to deactivate user '{user_id}': {e}")
|
||
return False
|
||
|
||
def activate_user(self, user_id: str) -> bool:
|
||
"""
|
||
激活用户
|
||
|
||
Args:
|
||
user_id: 用户ID
|
||
|
||
Returns:
|
||
操作成功返回True
|
||
"""
|
||
try:
|
||
return self.update_user(user_id, {"is_active": True})
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to activate user '{user_id}': {e}")
|
||
return False
|
||
|
||
def delete_user(self, user_id: str) -> bool:
|
||
"""
|
||
删除用户
|
||
|
||
Args:
|
||
user_id: 用户ID
|
||
|
||
Returns:
|
||
删除成功返回True
|
||
"""
|
||
try:
|
||
with self._get_connection() as conn:
|
||
with conn.cursor() as cursor:
|
||
cursor.execute("DELETE FROM users WHERE id = %s", (user_id,))
|
||
conn.commit()
|
||
|
||
if cursor.rowcount > 0:
|
||
logger.info(f"Deleted user: {user_id}")
|
||
return True
|
||
else:
|
||
logger.warning(f"No user found to delete: {user_id}")
|
||
return False
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to delete user '{user_id}': {e}")
|
||
return False
|
||
|
||
def get_all_users(self, include_inactive: bool = False, limit: int = 100) -> List[Dict[str, Any]]:
|
||
"""
|
||
获取所有用户
|
||
|
||
Args:
|
||
include_inactive: 是否包含非活跃用户
|
||
limit: 最大返回数量
|
||
|
||
Returns:
|
||
用户列表
|
||
"""
|
||
try:
|
||
with self._get_connection() as conn:
|
||
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cursor:
|
||
if include_inactive:
|
||
sql = "SELECT * FROM users ORDER BY created_at DESC LIMIT %s"
|
||
cursor.execute(sql, (limit,))
|
||
else:
|
||
sql = "SELECT * FROM users WHERE is_active = TRUE ORDER BY created_at DESC LIMIT %s"
|
||
cursor.execute(sql, (limit,))
|
||
|
||
rows = cursor.fetchall()
|
||
users = []
|
||
|
||
for row in rows:
|
||
user_info = {
|
||
'id': row['id'],
|
||
'username': row['username'],
|
||
'email': row['email'],
|
||
'display_name': row['display_name'],
|
||
'avatar_url': row['avatar_url'],
|
||
'is_active': row['is_active'],
|
||
'last_login': row['last_login'].isoformat() if row['last_login'] else None,
|
||
'created_at': row['created_at'].isoformat() if row['created_at'] else None,
|
||
'updated_at': row['updated_at'].isoformat() if row['updated_at'] else None
|
||
}
|
||
# 不包含密码哈希
|
||
users.append(user_info)
|
||
|
||
return users
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to get all users: {e}")
|
||
return []
|
||
|
||
def get_user_count(self, include_inactive: bool = False) -> int:
|
||
"""
|
||
获取用户数量
|
||
|
||
Args:
|
||
include_inactive: 是否包含非活跃用户
|
||
|
||
Returns:
|
||
用户数量
|
||
"""
|
||
try:
|
||
with self._get_connection() as conn:
|
||
with conn.cursor() as cursor:
|
||
if include_inactive:
|
||
cursor.execute("SELECT COUNT(*) FROM users")
|
||
else:
|
||
cursor.execute("SELECT COUNT(*) FROM users WHERE is_active = TRUE")
|
||
|
||
result = cursor.fetchone()
|
||
return result[0] if result else 0
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to get user count: {e}")
|
||
return 0
|
||
|
||
def search_users(self, query: str, limit: int = 50) -> List[Dict[str, Any]]:
|
||
"""
|
||
搜索用户
|
||
|
||
Args:
|
||
query: 搜索关键词(匹配用户名、邮箱、显示名称)
|
||
limit: 最大返回数量
|
||
|
||
Returns:
|
||
匹配的用户列表
|
||
"""
|
||
try:
|
||
with self._get_connection() as conn:
|
||
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cursor:
|
||
search_pattern = f"%{query}%"
|
||
sql = """
|
||
SELECT * FROM users
|
||
WHERE (username ILIKE %s OR email ILIKE %s OR display_name ILIKE %s)
|
||
ORDER BY created_at DESC
|
||
LIMIT %s
|
||
"""
|
||
cursor.execute(sql, (search_pattern, search_pattern, search_pattern, limit))
|
||
|
||
rows = cursor.fetchall()
|
||
users = []
|
||
|
||
for row in rows:
|
||
user_info = {
|
||
'id': row['id'],
|
||
'username': row['username'],
|
||
'email': row['email'],
|
||
'display_name': row['display_name'],
|
||
'avatar_url': row['avatar_url'],
|
||
'is_active': row['is_active'],
|
||
'last_login': row['last_login'].isoformat() if row['last_login'] else None,
|
||
'created_at': row['created_at'].isoformat() if row['created_at'] else None,
|
||
'updated_at': row['updated_at'].isoformat() if row['updated_at'] else None
|
||
}
|
||
# 不包含密码哈希
|
||
users.append(user_info)
|
||
|
||
return users
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to search users with query '{query}': {e}")
|
||
return []
|
||
|
||
|
||
# 创建全局用户表实例
|
||
user_table = UserTablePostgres()
|