604 lines
23 KiB
Python
604 lines
23 KiB
Python
import asyncio
|
||
import json
|
||
import os
|
||
import uuid
|
||
from datetime import datetime
|
||
from typing import Optional, List, Dict, Any, Set
|
||
|
||
import aiohttp
|
||
import uvicorn
|
||
from fastapi import FastAPI, Request, HTTPException, Path
|
||
from fastapi.responses import JSONResponse
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from pydantic import BaseModel
|
||
|
||
from workflow_service import comfyui_client
|
||
from workflow_service import database
|
||
from workflow_service.utils.s3_client import upload_file_to_s3
|
||
from workflow_service.config import Settings
|
||
|
||
settings = Settings()
|
||
|
||
web_app = FastAPI(title="ComfyUI Workflow Service & Management API")
|
||
|
||
# 配置跨域支持
|
||
web_app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=["*"], # 允许所有来源,生产环境中建议限制具体域名
|
||
allow_credentials=True,
|
||
allow_methods=["*"], # 允许所有HTTP方法
|
||
allow_headers=["*"], # 允许所有请求头
|
||
)
|
||
|
||
|
||
# --- Pydantic Response Models for New Endpoints ---
|
||
|
||
|
||
class ServerQueueDetails(BaseModel):
|
||
running_count: int
|
||
pending_count: int
|
||
|
||
|
||
class ServerStatus(BaseModel):
|
||
server_index: int
|
||
http_url: str
|
||
ws_url: str
|
||
input_dir: str
|
||
output_dir: str
|
||
is_reachable: bool
|
||
is_free: bool
|
||
queue_details: ServerQueueDetails
|
||
|
||
|
||
class FileDetails(BaseModel):
|
||
name: str
|
||
size_kb: float
|
||
modified_at: datetime
|
||
|
||
|
||
class ServerFiles(BaseModel):
|
||
server_index: int
|
||
http_url: str
|
||
input_files: List[FileDetails]
|
||
output_files: List[FileDetails]
|
||
|
||
|
||
@web_app.on_event("startup")
|
||
async def startup_event():
|
||
"""服务启动时,初始化数据库并为所有配置的服务器创建输入/输出目录。"""
|
||
await database.init_db()
|
||
try:
|
||
servers = settings.SERVERS
|
||
print(f"检测到 {len(servers)} 个 ComfyUI 服务器配置。")
|
||
for server in servers:
|
||
print(f" - 正在为服务器 {server.http_url} 准备目录...")
|
||
os.makedirs(server.input_dir, exist_ok=True)
|
||
print(f" - 输入目录: {os.path.abspath(server.input_dir)}")
|
||
os.makedirs(server.output_dir, exist_ok=True)
|
||
print(f" - 输出目录: {os.path.abspath(server.output_dir)}")
|
||
except ValueError as e:
|
||
print(f"错误: 无法在启动时初始化服务器目录: {e}")
|
||
|
||
|
||
@web_app.post("/api/workflow", status_code=201)
|
||
async def publish_workflow_endpoint(request: Request):
|
||
"""
|
||
发布工作流
|
||
"""
|
||
try:
|
||
data = await request.json()
|
||
name, wf_json = data.get("name"), data.get("workflow")
|
||
if not name or not wf_json:
|
||
raise HTTPException(
|
||
status_code=400, detail="`name` and `workflow` fields are required."
|
||
)
|
||
await database.save_workflow(name, json.dumps(wf_json))
|
||
print(f"Workflow '{name}' published.")
|
||
return JSONResponse(
|
||
content={"status": "success", "message": f"Workflow '{name}' published."},
|
||
status_code=201,
|
||
)
|
||
except ValueError as e:
|
||
raise HTTPException(status_code=400, detail=str(e))
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"Failed to save workflow: {e}")
|
||
|
||
|
||
@web_app.get("/api/workflow", response_model=List[dict])
|
||
async def get_all_workflows_endpoint():
|
||
"""
|
||
获取所有工作流
|
||
"""
|
||
try:
|
||
return await database.get_all_workflows()
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"Failed to get workflows: {e}")
|
||
|
||
|
||
@web_app.delete(f"/api/workflow/{{workflow_name:path}}")
|
||
async def delete_workflow_endpoint(
|
||
workflow_name: str = Path(
|
||
...,
|
||
description="The full, unique name of the workflow to delete, e.g., 'my_workflow [20250101120000]'",
|
||
)
|
||
):
|
||
"""
|
||
删除工作流
|
||
"""
|
||
try:
|
||
success = await database.delete_workflow(workflow_name)
|
||
if success:
|
||
return {"status": "deleted", "name": workflow_name}
|
||
else:
|
||
raise HTTPException(
|
||
status_code=404, detail=f"Workflow '{workflow_name}' not found"
|
||
)
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"Failed to delete workflow: {e}")
|
||
|
||
|
||
@web_app.get("/api/workflow/{base_name:path}")
|
||
async def get_one_workflow_endpoint(base_name: str, version: Optional[str] = None):
|
||
"""
|
||
获取工作流规范
|
||
"""
|
||
workflow_data = await database.get_workflow(base_name, version)
|
||
|
||
if not workflow_data:
|
||
detail = (
|
||
f"Workflow '{base_name}'"
|
||
+ (f" with version '{version}'" if version else " (latest)")
|
||
+ " not found."
|
||
)
|
||
raise HTTPException(status_code=404, detail=detail)
|
||
|
||
try:
|
||
workflow = json.loads(workflow_data["workflow_json"])
|
||
return {
|
||
"workflow": workflow,
|
||
"api_spec": comfyui_client.parse_api_spec(workflow),
|
||
}
|
||
except Exception as e:
|
||
raise HTTPException(
|
||
status_code=500, detail=f"Failed to parse workflow specification: {e}"
|
||
)
|
||
|
||
|
||
@web_app.post("/api/run")
|
||
async def run_workflow(
|
||
request: Request, workflow_name: str, workflow_version: Optional[str] = None
|
||
):
|
||
"""
|
||
异步执行工作流。
|
||
立即返回任务ID,调用者可以通过任务ID查询执行状态。
|
||
"""
|
||
try:
|
||
data = await request.json()
|
||
|
||
if not workflow_name:
|
||
raise HTTPException(status_code=400, detail="`workflow_name` 字段是必需的")
|
||
|
||
# 获取工作流定义
|
||
workflow_data = await database.get_workflow(workflow_name, workflow_version)
|
||
if not workflow_data:
|
||
detail = (
|
||
f"工作流 '{workflow_name}'"
|
||
+ (f" 带版本 '{workflow_version}'" if workflow_version else " (最新版)")
|
||
+ " 未找到。"
|
||
)
|
||
raise HTTPException(status_code=404, detail=detail)
|
||
|
||
workflow = json.loads(workflow_data["workflow_json"])
|
||
api_spec = comfyui_client.parse_api_spec(workflow)
|
||
|
||
# 提交到队列
|
||
workflow_run_id = await comfyui_client.submit_workflow_to_queue(
|
||
workflow_name=workflow_name,
|
||
workflow_data=workflow,
|
||
api_spec=api_spec,
|
||
request_data=data,
|
||
)
|
||
|
||
return JSONResponse(
|
||
content={
|
||
"workflow_run_id": workflow_run_id,
|
||
"status": "queued",
|
||
"message": "工作流已提交到队列,正在等待执行",
|
||
},
|
||
status_code=202,
|
||
)
|
||
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"提交工作流失败: {str(e)}")
|
||
|
||
|
||
@web_app.get("/api/run/{workflow_run_id}")
|
||
async def get_run_status(workflow_run_id: str):
|
||
"""
|
||
获取工作流执行状态。
|
||
"""
|
||
try:
|
||
status = await comfyui_client.queue_manager.get_task_status(workflow_run_id)
|
||
return status
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"获取状态失败: {str(e)}")
|
||
|
||
|
||
@web_app.get("/api/metrics")
|
||
async def get_metrics():
|
||
"""
|
||
获取队列状态概览。
|
||
"""
|
||
try:
|
||
pending_count = len(comfyui_client.queue_manager.pending_tasks)
|
||
running_count = len(comfyui_client.queue_manager.running_tasks)
|
||
|
||
return {
|
||
"pending_tasks": pending_count,
|
||
"running_tasks": running_count,
|
||
"total_servers": len(comfyui_client.queue_manager.running_tasks),
|
||
"queue_manager_status": "active",
|
||
}
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"获取队列状态失败: {str(e)}")
|
||
|
||
|
||
# 同步工作流执行API(带S3上传功能)
|
||
# @web_app.post("/api/run_sync/")
|
||
# async def execute_workflow_sync_endpoint(
|
||
# base_name: str, request_data_raw: Dict[str, Any], version: Optional[str] = None
|
||
# ):
|
||
# """
|
||
# 同步执行工作流(支持S3文件上传)
|
||
# """
|
||
# cleanup_paths = []
|
||
# try:
|
||
# # 1. 获取工作流定义
|
||
# workflow_data = await database.get_workflow(base_name, version)
|
||
|
||
# if not workflow_data:
|
||
# detail = (
|
||
# f"工作流 '{base_name}'"
|
||
# + (f" 带版本 '{version}'" if version else " (最新版)")
|
||
# + " 未找到。"
|
||
# )
|
||
# raise HTTPException(status_code=404, detail=detail)
|
||
|
||
# workflow = json.loads(workflow_data["workflow_json"])
|
||
# api_spec = comfyui_client.parse_api_spec(workflow)
|
||
# request_data = {k.lower(): v for k, v in request_data_raw.items()}
|
||
|
||
# # 2. [核心改动] 使用智能调度选择服务器
|
||
# try:
|
||
# selected_server = await comfyui_client.select_server_for_execution()
|
||
# except Exception as e:
|
||
# raise HTTPException(
|
||
# status_code=503, detail=f"无法选择ComfyUI服务器执行任务: {e}"
|
||
# )
|
||
|
||
# server_input_dir = selected_server.input_dir
|
||
# server_output_dir = selected_server.output_dir
|
||
|
||
# # 3. 下载文件到选定服务器的输入目录
|
||
# async with aiohttp.ClientSession() as session:
|
||
# for param_name, param_def in api_spec["inputs"].items():
|
||
# if param_def["type"] == "UploadFile" and param_name in request_data:
|
||
# image_url = request_data[param_name]
|
||
# if not isinstance(image_url, str) or not image_url.startswith(
|
||
# "http"
|
||
# ):
|
||
# raise HTTPException(
|
||
# status_code=400,
|
||
# detail=f"参数 '{param_name}' 必须是一个有效的URL。",
|
||
# )
|
||
|
||
# filename = f"api_download_{uuid.uuid4().hex}{os.path.splitext(image_url.split('/')[-1])[1] or '.dat'}"
|
||
# save_path = os.path.join(server_input_dir, filename)
|
||
# try:
|
||
# await download_file_from_url(session, image_url, save_path)
|
||
# # 直接更新 request_data,以便后续的 patch_workflow 使用
|
||
# request_data[param_name] = filename
|
||
# cleanup_paths.append(save_path)
|
||
# except Exception as e:
|
||
# raise HTTPException(
|
||
# status_code=500,
|
||
# detail=f"为 '{param_name}' 从 {image_url} 下载文件失败。错误: {e}",
|
||
# )
|
||
|
||
# # 4. Patch工作流,生成最终的API Prompt
|
||
# patched_workflow = comfyui_client.patch_workflow(
|
||
# workflow, api_spec, request_data
|
||
# )
|
||
# prompt_to_run = comfyui_client.convert_workflow_to_prompt_api_format(
|
||
# patched_workflow
|
||
# )
|
||
|
||
# # 5. 执行前快照,并在选定服务器上执行工作流
|
||
# files_before = get_files_in_dir(server_output_dir)
|
||
# output_nodes = await comfyui_client.execute_prompt_on_server_legacy(
|
||
# prompt_to_run, selected_server
|
||
# )
|
||
|
||
# # 6. 处理输出(与之前逻辑相同)
|
||
# files_after = get_files_in_dir(server_output_dir)
|
||
# new_files = files_after - files_before
|
||
|
||
# output_response = {}
|
||
# if new_files:
|
||
# s3_urls = []
|
||
# for file_path in new_files:
|
||
# cleanup_paths.append(file_path)
|
||
# try:
|
||
# s3_urls.append(await handle_file_upload(file_path, base_name))
|
||
# except Exception as e:
|
||
# print(f"上传文件 {file_path} 到S3时出错: {e}")
|
||
# if s3_urls:
|
||
# output_response["output_files"] = s3_urls
|
||
|
||
# for final_param_name, param_def in api_spec["outputs"].items():
|
||
# node_id = param_def["node_id"]
|
||
# if node_id in output_nodes:
|
||
# node_output = output_nodes[node_id]
|
||
# original_output_name = param_def["output_name"]
|
||
|
||
# output_value = None
|
||
# if original_output_name in node_output:
|
||
# output_value = node_output[original_output_name]
|
||
# elif "text" in node_output: # 备选,例如对于'ShowText'节点
|
||
# output_value = node_output["text"]
|
||
|
||
# if output_value is None:
|
||
# continue
|
||
|
||
# if isinstance(output_value, list):
|
||
# output_value = output_value[0] if output_value else None
|
||
|
||
# output_response[final_param_name] = output_value
|
||
|
||
# return JSONResponse(content=output_response)
|
||
# # [核心改动] 捕获来自ComfyUI的执行失败异常
|
||
# except ComfyUIExecutionError as e:
|
||
# print(f"捕获到ComfyUI执行错误: {e.error_data}")
|
||
# # 返回 502 Bad Gateway 状态码,表示上游服务器出错
|
||
# # detail 中包含结构化的错误信息,方便客户端处理
|
||
# raise HTTPException(
|
||
# status_code=500,
|
||
# detail={
|
||
# "message": "工作流在上游ComfyUI节点中执行失败。",
|
||
# "error_details": e.error_data,
|
||
# },
|
||
# )
|
||
# except Exception as e:
|
||
# # 捕获其他所有异常,作为通用的服务器内部错误
|
||
# print(f"执行工作流时发生未知错误: {e}")
|
||
# raise HTTPException(status_code=500, detail=str(e))
|
||
# finally:
|
||
# if cleanup_paths:
|
||
# print(f"正在清理 {len(cleanup_paths)} 个临时文件...")
|
||
# for path in cleanup_paths:
|
||
# try:
|
||
# if os.path.exists(path):
|
||
# os.remove(path)
|
||
# print(f" - 已删除: {path}")
|
||
# except Exception as e:
|
||
# print(f" - 删除 {path} 时出错: {e}")
|
||
|
||
|
||
@web_app.get(
|
||
"/api/servers/status", response_model=List[ServerStatus], tags=["Server Monitoring"]
|
||
)
|
||
async def get_servers_status_endpoint():
|
||
"""
|
||
获取所有已配置的ComfyUI服务器的配置信息和实时状态。
|
||
"""
|
||
servers = settings.SERVERS
|
||
if not servers:
|
||
return []
|
||
|
||
async with aiohttp.ClientSession() as session:
|
||
status_tasks = [
|
||
comfyui_client.get_server_status(server, session) for server in servers
|
||
]
|
||
live_statuses = await asyncio.gather(*status_tasks)
|
||
|
||
response_list = []
|
||
for i, server_config in enumerate(servers):
|
||
status_data = live_statuses[i]
|
||
response_list.append(
|
||
ServerStatus(
|
||
server_index=i,
|
||
http_url=server_config.http_url,
|
||
ws_url=server_config.ws_url,
|
||
input_dir=server_config.input_dir,
|
||
output_dir=server_config.output_dir,
|
||
is_reachable=status_data["is_reachable"],
|
||
is_free=status_data["is_free"],
|
||
queue_details=status_data["queue_details"],
|
||
)
|
||
)
|
||
return response_list
|
||
|
||
|
||
async def _get_folder_contents(path: str) -> List[FileDetails]:
|
||
"""异步地列出并返回文件夹内容的详细信息。"""
|
||
if not os.path.isdir(path):
|
||
return []
|
||
|
||
def sync_list_files(dir_path):
|
||
files = []
|
||
try:
|
||
for entry in os.scandir(dir_path):
|
||
if entry.is_file():
|
||
stat = entry.stat()
|
||
files.append(
|
||
FileDetails(
|
||
name=entry.name,
|
||
size_kb=round(stat.st_size / 1024, 2),
|
||
modified_at=datetime.fromtimestamp(stat.st_mtime),
|
||
)
|
||
)
|
||
except OSError as e:
|
||
print(f"无法扫描目录 {dir_path}: {e}")
|
||
return sorted(files, key=lambda x: x.modified_at, reverse=True)
|
||
|
||
return await asyncio.to_thread(sync_list_files, path)
|
||
|
||
|
||
@web_app.get(
|
||
"/api/servers/{server_index}/files",
|
||
response_model=ServerFiles,
|
||
tags=["Server Monitoring"],
|
||
)
|
||
async def list_server_files_endpoint(
|
||
server_index: int = Path(..., ge=0, description="服务器在配置列表中的索引")
|
||
):
|
||
"""
|
||
获取指定ComfyUI服务器的输入和输出文件夹中的文件列表。
|
||
"""
|
||
servers = settings.SERVERS
|
||
if server_index >= len(servers):
|
||
raise HTTPException(
|
||
status_code=404,
|
||
detail=f"服务器索引 {server_index} 超出范围。有效索引为 0 到 {len(servers) - 1}。",
|
||
)
|
||
|
||
server_config = servers[server_index]
|
||
|
||
input_files, output_files = await asyncio.gather(
|
||
_get_folder_contents(server_config.input_dir),
|
||
_get_folder_contents(server_config.output_dir),
|
||
)
|
||
|
||
return ServerFiles(
|
||
server_index=server_index,
|
||
http_url=server_config.http_url,
|
||
input_files=input_files,
|
||
output_files=output_files,
|
||
)
|
||
|
||
|
||
@web_app.get("/", response_class=JSONResponse)
|
||
def read_root():
|
||
"""
|
||
提供一个API的快速使用指南。
|
||
"""
|
||
guide = {
|
||
"service_name": "ComfyUI Workflow Service & Management API",
|
||
"description": "一个用于发布、管理和执行 ComfyUI 工作流的API服务。它将 ComfyUI 的图形化工作流抽象为标准的 RESTful API 端点。",
|
||
"quick_start_guide": [
|
||
{
|
||
"step": 1,
|
||
"action": "发布一个工作流 (Publish a Workflow)",
|
||
"description": "从 ComfyUI 保存你的工作流 (API格式),然后使用一个唯一的名称将其发布到服务中。名称中必须包含一个版本号,格式为 `[YYYYMMDDHHMMSS]`。",
|
||
"endpoint": f"/api/workflow",
|
||
"example_curl": 'curl -X POST http://127.0.0.1:18000/api/workflow -H \'Content-Type: application/json\' -d \'{\n "name": "my_t2i [20250801120000]",\n "workflow": { ... your_workflow_api_json ... }\n}\'',
|
||
},
|
||
{
|
||
"step": 2,
|
||
"action": "异步执行工作流 (Execute Workflow Async with S3 Upload)",
|
||
"description": "提交工作流到异步队列,立即返回任务ID,支持长时间运行的工作流。执行完成后自动上传输出文件到S3。",
|
||
"endpoint": "/api/run",
|
||
"example_curl": 'curl -X POST "http://127.0.0.1:18000/api/run?workflow_name=my_t2i&workflow_version=20250801120000" -H \'Content-Type: application/json\' -d \'{\n "prompt_prompt": "a beautiful cat sitting on a roof",\n "sampler_seed": 12345\n}\'',
|
||
},
|
||
{
|
||
"step": 3,
|
||
"action": "查询异步任务状态 (Check Async Task Status)",
|
||
"description": "通过任务ID查询异步工作流的执行状态。",
|
||
"endpoint": "/api/run/{workflow_run_id}",
|
||
"example_curl": "curl -X GET 'http://127.0.0.1:18000/api/run/your-task-id-here'",
|
||
},
|
||
{
|
||
"step": 4,
|
||
"action": "查看队列状态 (Check Queue Status)",
|
||
"description": "查看当前异步任务队列的状态。",
|
||
"endpoint": "/api/metrics",
|
||
"example_curl": "curl -X GET 'http://127.0.0.1:18000/api/metrics'",
|
||
},
|
||
{
|
||
"step": 5,
|
||
"action": "查看工作流的API规范 (Inspect Workflow API Spec)",
|
||
"description": "一旦发布,你可以查询工作流的API规范,以了解需要提供哪些输入参数以及可以期望哪些输出。",
|
||
"endpoint": "/api/spec/",
|
||
"parameters": [
|
||
{
|
||
"name": "base_name",
|
||
"description": "工作流的基础名称 (不含版本部分)。",
|
||
},
|
||
{
|
||
"name": "version",
|
||
"description": "可选。工作流的具体版本号 (YYYYMMDDHHMMSS)。如果省略,则获取最新版本。",
|
||
},
|
||
],
|
||
"example_curl": "curl -X GET 'http://127.0.0.1:18000/api/spec/?base_name=my_t2i'",
|
||
},
|
||
{
|
||
"step": 6,
|
||
"action": "执行工作流 (Execute a Workflow - Sync with S3 Upload)",
|
||
"description": "使用获取到的API规范,通过HTTP POST请求来同步执行工作流。对于文件输入,请提供可公开访问的URL。输出文件将被自动上传到S3并返回URL。",
|
||
"endpoint": "/api/run_sync/",
|
||
"parameters": [
|
||
{"name": "base_name", "description": "工作流的基础名称。"},
|
||
{
|
||
"name": "version",
|
||
"description": "可选。要执行的特定版本。如果省略,则执行最新版本。",
|
||
},
|
||
],
|
||
"example_curl": "curl -X POST 'http://127.0.0.1:18000/api/run_sync/?base_name=my_t2i' -H 'Content-Type: application/json' -d '{\n \"prompt_prompt\": \"a beautiful cat sitting on a roof\",\n \"sampler_seed\": 12345\n}'",
|
||
},
|
||
{
|
||
"step": 7,
|
||
"action": "管理工作流 (Manage Workflows)",
|
||
"description": "你也可以列出所有已发布的工作流或删除不再需要的工作流。",
|
||
"endpoints": [
|
||
{
|
||
"method": "GET",
|
||
"path": f"/api/workflow",
|
||
"description": "列出所有工作流。",
|
||
},
|
||
{
|
||
"method": "DELETE",
|
||
"path": f"/api/workflow/{{workflow_name}}",
|
||
"description": "按完整名称删除一个工作流。",
|
||
},
|
||
],
|
||
"example_curl_list": f"curl -X GET http://127.0.0.1:18000/api/workflow",
|
||
"example_curl_delete": f"curl -X DELETE 'http://127.0.0.1:18000/api/workflow/my_t2i%20[20250801120000]'",
|
||
},
|
||
],
|
||
"notes": "请确保在 curl 命令中对包含空格或特殊字符的工作流名称进行URL编码。",
|
||
}
|
||
return JSONResponse(content=guide)
|
||
|
||
|
||
def get_files_in_dir(directory: str) -> Set[str]:
|
||
file_set = set()
|
||
for root, _, files in os.walk(directory):
|
||
for filename in files:
|
||
if not filename.startswith("."):
|
||
file_set.add(os.path.join(root, filename))
|
||
return file_set
|
||
|
||
|
||
async def download_file_from_url(
|
||
session: aiohttp.ClientSession, url: str, save_path: str
|
||
):
|
||
async with session.get(url) as response:
|
||
response.raise_for_status()
|
||
with open(save_path, "wb") as f:
|
||
while True:
|
||
chunk = await response.content.read(8192)
|
||
if not chunk:
|
||
break
|
||
f.write(chunk)
|
||
|
||
|
||
async def handle_file_upload(file_path: str, base_name: str) -> str:
|
||
s3_object_name = f"outputs/{base_name}/{uuid.uuid4()}_{os.path.basename(file_path)}"
|
||
return await upload_file_to_s3(file_path, settings.S3_BUCKET_NAME, s3_object_name)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
uvicorn.run(
|
||
"workflow_service.main:web_app", host="0.0.0.0", port=18000, reload=True
|
||
)
|