283 lines
9.6 KiB
Python
283 lines
9.6 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
JSON-RPC Server Implementation
|
||
JSON-RPC 服务器实现
|
||
|
||
Provides HTTP and WebSocket JSON-RPC server implementations.
|
||
"""
|
||
|
||
import json
|
||
import asyncio
|
||
from typing import Any, Dict, Optional, Callable, List
|
||
from dataclasses import dataclass
|
||
from http.server import HTTPServer, BaseHTTPRequestHandler
|
||
import threading
|
||
import logging
|
||
|
||
try:
|
||
from jsonrpc import JSONRPCResponseManager, dispatcher
|
||
from jsonrpc.exceptions import JSONRPCError
|
||
JSON_RPC_AVAILABLE = True
|
||
except ImportError:
|
||
JSON_RPC_AVAILABLE = False
|
||
|
||
from .jsonrpc_enhanced import JSONRPCMethodRegistry, EnhancedProgressReporter
|
||
|
||
|
||
@dataclass
|
||
class ServerConfig:
|
||
"""服务器配置"""
|
||
host: str = "localhost"
|
||
port: int = 8080
|
||
debug: bool = False
|
||
cors_enabled: bool = True
|
||
max_request_size: int = 1024 * 1024 # 1MB
|
||
|
||
|
||
class JSONRPCHTTPHandler(BaseHTTPRequestHandler):
|
||
"""JSON-RPC HTTP 请求处理器"""
|
||
|
||
def __init__(self, method_registry: JSONRPCMethodRegistry, config: ServerConfig, *args, **kwargs):
|
||
self.method_registry = method_registry
|
||
self.config = config
|
||
super().__init__(*args, **kwargs)
|
||
|
||
def do_POST(self):
|
||
"""处理POST请求"""
|
||
try:
|
||
# 检查Content-Type
|
||
content_type = self.headers.get('Content-Type', '')
|
||
if 'application/json' not in content_type:
|
||
self._send_error(400, "Content-Type must be application/json")
|
||
return
|
||
|
||
# 读取请求体
|
||
content_length = int(self.headers.get('Content-Length', 0))
|
||
if content_length > self.config.max_request_size:
|
||
self._send_error(413, "Request too large")
|
||
return
|
||
|
||
request_data = self.rfile.read(content_length).decode('utf-8')
|
||
|
||
# 处理JSON-RPC请求
|
||
response_data = self.method_registry.handle_request(request_data)
|
||
|
||
# 如果响应为None,表示响应已经异步发送,不需要额外响应
|
||
if response_data is not None:
|
||
# 发送响应
|
||
self._send_json_response(response_data)
|
||
else:
|
||
# 发送空的200响应,表示请求已处理但无需响应体
|
||
self.send_response(200)
|
||
if self.config.cors_enabled:
|
||
self._send_cors_headers()
|
||
self.end_headers()
|
||
|
||
except Exception as e:
|
||
if self.config.debug:
|
||
logging.exception("Error handling request")
|
||
self._send_error(500, f"Internal server error: {str(e)}")
|
||
|
||
def do_OPTIONS(self):
|
||
"""处理OPTIONS请求(CORS预检)"""
|
||
if self.config.cors_enabled:
|
||
self._send_cors_headers()
|
||
self.end_headers()
|
||
else:
|
||
self._send_error(405, "Method not allowed")
|
||
|
||
def _send_json_response(self, data: str):
|
||
"""发送JSON响应"""
|
||
self.send_response(200)
|
||
self.send_header('Content-Type', 'application/json')
|
||
if self.config.cors_enabled:
|
||
self._send_cors_headers()
|
||
self.end_headers()
|
||
self.wfile.write(data.encode('utf-8'))
|
||
|
||
def _send_cors_headers(self):
|
||
"""发送CORS头"""
|
||
self.send_header('Access-Control-Allow-Origin', '*')
|
||
self.send_header('Access-Control-Allow-Methods', 'POST, OPTIONS')
|
||
self.send_header('Access-Control-Allow-Headers', 'Content-Type')
|
||
|
||
def _send_error(self, code: int, message: str):
|
||
"""发送错误响应"""
|
||
self.send_response(code)
|
||
self.send_header('Content-Type', 'application/json')
|
||
if self.config.cors_enabled:
|
||
self._send_cors_headers()
|
||
self.end_headers()
|
||
|
||
error_response = {
|
||
"jsonrpc": "2.0",
|
||
"id": None,
|
||
"error": {
|
||
"code": code,
|
||
"message": message
|
||
}
|
||
}
|
||
self.wfile.write(json.dumps(error_response).encode('utf-8'))
|
||
|
||
def log_message(self, format, *args):
|
||
"""重写日志方法"""
|
||
if self.config.debug:
|
||
super().log_message(format, *args)
|
||
|
||
|
||
class JSONRPCServer:
|
||
"""JSON-RPC HTTP 服务器"""
|
||
|
||
def __init__(self, config: Optional[ServerConfig] = None):
|
||
self.config = config or ServerConfig()
|
||
self.method_registry = JSONRPCMethodRegistry()
|
||
self.server: Optional[HTTPServer] = None
|
||
self.server_thread: Optional[threading.Thread] = None
|
||
self.running = False
|
||
|
||
def register_method(self, name: Optional[str] = None):
|
||
"""注册方法装饰器"""
|
||
return self.method_registry.register(name)
|
||
|
||
def register_function(self, func: Callable, name: Optional[str] = None):
|
||
"""注册函数"""
|
||
self.method_registry.register_function(func, name)
|
||
|
||
def start(self, blocking: bool = True):
|
||
"""启动服务器"""
|
||
if self.running:
|
||
raise RuntimeError("Server is already running")
|
||
|
||
# 创建处理器工厂
|
||
def handler_factory(*args, **kwargs):
|
||
return JSONRPCHTTPHandler(self.method_registry, self.config, *args, **kwargs)
|
||
|
||
# 创建服务器
|
||
self.server = HTTPServer((self.config.host, self.config.port), handler_factory)
|
||
self.running = True
|
||
|
||
print(f"🚀 JSON-RPC Server started on http://{self.config.host}:{self.config.port}")
|
||
|
||
if blocking:
|
||
try:
|
||
self.server.serve_forever()
|
||
except KeyboardInterrupt:
|
||
self.stop()
|
||
else:
|
||
self.server_thread = threading.Thread(target=self.server.serve_forever)
|
||
self.server_thread.daemon = True
|
||
self.server_thread.start()
|
||
|
||
def stop(self):
|
||
"""停止服务器"""
|
||
if self.server and self.running:
|
||
print("🛑 Stopping JSON-RPC Server...")
|
||
self.server.shutdown()
|
||
self.server.server_close()
|
||
self.running = False
|
||
|
||
if self.server_thread:
|
||
self.server_thread.join(timeout=5)
|
||
|
||
def __enter__(self):
|
||
"""上下文管理器入口"""
|
||
return self
|
||
|
||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||
"""上下文管理器出口"""
|
||
self.stop()
|
||
|
||
|
||
# 异步WebSocket服务器(需要额外依赖)
|
||
try:
|
||
import websockets
|
||
import asyncio
|
||
|
||
class JSONRPCWebSocketServer:
|
||
"""JSON-RPC WebSocket 服务器"""
|
||
|
||
def __init__(self, config: Optional[ServerConfig] = None):
|
||
self.config = config or ServerConfig()
|
||
self.method_registry = JSONRPCMethodRegistry()
|
||
self.clients: List[Any] = []
|
||
|
||
def register_method(self, name: Optional[str] = None):
|
||
"""注册方法装饰器"""
|
||
return self.method_registry.register(name)
|
||
|
||
async def handle_client(self, websocket, path):
|
||
"""处理WebSocket客户端"""
|
||
self.clients.append(websocket)
|
||
try:
|
||
async for message in websocket:
|
||
try:
|
||
response = self.method_registry.handle_request(message)
|
||
await websocket.send(response)
|
||
except Exception as e:
|
||
error_response = {
|
||
"jsonrpc": "2.0",
|
||
"id": None,
|
||
"error": {
|
||
"code": -32603,
|
||
"message": f"Internal error: {str(e)}"
|
||
}
|
||
}
|
||
await websocket.send(json.dumps(error_response))
|
||
except websockets.exceptions.ConnectionClosed:
|
||
pass
|
||
finally:
|
||
if websocket in self.clients:
|
||
self.clients.remove(websocket)
|
||
|
||
async def broadcast(self, method: str, params: Any = None):
|
||
"""广播通知到所有客户端"""
|
||
if not self.clients:
|
||
return
|
||
|
||
notification = {
|
||
"jsonrpc": "2.0",
|
||
"method": method,
|
||
"params": params
|
||
}
|
||
message = json.dumps(notification)
|
||
|
||
# 发送到所有连接的客户端
|
||
disconnected = []
|
||
for client in self.clients:
|
||
try:
|
||
await client.send(message)
|
||
except websockets.exceptions.ConnectionClosed:
|
||
disconnected.append(client)
|
||
|
||
# 清理断开的连接
|
||
for client in disconnected:
|
||
self.clients.remove(client)
|
||
|
||
async def start(self):
|
||
"""启动WebSocket服务器"""
|
||
print(f"🚀 JSON-RPC WebSocket Server started on ws://{self.config.host}:{self.config.port}")
|
||
|
||
async with websockets.serve(
|
||
self.handle_client,
|
||
self.config.host,
|
||
self.config.port
|
||
):
|
||
await asyncio.Future() # 永远运行
|
||
|
||
except ImportError:
|
||
class JSONRPCWebSocketServer:
|
||
def __init__(self, *args, **kwargs):
|
||
raise ImportError("WebSocket server requires 'websockets' package")
|
||
|
||
|
||
# 便捷函数
|
||
def create_server(config: Optional[ServerConfig] = None) -> JSONRPCServer:
|
||
"""创建HTTP服务器"""
|
||
return JSONRPCServer(config)
|
||
|
||
|
||
def create_websocket_server(config: Optional[ServerConfig] = None) -> JSONRPCWebSocketServer:
|
||
"""创建WebSocket服务器"""
|
||
return JSONRPCWebSocketServer(config)
|