294 lines
8.4 KiB
Python
294 lines
8.4 KiB
Python
"""
|
||
用户认证服务
|
||
基于UserTable实现的用户认证功能
|
||
"""
|
||
|
||
import re
|
||
from typing import Optional, Dict, Any
|
||
from dataclasses import dataclass
|
||
|
||
from python_core.database.user import user_table
|
||
from python_core.utils.jwt_auth import generate_access_token, verify_access_token
|
||
from python_core.utils.logger import setup_logger
|
||
|
||
logger = setup_logger(__name__)
|
||
|
||
|
||
@dataclass
|
||
class LoginRequest:
|
||
"""登录请求数据"""
|
||
username_or_email: str
|
||
password: str
|
||
|
||
|
||
@dataclass
|
||
class RegisterRequest:
|
||
"""注册请求数据"""
|
||
username: str
|
||
email: str
|
||
password: str
|
||
display_name: str = None
|
||
|
||
|
||
@dataclass
|
||
class AuthResponse:
|
||
"""认证响应数据"""
|
||
success: bool
|
||
message: str = ""
|
||
user: Dict[str, Any] = None
|
||
token: str = None
|
||
expires_at: str = None
|
||
|
||
|
||
|
||
|
||
|
||
class AuthService:
|
||
"""用户认证服务"""
|
||
|
||
def __init__(self):
|
||
self.user_table = user_table
|
||
logger.info("AuthService initialized")
|
||
|
||
def register(self, request: RegisterRequest) -> AuthResponse:
|
||
"""
|
||
用户注册
|
||
|
||
Args:
|
||
request: 注册请求数据
|
||
|
||
Returns:
|
||
AuthResponse: 注册响应
|
||
"""
|
||
try:
|
||
# 验证输入数据
|
||
validation_error = self._validate_register_request(request)
|
||
if validation_error:
|
||
return AuthResponse(
|
||
success=False,
|
||
message=validation_error
|
||
)
|
||
|
||
# 创建用户
|
||
user_id = self.user_table.create_user(
|
||
username=request.username,
|
||
email=request.email,
|
||
password=request.password,
|
||
display_name=request.display_name
|
||
)
|
||
|
||
# 获取创建的用户信息
|
||
user = self.user_table.get_user_by_id(user_id)
|
||
if not user:
|
||
return AuthResponse(
|
||
success=False,
|
||
message="用户创建失败"
|
||
)
|
||
|
||
# 生成JWT token
|
||
token_info = generate_access_token(
|
||
user_id=user["id"],
|
||
username=user["username"],
|
||
email=user["email"]
|
||
)
|
||
|
||
logger.info(f"User registered successfully: {user['username']}")
|
||
|
||
# 准备安全的用户信息(不包含密码哈希)
|
||
safe_user = {k: v for k, v in user.items() if k != "password_hash"}
|
||
|
||
return AuthResponse(
|
||
success=True,
|
||
message="注册成功",
|
||
user=safe_user,
|
||
token=token_info["token"],
|
||
expires_at=token_info["expires_at"]
|
||
)
|
||
|
||
except ValueError as e:
|
||
logger.warning(f"Registration failed: {e}")
|
||
return AuthResponse(
|
||
success=False,
|
||
message=str(e)
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"Registration error: {e}")
|
||
return AuthResponse(
|
||
success=False,
|
||
message="注册失败,请稍后重试"
|
||
)
|
||
|
||
def login(self, request: LoginRequest) -> AuthResponse:
|
||
"""
|
||
用户登录
|
||
|
||
Args:
|
||
request: 登录请求数据
|
||
|
||
Returns:
|
||
AuthResponse: 登录响应
|
||
"""
|
||
try:
|
||
# 验证输入数据
|
||
validation_error = self._validate_login_request(request)
|
||
if validation_error:
|
||
return AuthResponse(
|
||
success=False,
|
||
message=validation_error
|
||
)
|
||
|
||
# 使用UserTable的认证方法
|
||
user = self.user_table.authenticate_user(request.username_or_email, request.password)
|
||
if not user:
|
||
return AuthResponse(
|
||
success=False,
|
||
message="用户名或密码错误"
|
||
)
|
||
|
||
# 检查用户状态
|
||
if not user.get("is_active", True):
|
||
return AuthResponse(
|
||
success=False,
|
||
message="账户已被禁用"
|
||
)
|
||
|
||
# 生成JWT token
|
||
token_info = generate_access_token(
|
||
user_id=user["id"],
|
||
username=user["username"],
|
||
email=user["email"]
|
||
)
|
||
|
||
logger.info(f"User logged in successfully: {user['username']}")
|
||
|
||
# 准备安全的用户信息(不包含密码哈希)
|
||
safe_user = {k: v for k, v in user.items() if k != "password_hash"}
|
||
|
||
return AuthResponse(
|
||
success=True,
|
||
message="登录成功",
|
||
user=safe_user,
|
||
token=token_info["token"],
|
||
expires_at=token_info["expires_at"]
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"Login error: {e}")
|
||
return AuthResponse(
|
||
success=False,
|
||
message="登录失败,请稍后重试"
|
||
)
|
||
|
||
def verify_token(self, token: str) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
验证JWT token
|
||
|
||
Args:
|
||
token: JWT token字符串
|
||
|
||
Returns:
|
||
Dict: 用户信息,验证失败返回None
|
||
"""
|
||
try:
|
||
# 验证token
|
||
payload = verify_access_token(token)
|
||
if not payload:
|
||
return None
|
||
|
||
# 获取用户信息
|
||
user = self.user_table.get_user_by_id(payload["user_id"])
|
||
if not user or not user.get("is_active", True):
|
||
return None
|
||
|
||
return {
|
||
"user_id": user["id"],
|
||
"username": user["username"],
|
||
"email": user["email"],
|
||
"display_name": user["display_name"],
|
||
"avatar_url": user.get("avatar_url", ""),
|
||
"last_login": user.get("last_login")
|
||
}
|
||
|
||
except Exception as e:
|
||
logger.error(f"Token verification error: {e}")
|
||
return None
|
||
|
||
def get_current_user(self, token: str) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
根据token获取当前用户
|
||
|
||
Args:
|
||
token: JWT token字符串
|
||
|
||
Returns:
|
||
Dict: 用户信息字典,失败返回None
|
||
"""
|
||
try:
|
||
payload = verify_access_token(token)
|
||
if not payload:
|
||
return None
|
||
|
||
user = self.user_table.get_user_by_id(payload["user_id"])
|
||
if not user or not user.get("is_active", True):
|
||
return None
|
||
|
||
# 返回安全的用户信息(不包含密码哈希)
|
||
safe_user = {k: v for k, v in user.items() if k != "password_hash"}
|
||
return safe_user
|
||
|
||
except Exception as e:
|
||
logger.error(f"Get current user error: {e}")
|
||
return None
|
||
|
||
def _validate_register_request(self, request: RegisterRequest) -> Optional[str]:
|
||
"""验证注册请求数据"""
|
||
|
||
# 验证用户名
|
||
if not request.username or len(request.username.strip()) < 3:
|
||
return "用户名至少需要3个字符"
|
||
|
||
if len(request.username) > 50:
|
||
return "用户名不能超过50个字符"
|
||
|
||
if not re.match(r'^[a-zA-Z0-9_]+$', request.username):
|
||
return "用户名只能包含字母、数字和下划线"
|
||
|
||
# 验证邮箱
|
||
if not request.email:
|
||
return "邮箱地址不能为空"
|
||
|
||
email_pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
|
||
if not re.match(email_pattern, request.email):
|
||
return "邮箱地址格式不正确"
|
||
|
||
# 验证密码
|
||
if not request.password:
|
||
return "密码不能为空"
|
||
|
||
if len(request.password) < 6:
|
||
return "密码至少需要6个字符"
|
||
|
||
if len(request.password) > 100:
|
||
return "密码不能超过100个字符"
|
||
|
||
# 验证显示名称
|
||
if request.display_name and len(request.display_name) > 100:
|
||
return "显示名称不能超过100个字符"
|
||
|
||
return None
|
||
|
||
def _validate_login_request(self, request: LoginRequest) -> Optional[str]:
|
||
"""验证登录请求数据"""
|
||
|
||
if not request.username_or_email:
|
||
return "用户名或邮箱不能为空"
|
||
|
||
if not request.password:
|
||
return "密码不能为空"
|
||
|
||
return None
|
||
|
||
|
||
# 创建全局认证服务实例
|
||
auth_service = AuthService()
|