import asyncio import json import os import uuid from datetime import datetime from typing import Optional, List, Dict, Any, Set import aiohttp import uvicorn from fastapi import FastAPI, Request, HTTPException, Path from fastapi.responses import JSONResponse from pydantic import BaseModel from workflow_service import comfyui_client from workflow_service import database from workflow_service import s3_client from workflow_service import workflow_parser from workflow_service.comfyui_client import ComfyUIExecutionError from workflow_service.config import Settings settings = Settings() web_app = FastAPI(title="ComfyUI Workflow Service & Management API") # --- Pydantic Response Models for New Endpoints --- class ServerQueueDetails(BaseModel): running_count: int pending_count: int class ServerStatus(BaseModel): server_index: int http_url: str ws_url: str input_dir: str output_dir: str is_reachable: bool is_free: bool queue_details: ServerQueueDetails class FileDetails(BaseModel): name: str size_kb: float modified_at: datetime class ServerFiles(BaseModel): server_index: int http_url: str input_files: List[FileDetails] output_files: List[FileDetails] @web_app.on_event("startup") async def startup_event(): """服务启动时,初始化数据库并为所有配置的服务器创建输入/输出目录。""" await database.init_db() try: servers = settings.SERVERS print(f"检测到 {len(servers)} 个 ComfyUI 服务器配置。") for server in servers: print(f" - 正在为服务器 {server.http_url} 准备目录...") os.makedirs(server.input_dir, exist_ok=True) print(f" - 输入目录: {os.path.abspath(server.input_dir)}") os.makedirs(server.output_dir, exist_ok=True) print(f" - 输出目录: {os.path.abspath(server.output_dir)}") except ValueError as e: print(f"错误: 无法在启动时初始化服务器目录: {e}") # --- Section 1: 工作流管理API --- BASE_MANAGEMENT_PATH = "/api/workflow" @web_app.post(BASE_MANAGEMENT_PATH, status_code=201) async def publish_workflow_endpoint(request: Request): try: data = await request.json() name, wf_json = data.get("name"), data.get("workflow") if not name or not wf_json: raise HTTPException(status_code=400, detail="`name` and `workflow` fields are required.") await database.save_workflow(name, json.dumps(wf_json)) print(f"Workflow '{name}' published.") return JSONResponse( content={"status": "success", "message": f"Workflow '{name}' published."}, status_code=201) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to save workflow: {e}") @web_app.get(BASE_MANAGEMENT_PATH, response_model=List[dict]) async def get_all_workflows_endpoint(): try: return await database.get_all_workflows() except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to get workflows: {e}") # [BUG修复] 修正API路径,使其与其他管理端点保持一致 @web_app.delete(f"{BASE_MANAGEMENT_PATH}/{{workflow_name:path}}") async def delete_workflow_endpoint(workflow_name: str = Path(..., description="The full, unique name of the workflow to delete, e.g., 'my_workflow [20250101120000]'")): try: success = await database.delete_workflow(workflow_name) if success: return {"status": "deleted", "name": workflow_name} else: raise HTTPException(status_code=404, detail=f"Workflow '{workflow_name}' not found") except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to delete workflow: {e}") # --- Section 2: 工作流执行API (核心改动) --- def get_files_in_dir(directory: str) -> Set[str]: file_set = set() for root, _, files in os.walk(directory): for filename in files: if not filename.startswith('.'): file_set.add(os.path.join(root, filename)) return file_set async def download_file_from_url(session: aiohttp.ClientSession, url: str, save_path: str): async with session.get(url) as response: response.raise_for_status() with open(save_path, 'wb') as f: while True: chunk = await response.content.read(8192) if not chunk: break f.write(chunk) async def handle_file_upload(file_path: str, base_name: str) -> str: s3_object_name = f"outputs/{base_name}/{uuid.uuid4()}_{os.path.basename(file_path)}" return await s3_client.upload_file_to_s3(file_path, settings.S3_BUCKET_NAME, s3_object_name) @web_app.post("/api/run/") async def execute_workflow_endpoint(base_name: str, request_data_raw: Dict[str, Any], version: Optional[str] = None): cleanup_paths = [] try: # 1. 获取工作流定义 if version: workflow_data = await database.get_workflow_by_version(base_name, version) else: workflow_data = await database.get_latest_workflow_by_base_name(base_name) if not workflow_data: detail = f"工作流 '{base_name}'" + (f" 带版本 '{version}'" if version else " (最新版)") + " 未找到。" raise HTTPException(status_code=404, detail=detail) workflow = json.loads(workflow_data['workflow_json']) api_spec = workflow_parser.parse_api_spec(workflow) request_data = {k.lower(): v for k, v in request_data_raw.items()} # 2. [核心改动] 使用智能调度选择服务器 try: selected_server = await comfyui_client.select_server_for_execution() except Exception as e: raise HTTPException(status_code=503, detail=f"无法选择ComfyUI服务器执行任务: {e}") server_input_dir = selected_server.input_dir server_output_dir = selected_server.output_dir # 3. 下载文件到选定服务器的输入目录 async with aiohttp.ClientSession() as session: for param_name, spec in api_spec["inputs"].items(): if spec["type"] == "UploadFile" and param_name in request_data: image_url = request_data[param_name] if not isinstance(image_url, str) or not image_url.startswith('http'): raise HTTPException(status_code=400, detail=f"参数 '{param_name}' 必须是一个有效的URL。") filename = f"api_download_{uuid.uuid4().hex}{os.path.splitext(image_url.split('/')[-1])[1] or '.dat'}" save_path = os.path.join(server_input_dir, filename) try: await download_file_from_url(session, image_url, save_path) # 直接更新 request_data,以便后续的 patch_workflow 使用 request_data[param_name] = filename cleanup_paths.append(save_path) except Exception as e: raise HTTPException(status_code=500, detail=f"为 '{param_name}' 从 {image_url} 下载文件失败。错误: {e}") # 4. Patch工作流,生成最终的API Prompt patched_workflow = workflow_parser.patch_workflow(workflow, api_spec, request_data) prompt_to_run = workflow_parser.convert_workflow_to_prompt_api_format(patched_workflow) # 5. 执行前快照,并在选定服务器上执行工作流 files_before = get_files_in_dir(server_output_dir) output_nodes = await comfyui_client.execute_prompt_on_server(prompt_to_run, selected_server) # 6. 处理输出(与之前逻辑相同) files_after = get_files_in_dir(server_output_dir) new_files = files_after - files_before output_response = {} if new_files: s3_urls = [] for file_path in new_files: cleanup_paths.append(file_path) try: s3_urls.append(await handle_file_upload(file_path, base_name)) except Exception as e: print(f"上传文件 {file_path} 到S3时出错: {e}") if s3_urls: output_response["output_files"] = s3_urls for final_param_name, spec in api_spec["outputs"].items(): node_id = spec["node_id"] if node_id in output_nodes: node_output = output_nodes[node_id] original_output_name = spec["output_name"] output_value = None if original_output_name in node_output: output_value = node_output[original_output_name] elif "text" in node_output: # 备选,例如对于'ShowText'节点 output_value = node_output["text"] if output_value is None: continue if isinstance(output_value, list): output_value = output_value[0] if output_value else None output_response[final_param_name] = output_value return JSONResponse(content=output_response) # [核心改动] 捕获来自ComfyUI的执行失败异常 except ComfyUIExecutionError as e: print(f"捕获到ComfyUI执行错误: {e.error_data}") # 返回 502 Bad Gateway 状态码,表示上游服务器出错 # detail 中包含结构化的错误信息,方便客户端处理 raise HTTPException( status_code=500, detail={ "message": "工作流在上游ComfyUI节点中执行失败。", "error_details": e.error_data } ) except Exception as e: # 捕获其他所有异常,作为通用的服务器内部错误 print(f"执行工作流时发生未知错误: {e}") raise HTTPException(status_code=500, detail=str(e)) finally: if cleanup_paths: print(f"正在清理 {len(cleanup_paths)} 个临时文件...") for path in cleanup_paths: try: if os.path.exists(path): os.remove(path) print(f" - 已删除: {path}") except Exception as e: print(f" - 删除 {path} 时出错: {e}") # --- Section 3: 工作流元数据/规范API --- @web_app.get("/api/spec/") async def get_workflow_spec_endpoint(base_name: str, version: Optional[str] = None): if version: workflow_data = await database.get_workflow_by_version(base_name, version) else: workflow_data = await database.get_latest_workflow_by_base_name(base_name) if not workflow_data: detail = f"Workflow '{base_name}'" + (f" with version '{version}'" if version else " (latest)") + " not found." raise HTTPException(status_code=404, detail=detail) try: workflow = json.loads(workflow_data['workflow_json']) return workflow_parser.parse_api_spec(workflow) except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to parse workflow specification: {e}") # --- NEW Section: 服务器监控API --- @web_app.get("/api/servers/status", response_model=List[ServerStatus], tags=["Server Monitoring"]) async def get_servers_status_endpoint(): """ 获取所有已配置的ComfyUI服务器的配置信息和实时状态。 """ servers = settings.SERVERS if not servers: return [] async with aiohttp.ClientSession() as session: status_tasks = [comfyui_client.get_server_status(server, session) for server in servers] live_statuses = await asyncio.gather(*status_tasks) response_list = [] for i, server_config in enumerate(servers): status_data = live_statuses[i] response_list.append( ServerStatus( server_index=i, http_url=server_config.http_url, ws_url=server_config.ws_url, input_dir=server_config.input_dir, output_dir=server_config.output_dir, is_reachable=status_data["is_reachable"], is_free=status_data["is_free"], queue_details=status_data["queue_details"] ) ) return response_list async def _get_folder_contents(path: str) -> List[FileDetails]: """异步地列出并返回文件夹内容的详细信息。""" if not os.path.isdir(path): return [] def sync_list_files(dir_path): files = [] try: for entry in os.scandir(dir_path): if entry.is_file(): stat = entry.stat() files.append( FileDetails( name=entry.name, size_kb=round(stat.st_size / 1024, 2), modified_at=datetime.fromtimestamp(stat.st_mtime) ) ) except OSError as e: print(f"无法扫描目录 {dir_path}: {e}") return sorted(files, key=lambda x: x.modified_at, reverse=True) return await asyncio.to_thread(sync_list_files, path) @web_app.get("/api/servers/{server_index}/files", response_model=ServerFiles, tags=["Server Monitoring"]) async def list_server_files_endpoint(server_index: int = Path(..., ge=0, description="服务器在配置列表中的索引")): """ 获取指定ComfyUI服务器的输入和输出文件夹中的文件列表。 """ servers = settings.SERVERS if server_index >= len(servers): raise HTTPException(status_code=404, detail=f"服务器索引 {server_index} 超出范围。有效索引为 0 到 {len(servers) - 1}。") server_config = servers[server_index] input_files, output_files = await asyncio.gather( _get_folder_contents(server_config.input_dir), _get_folder_contents(server_config.output_dir) ) return ServerFiles( server_index=server_index, http_url=server_config.http_url, input_files=input_files, output_files=output_files ) @web_app.get("/", response_class=JSONResponse) def read_root(): """ 提供一个API的快速使用指南。 """ guide = { "service_name": "ComfyUI Workflow Service & Management API", "description": "一个用于发布、管理和执行 ComfyUI 工作流的API服务。它将 ComfyUI 的图形化工作流抽象为标准的 RESTful API 端点。", "quick_start_guide": [ { "step": 1, "action": "发布一个工作流 (Publish a Workflow)", "description": "从 ComfyUI 保存你的工作流 (API格式),然后使用一个唯一的名称将其发布到服务中。名称中必须包含一个版本号,格式为 `[YYYYMMDDHHMMSS]`。", "endpoint": f"POST {BASE_MANAGEMENT_PATH}", "example_curl": "curl -X POST http://127.0.0.1:18000/api/workflow -H 'Content-Type: application/json' -d '{\n \"name\": \"my_t2i [20250801120000]\",\n \"workflow\": { ... your_workflow_api_json ... }\n}'" }, { "step": 2, "action": "查看工作流的API规范 (Inspect Workflow API Spec)", "description": "一旦发布,你可以查询工作流的API规范,以了解需要提供哪些输入参数以及可以期望哪些输出。", "endpoint": "/api/spec/", "parameters": [ {"name": "base_name", "description": "工作流的基础名称 (不含版本部分)。"}, {"name": "version", "description": "可选。工作流的具体版本号 (YYYYMMDDHHMMSS)。如果省略,则获取最新版本。"} ], "example_curl": "curl -X GET 'http://127.0.0.1:18000/api/spec/?base_name=my_t2i'" }, { "step": 3, "action": "执行工作流 (Execute a Workflow)", "description": "使用获取到的API规范,通过HTTP POST请求来执行工作流。对于文件输入,请提供可公开访问的URL。输出文件将被上传到S3并返回URL。", "endpoint": "/api/run/", "parameters": [ {"name": "base_name", "description": "工作流的基础名称。"}, {"name": "version", "description": "可选。要执行的特定版本。如果省略,则执行最新版本。"} ], "example_curl": "curl -X POST 'http://127.0.0.1:18000/api/run/?base_name=my_t2i' -H 'Content-Type: application/json' -d '{\n \"prompt_prompt\": \"a beautiful cat sitting on a roof\",\n \"sampler_seed\": 12345\n}'" }, { "step": 4, "action": "管理工作流 (Manage Workflows)", "description": "你也可以列出所有已发布的工作流或删除不再需要的工作流。", "endpoints": [ {"method": "GET", "path": f"{BASE_MANAGEMENT_PATH}", "description": "列出所有工作流。"}, {"method": "DELETE", "path": f"{BASE_MANAGEMENT_PATH}/{{workflow_name}}", "description": "按完整名称删除一个工作流。"} ], "example_curl_list": f"curl -X GET http://127.0.0.1:18000{BASE_MANAGEMENT_PATH}", "example_curl_delete": f"curl -X DELETE 'http://127.0.0.1:18000/api/workflow/my_t2i%20[20250801120000]'" } ], "notes": "请确保在 curl 命令中对包含空格或特殊字符的工作流名称进行URL编码。" } return JSONResponse(content=guide) if __name__ == "__main__": uvicorn.run("workflow_service.main:web_app", host="0.0.0.0", port=18000, reload=True)