422 lines
18 KiB
Python
422 lines
18 KiB
Python
import asyncio
|
||
import json
|
||
import os
|
||
import uuid
|
||
from datetime import datetime
|
||
from typing import Optional, List, Dict, Any, Set
|
||
|
||
import aiohttp
|
||
import uvicorn
|
||
from fastapi import FastAPI, Request, HTTPException, Path
|
||
from fastapi.responses import JSONResponse
|
||
from pydantic import BaseModel
|
||
|
||
from workflow_service import comfyui_client
|
||
from workflow_service import database
|
||
from workflow_service import s3_client
|
||
from workflow_service import workflow_parser
|
||
from workflow_service.comfyui_client import ComfyUIExecutionError
|
||
from workflow_service.config import Settings
|
||
|
||
settings = Settings()
|
||
|
||
web_app = FastAPI(title="ComfyUI Workflow Service & Management API")
|
||
|
||
|
||
# --- Pydantic Response Models for New Endpoints ---
|
||
|
||
class ServerQueueDetails(BaseModel):
|
||
running_count: int
|
||
pending_count: int
|
||
|
||
|
||
class ServerStatus(BaseModel):
|
||
server_index: int
|
||
http_url: str
|
||
ws_url: str
|
||
input_dir: str
|
||
output_dir: str
|
||
is_reachable: bool
|
||
is_free: bool
|
||
queue_details: ServerQueueDetails
|
||
|
||
|
||
class FileDetails(BaseModel):
|
||
name: str
|
||
size_kb: float
|
||
modified_at: datetime
|
||
|
||
|
||
class ServerFiles(BaseModel):
|
||
server_index: int
|
||
http_url: str
|
||
input_files: List[FileDetails]
|
||
output_files: List[FileDetails]
|
||
|
||
|
||
@web_app.on_event("startup")
|
||
async def startup_event():
|
||
"""服务启动时,初始化数据库并为所有配置的服务器创建输入/输出目录。"""
|
||
await database.init_db()
|
||
try:
|
||
servers = settings.SERVERS
|
||
print(f"检测到 {len(servers)} 个 ComfyUI 服务器配置。")
|
||
for server in servers:
|
||
print(f" - 正在为服务器 {server.http_url} 准备目录...")
|
||
os.makedirs(server.input_dir, exist_ok=True)
|
||
print(f" - 输入目录: {os.path.abspath(server.input_dir)}")
|
||
os.makedirs(server.output_dir, exist_ok=True)
|
||
print(f" - 输出目录: {os.path.abspath(server.output_dir)}")
|
||
except ValueError as e:
|
||
print(f"错误: 无法在启动时初始化服务器目录: {e}")
|
||
|
||
|
||
# --- Section 1: 工作流管理API ---
|
||
BASE_MANAGEMENT_PATH = "/api/workflow"
|
||
|
||
|
||
@web_app.post(BASE_MANAGEMENT_PATH, status_code=201)
|
||
async def publish_workflow_endpoint(request: Request):
|
||
try:
|
||
data = await request.json()
|
||
name, wf_json = data.get("name"), data.get("workflow")
|
||
if not name or not wf_json:
|
||
raise HTTPException(status_code=400, detail="`name` and `workflow` fields are required.")
|
||
await database.save_workflow(name, json.dumps(wf_json))
|
||
print(f"Workflow '{name}' published.")
|
||
return JSONResponse(
|
||
content={"status": "success", "message": f"Workflow '{name}' published."}, status_code=201)
|
||
except ValueError as e:
|
||
raise HTTPException(status_code=400, detail=str(e))
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"Failed to save workflow: {e}")
|
||
|
||
|
||
@web_app.get(BASE_MANAGEMENT_PATH, response_model=List[dict])
|
||
async def get_all_workflows_endpoint():
|
||
try:
|
||
return await database.get_all_workflows()
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"Failed to get workflows: {e}")
|
||
|
||
|
||
# [BUG修复] 修正API路径,使其与其他管理端点保持一致
|
||
@web_app.delete(f"{BASE_MANAGEMENT_PATH}/{{workflow_name:path}}")
|
||
async def delete_workflow_endpoint(workflow_name: str = Path(...,
|
||
description="The full, unique name of the workflow to delete, e.g., 'my_workflow [20250101120000]'")):
|
||
try:
|
||
success = await database.delete_workflow(workflow_name)
|
||
if success:
|
||
return {"status": "deleted", "name": workflow_name}
|
||
else:
|
||
raise HTTPException(status_code=404, detail=f"Workflow '{workflow_name}' not found")
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"Failed to delete workflow: {e}")
|
||
|
||
|
||
# --- Section 2: 工作流执行API (核心改动) ---
|
||
|
||
def get_files_in_dir(directory: str) -> Set[str]:
|
||
file_set = set()
|
||
for root, _, files in os.walk(directory):
|
||
for filename in files:
|
||
if not filename.startswith('.'):
|
||
file_set.add(os.path.join(root, filename))
|
||
return file_set
|
||
|
||
|
||
async def download_file_from_url(session: aiohttp.ClientSession, url: str, save_path: str):
|
||
async with session.get(url) as response:
|
||
response.raise_for_status()
|
||
with open(save_path, 'wb') as f:
|
||
while True:
|
||
chunk = await response.content.read(8192)
|
||
if not chunk: break
|
||
f.write(chunk)
|
||
|
||
|
||
async def handle_file_upload(file_path: str, base_name: str) -> str:
|
||
s3_object_name = f"outputs/{base_name}/{uuid.uuid4()}_{os.path.basename(file_path)}"
|
||
return await s3_client.upload_file_to_s3(file_path, settings.S3_BUCKET_NAME, s3_object_name)
|
||
|
||
|
||
@web_app.post("/api/run/")
|
||
async def execute_workflow_endpoint(base_name: str, request_data_raw: Dict[str, Any], version: Optional[str] = None):
|
||
cleanup_paths = []
|
||
try:
|
||
# 1. 获取工作流定义
|
||
if version:
|
||
workflow_data = await database.get_workflow_by_version(base_name, version)
|
||
else:
|
||
workflow_data = await database.get_latest_workflow_by_base_name(base_name)
|
||
|
||
if not workflow_data:
|
||
detail = f"工作流 '{base_name}'" + (f" 带版本 '{version}'" if version else " (最新版)") + " 未找到。"
|
||
raise HTTPException(status_code=404, detail=detail)
|
||
|
||
workflow = json.loads(workflow_data['workflow_json'])
|
||
api_spec = workflow_parser.parse_api_spec(workflow)
|
||
request_data = {k.lower(): v for k, v in request_data_raw.items()}
|
||
|
||
# 2. [核心改动] 使用智能调度选择服务器
|
||
try:
|
||
selected_server = await comfyui_client.select_server_for_execution()
|
||
except Exception as e:
|
||
raise HTTPException(status_code=503, detail=f"无法选择ComfyUI服务器执行任务: {e}")
|
||
|
||
server_input_dir = selected_server.input_dir
|
||
server_output_dir = selected_server.output_dir
|
||
|
||
# 3. 下载文件到选定服务器的输入目录
|
||
async with aiohttp.ClientSession() as session:
|
||
for param_name, spec in api_spec["inputs"].items():
|
||
if spec["type"] == "UploadFile" and param_name in request_data:
|
||
image_url = request_data[param_name]
|
||
if not isinstance(image_url, str) or not image_url.startswith('http'):
|
||
raise HTTPException(status_code=400, detail=f"参数 '{param_name}' 必须是一个有效的URL。")
|
||
|
||
filename = f"api_download_{uuid.uuid4().hex}{os.path.splitext(image_url.split('/')[-1])[1] or '.dat'}"
|
||
save_path = os.path.join(server_input_dir, filename)
|
||
try:
|
||
await download_file_from_url(session, image_url, save_path)
|
||
# 直接更新 request_data,以便后续的 patch_workflow 使用
|
||
request_data[param_name] = filename
|
||
cleanup_paths.append(save_path)
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500,
|
||
detail=f"为 '{param_name}' 从 {image_url} 下载文件失败。错误: {e}")
|
||
|
||
# 4. Patch工作流,生成最终的API Prompt
|
||
patched_workflow = workflow_parser.patch_workflow(workflow, api_spec, request_data)
|
||
prompt_to_run = workflow_parser.convert_workflow_to_prompt_api_format(patched_workflow)
|
||
|
||
# 5. 执行前快照,并在选定服务器上执行工作流
|
||
files_before = get_files_in_dir(server_output_dir)
|
||
output_nodes = await comfyui_client.execute_prompt_on_server(prompt_to_run, selected_server)
|
||
|
||
# 6. 处理输出(与之前逻辑相同)
|
||
files_after = get_files_in_dir(server_output_dir)
|
||
new_files = files_after - files_before
|
||
|
||
output_response = {}
|
||
if new_files:
|
||
s3_urls = []
|
||
for file_path in new_files:
|
||
cleanup_paths.append(file_path)
|
||
try:
|
||
s3_urls.append(await handle_file_upload(file_path, base_name))
|
||
except Exception as e:
|
||
print(f"上传文件 {file_path} 到S3时出错: {e}")
|
||
if s3_urls:
|
||
output_response["output_files"] = s3_urls
|
||
|
||
for final_param_name, spec in api_spec["outputs"].items():
|
||
node_id = spec["node_id"]
|
||
if node_id in output_nodes:
|
||
node_output = output_nodes[node_id]
|
||
original_output_name = spec["output_name"]
|
||
|
||
output_value = None
|
||
if original_output_name in node_output:
|
||
output_value = node_output[original_output_name]
|
||
elif "text" in node_output: # 备选,例如对于'ShowText'节点
|
||
output_value = node_output["text"]
|
||
|
||
if output_value is None: continue
|
||
|
||
if isinstance(output_value, list):
|
||
output_value = output_value[0] if output_value else None
|
||
|
||
output_response[final_param_name] = output_value
|
||
|
||
return JSONResponse(content=output_response)
|
||
# [核心改动] 捕获来自ComfyUI的执行失败异常
|
||
except ComfyUIExecutionError as e:
|
||
print(f"捕获到ComfyUI执行错误: {e.error_data}")
|
||
# 返回 502 Bad Gateway 状态码,表示上游服务器出错
|
||
# detail 中包含结构化的错误信息,方便客户端处理
|
||
raise HTTPException(
|
||
status_code=500,
|
||
detail={
|
||
"message": "工作流在上游ComfyUI节点中执行失败。",
|
||
"error_details": e.error_data
|
||
}
|
||
)
|
||
except Exception as e:
|
||
# 捕获其他所有异常,作为通用的服务器内部错误
|
||
print(f"执行工作流时发生未知错误: {e}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
finally:
|
||
if cleanup_paths:
|
||
print(f"正在清理 {len(cleanup_paths)} 个临时文件...")
|
||
for path in cleanup_paths:
|
||
try:
|
||
if os.path.exists(path):
|
||
os.remove(path)
|
||
print(f" - 已删除: {path}")
|
||
except Exception as e:
|
||
print(f" - 删除 {path} 时出错: {e}")
|
||
|
||
|
||
# --- Section 3: 工作流元数据/规范API ---
|
||
@web_app.get("/api/spec/")
|
||
async def get_workflow_spec_endpoint(base_name: str, version: Optional[str] = None):
|
||
if version:
|
||
workflow_data = await database.get_workflow_by_version(base_name, version)
|
||
else:
|
||
workflow_data = await database.get_latest_workflow_by_base_name(base_name)
|
||
|
||
if not workflow_data:
|
||
detail = f"Workflow '{base_name}'" + (f" with version '{version}'" if version else " (latest)") + " not found."
|
||
raise HTTPException(status_code=404, detail=detail)
|
||
|
||
try:
|
||
workflow = json.loads(workflow_data['workflow_json'])
|
||
return workflow_parser.parse_api_spec(workflow)
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"Failed to parse workflow specification: {e}")
|
||
|
||
|
||
# --- NEW Section: 服务器监控API ---
|
||
|
||
@web_app.get("/api/servers/status", response_model=List[ServerStatus], tags=["Server Monitoring"])
|
||
async def get_servers_status_endpoint():
|
||
"""
|
||
获取所有已配置的ComfyUI服务器的配置信息和实时状态。
|
||
"""
|
||
servers = settings.SERVERS
|
||
if not servers:
|
||
return []
|
||
|
||
async with aiohttp.ClientSession() as session:
|
||
status_tasks = [comfyui_client.get_server_status(server, session) for server in servers]
|
||
live_statuses = await asyncio.gather(*status_tasks)
|
||
|
||
response_list = []
|
||
for i, server_config in enumerate(servers):
|
||
status_data = live_statuses[i]
|
||
response_list.append(
|
||
ServerStatus(
|
||
server_index=i,
|
||
http_url=server_config.http_url,
|
||
ws_url=server_config.ws_url,
|
||
input_dir=server_config.input_dir,
|
||
output_dir=server_config.output_dir,
|
||
is_reachable=status_data["is_reachable"],
|
||
is_free=status_data["is_free"],
|
||
queue_details=status_data["queue_details"]
|
||
)
|
||
)
|
||
return response_list
|
||
|
||
|
||
async def _get_folder_contents(path: str) -> List[FileDetails]:
|
||
"""异步地列出并返回文件夹内容的详细信息。"""
|
||
if not os.path.isdir(path):
|
||
return []
|
||
|
||
def sync_list_files(dir_path):
|
||
files = []
|
||
try:
|
||
for entry in os.scandir(dir_path):
|
||
if entry.is_file():
|
||
stat = entry.stat()
|
||
files.append(
|
||
FileDetails(
|
||
name=entry.name,
|
||
size_kb=round(stat.st_size / 1024, 2),
|
||
modified_at=datetime.fromtimestamp(stat.st_mtime)
|
||
)
|
||
)
|
||
except OSError as e:
|
||
print(f"无法扫描目录 {dir_path}: {e}")
|
||
return sorted(files, key=lambda x: x.modified_at, reverse=True)
|
||
|
||
return await asyncio.to_thread(sync_list_files, path)
|
||
|
||
|
||
@web_app.get("/api/servers/{server_index}/files", response_model=ServerFiles, tags=["Server Monitoring"])
|
||
async def list_server_files_endpoint(server_index: int = Path(..., ge=0, description="服务器在配置列表中的索引")):
|
||
"""
|
||
获取指定ComfyUI服务器的输入和输出文件夹中的文件列表。
|
||
"""
|
||
servers = settings.SERVERS
|
||
if server_index >= len(servers):
|
||
raise HTTPException(status_code=404,
|
||
detail=f"服务器索引 {server_index} 超出范围。有效索引为 0 到 {len(servers) - 1}。")
|
||
|
||
server_config = servers[server_index]
|
||
|
||
input_files, output_files = await asyncio.gather(
|
||
_get_folder_contents(server_config.input_dir),
|
||
_get_folder_contents(server_config.output_dir)
|
||
)
|
||
|
||
return ServerFiles(
|
||
server_index=server_index,
|
||
http_url=server_config.http_url,
|
||
input_files=input_files,
|
||
output_files=output_files
|
||
)
|
||
|
||
|
||
@web_app.get("/", response_class=JSONResponse)
|
||
def read_root():
|
||
"""
|
||
提供一个API的快速使用指南。
|
||
"""
|
||
guide = {
|
||
"service_name": "ComfyUI Workflow Service & Management API",
|
||
"description": "一个用于发布、管理和执行 ComfyUI 工作流的API服务。它将 ComfyUI 的图形化工作流抽象为标准的 RESTful API 端点。",
|
||
"quick_start_guide": [
|
||
{
|
||
"step": 1,
|
||
"action": "发布一个工作流 (Publish a Workflow)",
|
||
"description": "从 ComfyUI 保存你的工作流 (API格式),然后使用一个唯一的名称将其发布到服务中。名称中必须包含一个版本号,格式为 `[YYYYMMDDHHMMSS]`。",
|
||
"endpoint": f"POST {BASE_MANAGEMENT_PATH}",
|
||
"example_curl": "curl -X POST http://127.0.0.1:18000/api/workflow -H 'Content-Type: application/json' -d '{\n \"name\": \"my_t2i [20250801120000]\",\n \"workflow\": { ... your_workflow_api_json ... }\n}'"
|
||
},
|
||
{
|
||
"step": 2,
|
||
"action": "查看工作流的API规范 (Inspect Workflow API Spec)",
|
||
"description": "一旦发布,你可以查询工作流的API规范,以了解需要提供哪些输入参数以及可以期望哪些输出。",
|
||
"endpoint": "/api/spec/",
|
||
"parameters": [
|
||
{"name": "base_name", "description": "工作流的基础名称 (不含版本部分)。"},
|
||
{"name": "version",
|
||
"description": "可选。工作流的具体版本号 (YYYYMMDDHHMMSS)。如果省略,则获取最新版本。"}
|
||
],
|
||
"example_curl": "curl -X GET 'http://127.0.0.1:18000/api/spec/?base_name=my_t2i'"
|
||
},
|
||
{
|
||
"step": 3,
|
||
"action": "执行工作流 (Execute a Workflow)",
|
||
"description": "使用获取到的API规范,通过HTTP POST请求来执行工作流。对于文件输入,请提供可公开访问的URL。输出文件将被上传到S3并返回URL。",
|
||
"endpoint": "/api/run/",
|
||
"parameters": [
|
||
{"name": "base_name", "description": "工作流的基础名称。"},
|
||
{"name": "version", "description": "可选。要执行的特定版本。如果省略,则执行最新版本。"}
|
||
],
|
||
"example_curl": "curl -X POST 'http://127.0.0.1:18000/api/run/?base_name=my_t2i' -H 'Content-Type: application/json' -d '{\n \"prompt_prompt\": \"a beautiful cat sitting on a roof\",\n \"sampler_seed\": 12345\n}'"
|
||
},
|
||
{
|
||
"step": 4,
|
||
"action": "管理工作流 (Manage Workflows)",
|
||
"description": "你也可以列出所有已发布的工作流或删除不再需要的工作流。",
|
||
"endpoints": [
|
||
{"method": "GET", "path": f"{BASE_MANAGEMENT_PATH}", "description": "列出所有工作流。"},
|
||
{"method": "DELETE", "path": f"{BASE_MANAGEMENT_PATH}/{{workflow_name}}",
|
||
"description": "按完整名称删除一个工作流。"}
|
||
],
|
||
"example_curl_list": f"curl -X GET http://127.0.0.1:18000{BASE_MANAGEMENT_PATH}",
|
||
"example_curl_delete": f"curl -X DELETE 'http://127.0.0.1:18000/api/workflow/my_t2i%20[20250801120000]'"
|
||
}
|
||
],
|
||
"notes": "请确保在 curl 命令中对包含空格或特殊字符的工作流名称进行URL编码。"
|
||
}
|
||
return JSONResponse(content=guide)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
uvicorn.run("workflow_service.main:web_app", host="0.0.0.0", port=18000, reload=True)
|