ComfyUI-WorkflowPublisher/workflow_service/main.py

604 lines
23 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
import json
import os
import uuid
from datetime import datetime
from typing import Optional, List, Dict, Any, Set
import aiohttp
import uvicorn
from fastapi import FastAPI, Request, HTTPException, Path
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from workflow_service import comfyui_client
from workflow_service import database
from workflow_service.utils.s3_client import upload_file_to_s3
from workflow_service.config import Settings
settings = Settings()
web_app = FastAPI(title="ComfyUI Workflow Service & Management API")
# 配置跨域支持
web_app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # 允许所有来源,生产环境中建议限制具体域名
allow_credentials=True,
allow_methods=["*"], # 允许所有HTTP方法
allow_headers=["*"], # 允许所有请求头
)
# --- Pydantic Response Models for New Endpoints ---
class ServerQueueDetails(BaseModel):
running_count: int
pending_count: int
class ServerStatus(BaseModel):
server_index: int
http_url: str
ws_url: str
input_dir: str
output_dir: str
is_reachable: bool
is_free: bool
queue_details: ServerQueueDetails
class FileDetails(BaseModel):
name: str
size_kb: float
modified_at: datetime
class ServerFiles(BaseModel):
server_index: int
http_url: str
input_files: List[FileDetails]
output_files: List[FileDetails]
@web_app.on_event("startup")
async def startup_event():
"""服务启动时,初始化数据库并为所有配置的服务器创建输入/输出目录。"""
await database.init_db()
try:
servers = settings.SERVERS
print(f"检测到 {len(servers)} 个 ComfyUI 服务器配置。")
for server in servers:
print(f" - 正在为服务器 {server.http_url} 准备目录...")
os.makedirs(server.input_dir, exist_ok=True)
print(f" - 输入目录: {os.path.abspath(server.input_dir)}")
os.makedirs(server.output_dir, exist_ok=True)
print(f" - 输出目录: {os.path.abspath(server.output_dir)}")
except ValueError as e:
print(f"错误: 无法在启动时初始化服务器目录: {e}")
@web_app.post("/api/workflow", status_code=201)
async def publish_workflow_endpoint(request: Request):
"""
发布工作流
"""
try:
data = await request.json()
name, wf_json = data.get("name"), data.get("workflow")
if not name or not wf_json:
raise HTTPException(
status_code=400, detail="`name` and `workflow` fields are required."
)
await database.save_workflow(name, json.dumps(wf_json))
print(f"Workflow '{name}' published.")
return JSONResponse(
content={"status": "success", "message": f"Workflow '{name}' published."},
status_code=201,
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to save workflow: {e}")
@web_app.get("/api/workflow", response_model=List[dict])
async def get_all_workflows_endpoint():
"""
获取所有工作流
"""
try:
return await database.get_all_workflows()
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to get workflows: {e}")
@web_app.delete(f"/api/workflow/{{workflow_name:path}}")
async def delete_workflow_endpoint(
workflow_name: str = Path(
...,
description="The full, unique name of the workflow to delete, e.g., 'my_workflow [20250101120000]'",
)
):
"""
删除工作流
"""
try:
success = await database.delete_workflow(workflow_name)
if success:
return {"status": "deleted", "name": workflow_name}
else:
raise HTTPException(
status_code=404, detail=f"Workflow '{workflow_name}' not found"
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to delete workflow: {e}")
@web_app.get("/api/workflow/{base_name:path}")
async def get_one_workflow_endpoint(base_name: str, version: Optional[str] = None):
"""
获取工作流规范
"""
workflow_data = await database.get_workflow(base_name, version)
if not workflow_data:
detail = (
f"Workflow '{base_name}'"
+ (f" with version '{version}'" if version else " (latest)")
+ " not found."
)
raise HTTPException(status_code=404, detail=detail)
try:
workflow = json.loads(workflow_data["workflow_json"])
return {
"workflow": workflow,
"api_spec": comfyui_client.parse_api_spec(workflow),
}
except Exception as e:
raise HTTPException(
status_code=500, detail=f"Failed to parse workflow specification: {e}"
)
@web_app.post("/api/run")
async def run_workflow(
request: Request, workflow_name: str, workflow_version: Optional[str] = None
):
"""
异步执行工作流。
立即返回任务ID调用者可以通过任务ID查询执行状态。
"""
try:
data = await request.json()
if not workflow_name:
raise HTTPException(status_code=400, detail="`workflow_name` 字段是必需的")
# 获取工作流定义
workflow_data = await database.get_workflow(workflow_name, workflow_version)
if not workflow_data:
detail = (
f"工作流 '{workflow_name}'"
+ (f" 带版本 '{workflow_version}'" if workflow_version else " (最新版)")
+ " 未找到。"
)
raise HTTPException(status_code=404, detail=detail)
workflow = json.loads(workflow_data["workflow_json"])
api_spec = comfyui_client.parse_api_spec(workflow)
# 提交到队列
workflow_run_id = await comfyui_client.submit_workflow_to_queue(
workflow_name=workflow_name,
workflow_data=workflow,
api_spec=api_spec,
request_data=data,
)
return JSONResponse(
content={
"workflow_run_id": workflow_run_id,
"status": "queued",
"message": "工作流已提交到队列,正在等待执行",
},
status_code=202,
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"提交工作流失败: {str(e)}")
@web_app.get("/api/run/{workflow_run_id}")
async def get_run_status(workflow_run_id: str):
"""
获取工作流执行状态。
"""
try:
status = await comfyui_client.queue_manager.get_task_status(workflow_run_id)
return status
except Exception as e:
raise HTTPException(status_code=500, detail=f"获取状态失败: {str(e)}")
@web_app.get("/api/metrics")
async def get_metrics():
"""
获取队列状态概览。
"""
try:
pending_count = len(comfyui_client.queue_manager.pending_tasks)
running_count = len(comfyui_client.queue_manager.running_tasks)
return {
"pending_tasks": pending_count,
"running_tasks": running_count,
"total_servers": len(comfyui_client.queue_manager.running_tasks),
"queue_manager_status": "active",
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"获取队列状态失败: {str(e)}")
# 同步工作流执行API带S3上传功能
# @web_app.post("/api/run_sync/")
# async def execute_workflow_sync_endpoint(
# base_name: str, request_data_raw: Dict[str, Any], version: Optional[str] = None
# ):
# """
# 同步执行工作流支持S3文件上传
# """
# cleanup_paths = []
# try:
# # 1. 获取工作流定义
# workflow_data = await database.get_workflow(base_name, version)
# if not workflow_data:
# detail = (
# f"工作流 '{base_name}'"
# + (f" 带版本 '{version}'" if version else " (最新版)")
# + " 未找到。"
# )
# raise HTTPException(status_code=404, detail=detail)
# workflow = json.loads(workflow_data["workflow_json"])
# api_spec = comfyui_client.parse_api_spec(workflow)
# request_data = {k.lower(): v for k, v in request_data_raw.items()}
# # 2. [核心改动] 使用智能调度选择服务器
# try:
# selected_server = await comfyui_client.select_server_for_execution()
# except Exception as e:
# raise HTTPException(
# status_code=503, detail=f"无法选择ComfyUI服务器执行任务: {e}"
# )
# server_input_dir = selected_server.input_dir
# server_output_dir = selected_server.output_dir
# # 3. 下载文件到选定服务器的输入目录
# async with aiohttp.ClientSession() as session:
# for param_name, param_def in api_spec["inputs"].items():
# if param_def["type"] == "UploadFile" and param_name in request_data:
# image_url = request_data[param_name]
# if not isinstance(image_url, str) or not image_url.startswith(
# "http"
# ):
# raise HTTPException(
# status_code=400,
# detail=f"参数 '{param_name}' 必须是一个有效的URL。",
# )
# filename = f"api_download_{uuid.uuid4().hex}{os.path.splitext(image_url.split('/')[-1])[1] or '.dat'}"
# save_path = os.path.join(server_input_dir, filename)
# try:
# await download_file_from_url(session, image_url, save_path)
# # 直接更新 request_data以便后续的 patch_workflow 使用
# request_data[param_name] = filename
# cleanup_paths.append(save_path)
# except Exception as e:
# raise HTTPException(
# status_code=500,
# detail=f"为 '{param_name}' 从 {image_url} 下载文件失败。错误: {e}",
# )
# # 4. Patch工作流生成最终的API Prompt
# patched_workflow = comfyui_client.patch_workflow(
# workflow, api_spec, request_data
# )
# prompt_to_run = comfyui_client.convert_workflow_to_prompt_api_format(
# patched_workflow
# )
# # 5. 执行前快照,并在选定服务器上执行工作流
# files_before = get_files_in_dir(server_output_dir)
# output_nodes = await comfyui_client.execute_prompt_on_server_legacy(
# prompt_to_run, selected_server
# )
# # 6. 处理输出(与之前逻辑相同)
# files_after = get_files_in_dir(server_output_dir)
# new_files = files_after - files_before
# output_response = {}
# if new_files:
# s3_urls = []
# for file_path in new_files:
# cleanup_paths.append(file_path)
# try:
# s3_urls.append(await handle_file_upload(file_path, base_name))
# except Exception as e:
# print(f"上传文件 {file_path} 到S3时出错: {e}")
# if s3_urls:
# output_response["output_files"] = s3_urls
# for final_param_name, param_def in api_spec["outputs"].items():
# node_id = param_def["node_id"]
# if node_id in output_nodes:
# node_output = output_nodes[node_id]
# original_output_name = param_def["output_name"]
# output_value = None
# if original_output_name in node_output:
# output_value = node_output[original_output_name]
# elif "text" in node_output: # 备选,例如对于'ShowText'节点
# output_value = node_output["text"]
# if output_value is None:
# continue
# if isinstance(output_value, list):
# output_value = output_value[0] if output_value else None
# output_response[final_param_name] = output_value
# return JSONResponse(content=output_response)
# # [核心改动] 捕获来自ComfyUI的执行失败异常
# except ComfyUIExecutionError as e:
# print(f"捕获到ComfyUI执行错误: {e.error_data}")
# # 返回 502 Bad Gateway 状态码,表示上游服务器出错
# # detail 中包含结构化的错误信息,方便客户端处理
# raise HTTPException(
# status_code=500,
# detail={
# "message": "工作流在上游ComfyUI节点中执行失败。",
# "error_details": e.error_data,
# },
# )
# except Exception as e:
# # 捕获其他所有异常,作为通用的服务器内部错误
# print(f"执行工作流时发生未知错误: {e}")
# raise HTTPException(status_code=500, detail=str(e))
# finally:
# if cleanup_paths:
# print(f"正在清理 {len(cleanup_paths)} 个临时文件...")
# for path in cleanup_paths:
# try:
# if os.path.exists(path):
# os.remove(path)
# print(f" - 已删除: {path}")
# except Exception as e:
# print(f" - 删除 {path} 时出错: {e}")
@web_app.get(
"/api/servers/status", response_model=List[ServerStatus], tags=["Server Monitoring"]
)
async def get_servers_status_endpoint():
"""
获取所有已配置的ComfyUI服务器的配置信息和实时状态。
"""
servers = settings.SERVERS
if not servers:
return []
async with aiohttp.ClientSession() as session:
status_tasks = [
comfyui_client.get_server_status(server, session) for server in servers
]
live_statuses = await asyncio.gather(*status_tasks)
response_list = []
for i, server_config in enumerate(servers):
status_data = live_statuses[i]
response_list.append(
ServerStatus(
server_index=i,
http_url=server_config.http_url,
ws_url=server_config.ws_url,
input_dir=server_config.input_dir,
output_dir=server_config.output_dir,
is_reachable=status_data["is_reachable"],
is_free=status_data["is_free"],
queue_details=status_data["queue_details"],
)
)
return response_list
async def _get_folder_contents(path: str) -> List[FileDetails]:
"""异步地列出并返回文件夹内容的详细信息。"""
if not os.path.isdir(path):
return []
def sync_list_files(dir_path):
files = []
try:
for entry in os.scandir(dir_path):
if entry.is_file():
stat = entry.stat()
files.append(
FileDetails(
name=entry.name,
size_kb=round(stat.st_size / 1024, 2),
modified_at=datetime.fromtimestamp(stat.st_mtime),
)
)
except OSError as e:
print(f"无法扫描目录 {dir_path}: {e}")
return sorted(files, key=lambda x: x.modified_at, reverse=True)
return await asyncio.to_thread(sync_list_files, path)
@web_app.get(
"/api/servers/{server_index}/files",
response_model=ServerFiles,
tags=["Server Monitoring"],
)
async def list_server_files_endpoint(
server_index: int = Path(..., ge=0, description="服务器在配置列表中的索引")
):
"""
获取指定ComfyUI服务器的输入和输出文件夹中的文件列表。
"""
servers = settings.SERVERS
if server_index >= len(servers):
raise HTTPException(
status_code=404,
detail=f"服务器索引 {server_index} 超出范围。有效索引为 0 到 {len(servers) - 1}",
)
server_config = servers[server_index]
input_files, output_files = await asyncio.gather(
_get_folder_contents(server_config.input_dir),
_get_folder_contents(server_config.output_dir),
)
return ServerFiles(
server_index=server_index,
http_url=server_config.http_url,
input_files=input_files,
output_files=output_files,
)
@web_app.get("/", response_class=JSONResponse)
def read_root():
"""
提供一个API的快速使用指南。
"""
guide = {
"service_name": "ComfyUI Workflow Service & Management API",
"description": "一个用于发布、管理和执行 ComfyUI 工作流的API服务。它将 ComfyUI 的图形化工作流抽象为标准的 RESTful API 端点。",
"quick_start_guide": [
{
"step": 1,
"action": "发布一个工作流 (Publish a Workflow)",
"description": "从 ComfyUI 保存你的工作流 (API格式),然后使用一个唯一的名称将其发布到服务中。名称中必须包含一个版本号,格式为 `[YYYYMMDDHHMMSS]`。",
"endpoint": f"/api/workflow",
"example_curl": 'curl -X POST http://127.0.0.1:18000/api/workflow -H \'Content-Type: application/json\' -d \'{\n "name": "my_t2i [20250801120000]",\n "workflow": { ... your_workflow_api_json ... }\n}\'',
},
{
"step": 2,
"action": "异步执行工作流 (Execute Workflow Async with S3 Upload)",
"description": "提交工作流到异步队列立即返回任务ID支持长时间运行的工作流。执行完成后自动上传输出文件到S3。",
"endpoint": "/api/run",
"example_curl": 'curl -X POST "http://127.0.0.1:18000/api/run?workflow_name=my_t2i&workflow_version=20250801120000" -H \'Content-Type: application/json\' -d \'{\n "prompt_prompt": "a beautiful cat sitting on a roof",\n "sampler_seed": 12345\n}\'',
},
{
"step": 3,
"action": "查询异步任务状态 (Check Async Task Status)",
"description": "通过任务ID查询异步工作流的执行状态。",
"endpoint": "/api/run/{workflow_run_id}",
"example_curl": "curl -X GET 'http://127.0.0.1:18000/api/run/your-task-id-here'",
},
{
"step": 4,
"action": "查看队列状态 (Check Queue Status)",
"description": "查看当前异步任务队列的状态。",
"endpoint": "/api/metrics",
"example_curl": "curl -X GET 'http://127.0.0.1:18000/api/metrics'",
},
{
"step": 5,
"action": "查看工作流的API规范 (Inspect Workflow API Spec)",
"description": "一旦发布你可以查询工作流的API规范以了解需要提供哪些输入参数以及可以期望哪些输出。",
"endpoint": "/api/spec/",
"parameters": [
{
"name": "base_name",
"description": "工作流的基础名称 (不含版本部分)。",
},
{
"name": "version",
"description": "可选。工作流的具体版本号 (YYYYMMDDHHMMSS)。如果省略,则获取最新版本。",
},
],
"example_curl": "curl -X GET 'http://127.0.0.1:18000/api/spec/?base_name=my_t2i'",
},
{
"step": 6,
"action": "执行工作流 (Execute a Workflow - Sync with S3 Upload)",
"description": "使用获取到的API规范通过HTTP POST请求来同步执行工作流。对于文件输入请提供可公开访问的URL。输出文件将被自动上传到S3并返回URL。",
"endpoint": "/api/run_sync/",
"parameters": [
{"name": "base_name", "description": "工作流的基础名称。"},
{
"name": "version",
"description": "可选。要执行的特定版本。如果省略,则执行最新版本。",
},
],
"example_curl": "curl -X POST 'http://127.0.0.1:18000/api/run_sync/?base_name=my_t2i' -H 'Content-Type: application/json' -d '{\n \"prompt_prompt\": \"a beautiful cat sitting on a roof\",\n \"sampler_seed\": 12345\n}'",
},
{
"step": 7,
"action": "管理工作流 (Manage Workflows)",
"description": "你也可以列出所有已发布的工作流或删除不再需要的工作流。",
"endpoints": [
{
"method": "GET",
"path": f"/api/workflow",
"description": "列出所有工作流。",
},
{
"method": "DELETE",
"path": f"/api/workflow/{{workflow_name}}",
"description": "按完整名称删除一个工作流。",
},
],
"example_curl_list": f"curl -X GET http://127.0.0.1:18000/api/workflow",
"example_curl_delete": f"curl -X DELETE 'http://127.0.0.1:18000/api/workflow/my_t2i%20[20250801120000]'",
},
],
"notes": "请确保在 curl 命令中对包含空格或特殊字符的工作流名称进行URL编码。",
}
return JSONResponse(content=guide)
def get_files_in_dir(directory: str) -> Set[str]:
file_set = set()
for root, _, files in os.walk(directory):
for filename in files:
if not filename.startswith("."):
file_set.add(os.path.join(root, filename))
return file_set
async def download_file_from_url(
session: aiohttp.ClientSession, url: str, save_path: str
):
async with session.get(url) as response:
response.raise_for_status()
with open(save_path, "wb") as f:
while True:
chunk = await response.content.read(8192)
if not chunk:
break
f.write(chunk)
async def handle_file_upload(file_path: str, base_name: str) -> str:
s3_object_name = f"outputs/{base_name}/{uuid.uuid4()}_{os.path.basename(file_path)}"
return await upload_file_to_s3(file_path, settings.S3_BUCKET_NAME, s3_object_name)
if __name__ == "__main__":
uvicorn.run(
"workflow_service.main:web_app", host="0.0.0.0", port=18000, reload=True
)