add 增加waas服务器demo
This commit is contained in:
parent
591989e826
commit
c69d3cc326
85
README.md
85
README.md
|
|
@ -5,4 +5,87 @@
|
|||
- 工作流上传
|
||||
- 工作流版本控制
|
||||
- 工作流加载
|
||||
- 需配置工作流服务器(详见workflow_server_demo.py文件)
|
||||
- 需配置工作流服务器 (详见**WAAS(工作流即服务) Demo API服务器**)
|
||||
|
||||
|
||||
# WAAS(工作流即服务) Demo API服务器
|
||||
- 路径: ./workflow_service
|
||||
- 必须按照指定规则命名
|
||||
- **输入节点名**: 前缀 **INPUT_**
|
||||
- **除生成文件节点外输出节点名**: 前缀 **OUTPUT_**
|
||||
- 支持输入节点:
|
||||
- comfyui-core
|
||||
- 加载图像
|
||||
- ComfyUI-VideoHelperSuite
|
||||
- Load Video(Upload)
|
||||
- comfyui-easy-use
|
||||
- 整数
|
||||
- 字符串
|
||||
- 浮点数
|
||||
- 支持输出节点:
|
||||
- 所有在output文件夹中生成文件(图片/视频)的节点
|
||||
- comfyui-easy-use
|
||||
- 展示任何
|
||||
- 数据库
|
||||
- 类型: Sqlite (workflows_service.sqlite)
|
||||
- 数据库结构
|
||||
```
|
||||
CREATE TABLE IF NOT EXISTS workflows (
|
||||
name TEXT PRIMARY KEY,
|
||||
base_name TEXT NOT NULL,
|
||||
version TEXT NOT NULL,
|
||||
workflow_json TEXT NOT NULL
|
||||
)
|
||||
```
|
||||
- 路由
|
||||
- GET /api/workflow: 列出工作流
|
||||
- POST /api/workflow: 添加工作流
|
||||
- DELETE /api/workflow: 删除工作流
|
||||
- GET /api/run/{base_name}: 获取工作流输入输出元数据
|
||||
```
|
||||
输入:
|
||||
*base_name: 工作流名称
|
||||
version: 工作流版本
|
||||
|
||||
输出:
|
||||
Json
|
||||
{
|
||||
"inputs": {
|
||||
"image_image": {
|
||||
"node_id": "13",
|
||||
"type": "UploadFile",
|
||||
"widget_name": "image"
|
||||
},
|
||||
"prefix_value": {
|
||||
"node_id": "22",
|
||||
"type": "int",
|
||||
"widget_name": "value"
|
||||
}
|
||||
},
|
||||
"outputs": {
|
||||
"text_output": {
|
||||
"node_id": "21",
|
||||
"class_type": "easy showAnything",
|
||||
"output_name": "output",
|
||||
"output_index": 0
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
- POST /api/run/{base_name}: 执行工作流
|
||||
```
|
||||
输入:
|
||||
*base_name: 工作流名称
|
||||
version: 工作流版本
|
||||
|
||||
输出:
|
||||
Json
|
||||
{
|
||||
"output_files": [
|
||||
"https://cdn.roasmax.cn/outputs/测试/4e91e429-c848-4f66-885c-98a83c745872_111_00001_.png"
|
||||
],
|
||||
"text_output": [
|
||||
"output\\111_00001_.png"
|
||||
]
|
||||
}
|
||||
```
|
||||
|
|
@ -1,158 +0,0 @@
|
|||
import sqlite3
|
||||
import json
|
||||
from http.server import HTTPServer, BaseHTTPRequestHandler
|
||||
import urllib.parse
|
||||
|
||||
# --- 配置 ---
|
||||
DATABASE_FILE = "workflows.sqlite"
|
||||
PORT = 8000
|
||||
|
||||
|
||||
# --- 数据库初始化 ---
|
||||
def init_db():
|
||||
"""初始化数据库,如果表不存在则创建它"""
|
||||
with sqlite3.connect(DATABASE_FILE) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS workflows (
|
||||
name TEXT PRIMARY KEY,
|
||||
workflow_json TEXT NOT NULL
|
||||
)
|
||||
""")
|
||||
conn.commit()
|
||||
print(f"数据库 '{DATABASE_FILE}' 已准备就绪。")
|
||||
|
||||
|
||||
# --- API 处理器 ---
|
||||
class PersistentAPIHandler(BaseHTTPRequestHandler):
|
||||
|
||||
def _send_cors_headers(self):
|
||||
"""发送CORS头部"""
|
||||
self.send_header('Access-Control-Allow-Origin', '*')
|
||||
self.send_header('Access-Control-Allow-Methods', 'GET, POST, DELETE, OPTIONS')
|
||||
self.send_header("Access-Control-Allow-Headers", "X-Requested-With, Content-Type")
|
||||
|
||||
def do_OPTIONS(self):
|
||||
"""处理预检请求"""
|
||||
self.send_response(200, "ok")
|
||||
self._send_cors_headers()
|
||||
self.end_headers()
|
||||
|
||||
def do_GET(self):
|
||||
"""处理获取所有工作流的请求"""
|
||||
if self.path == '/api/workflow':
|
||||
try:
|
||||
with sqlite3.connect(DATABASE_FILE) as conn:
|
||||
conn.row_factory = sqlite3.Row # 让查询结果可以像字典一样访问
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT name, workflow_json FROM workflows")
|
||||
rows = cursor.fetchall()
|
||||
|
||||
# 将数据库行转换为前端期望的JSON格式
|
||||
workflows_list = [
|
||||
{"name": row["name"], "workflow": json.loads(row["workflow_json"])}
|
||||
for row in rows
|
||||
]
|
||||
|
||||
self.send_response(200)
|
||||
self.send_header('Content-type', 'application/json')
|
||||
self._send_cors_headers()
|
||||
self.end_headers()
|
||||
self.wfile.write(json.dumps(workflows_list).encode())
|
||||
|
||||
except Exception as e:
|
||||
self._send_error(f"获取工作流时出错: {e}")
|
||||
else:
|
||||
self._send_error("路径未找到", 404)
|
||||
|
||||
def do_POST(self):
|
||||
"""处理发布/更新工作流的请求"""
|
||||
if self.path == '/api/workflow':
|
||||
try:
|
||||
content_length = int(self.headers['Content-Length'])
|
||||
post_data = self.rfile.read(content_length)
|
||||
data = json.loads(post_data)
|
||||
|
||||
name = data.get('name')
|
||||
workflow_data = data.get('workflow')
|
||||
|
||||
if not name or not workflow_data:
|
||||
return self._send_error("请求体中缺少 'name' 或 'workflow'", 400)
|
||||
|
||||
workflow_json_str = json.dumps(workflow_data)
|
||||
|
||||
with sqlite3.connect(DATABASE_FILE) as conn:
|
||||
cursor = conn.cursor()
|
||||
# 使用 INSERT OR REPLACE 来处理创建和更新,非常方便
|
||||
cursor.execute(
|
||||
"INSERT OR REPLACE INTO workflows (name, workflow_json) VALUES (?, ?)",
|
||||
(name, workflow_json_str)
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
print(f"--- 已保存/更新工作流: {name} ---")
|
||||
self.send_response(200)
|
||||
self.send_header('Content-type', 'application/json')
|
||||
self._send_cors_headers()
|
||||
self.end_headers()
|
||||
self.wfile.write(json.dumps({"status": "success", "name": name}).encode())
|
||||
|
||||
except Exception as e:
|
||||
self._send_error(f"保存工作流时出错: {e}")
|
||||
else:
|
||||
self._send_error("路径未找到", 404)
|
||||
|
||||
def do_DELETE(self):
|
||||
"""处理删除指定工作流的请求"""
|
||||
parts = self.path.split('/api/workflow/')
|
||||
if len(parts) == 2 and parts[1]:
|
||||
workflow_name_to_delete = urllib.parse.unquote(parts[1])
|
||||
try:
|
||||
with sqlite3.connect(DATABASE_FILE) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("DELETE FROM workflows WHERE name = ?", (workflow_name_to_delete,))
|
||||
conn.commit()
|
||||
|
||||
# 检查是否真的有行被删除了
|
||||
if cursor.rowcount > 0:
|
||||
print(f"--- 已删除工作流: {workflow_name_to_delete} ---")
|
||||
self.send_response(200)
|
||||
self.send_header('Content-type', 'application/json')
|
||||
self._send_cors_headers()
|
||||
self.end_headers()
|
||||
self.wfile.write(json.dumps({"status": "deleted", "name": workflow_name_to_delete}).encode())
|
||||
else:
|
||||
self._send_error("工作流未找到,无法删除", 404)
|
||||
|
||||
except Exception as e:
|
||||
self._send_error(f"删除工作流时出错: {e}")
|
||||
else:
|
||||
self._send_error("无效的删除路径", 400)
|
||||
|
||||
def _send_error(self, message, code=500):
|
||||
self.send_response(code)
|
||||
self.send_header('Content-type', 'application/json')
|
||||
self._send_cors_headers()
|
||||
self.end_headers()
|
||||
self.wfile.write(json.dumps({"status": "error", "message": message}).encode())
|
||||
|
||||
|
||||
# --- 主执行函数 ---
|
||||
def run(server_class=HTTPServer, handler_class=PersistentAPIHandler, port=PORT):
|
||||
# 1. 初始化数据库
|
||||
init_db()
|
||||
|
||||
# 2. 启动服务器
|
||||
server_address = ('', port)
|
||||
httpd = server_class(server_address, handler_class)
|
||||
print(f"\n持久化API服务器已在 http://localhost:{port} 上启动...")
|
||||
print(f"数据库文件: ./{DATABASE_FILE}")
|
||||
print(" - POST /api/workflow -> 发布/更新工作流")
|
||||
print(" - GET /api/workflow -> 获取所有工作流")
|
||||
print(" - DELETE /api/workflow/<name> -> 删除指定工作流")
|
||||
print("\n按 Ctrl+C 停止服务器。")
|
||||
httpd.serve_forever()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
|
@ -0,0 +1,13 @@
|
|||
# ComfyUI服务器地址
|
||||
COMFYUI_URL="ws://127.0.0.1:8188/ws"
|
||||
COMFYUI_HTTP_URL="http://127.0.0.1:8188"
|
||||
# 绝对路径!例如: /home/user/ComfyUI/input
|
||||
COMFYUI_INPUT_DIR="F:\ComfyUI\input"
|
||||
# 绝对路径!例如: /home/user/ComfyUI/output
|
||||
COMFYUI_OUTPUT_DIR="F:\ComfyUI\output"
|
||||
|
||||
# AWS S3 配置
|
||||
S3_BUCKET_NAME="modal-media-cache"
|
||||
AWS_ACCESS_KEY_ID="AKIAYRH5NGRSWHN2L4M6"
|
||||
AWS_SECRET_ACCESS_KEY="kfAqoOmIiyiywi25xaAkJUQbZ/EKDnzvI6NRCW1l"
|
||||
AWS_REGION_NAME="ap-northeast-2"
|
||||
|
|
@ -0,0 +1,74 @@
|
|||
import websockets
|
||||
import json
|
||||
import uuid
|
||||
import aiohttp
|
||||
import random
|
||||
from config import settings
|
||||
|
||||
|
||||
async def queue_prompt(prompt: dict, client_id: str) -> str:
|
||||
"""通过HTTP POST将工作流任务提交到ComfyUI队列。"""
|
||||
# [缓存破解] 在每个节点中添加一个随机数输入,强制重新执行
|
||||
for node_id in prompt:
|
||||
prompt[node_id]["inputs"][f"cache_buster_{uuid.uuid4().hex}"] = random.random()
|
||||
|
||||
payload = {"prompt": prompt, "client_id": client_id}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
http_url = f"{settings.COMFYUI_HTTP_URL}/prompt"
|
||||
async with session.post(http_url, json=payload) as response:
|
||||
response.raise_for_status()
|
||||
result = await response.json()
|
||||
if "prompt_id" not in result:
|
||||
raise Exception(f"Invalid response from ComfyUI /prompt endpoint: {result}")
|
||||
return result["prompt_id"]
|
||||
|
||||
|
||||
async def get_execution_results(prompt_id: str, client_id: str) -> dict:
|
||||
"""
|
||||
通过WebSocket连接,聚合所有'executed'事件的输出,
|
||||
直到整个执行流程结束。
|
||||
"""
|
||||
ws_url = f"{settings.COMFYUI_URL}?clientId={client_id}"
|
||||
aggregated_outputs = {}
|
||||
|
||||
async with websockets.connect(ws_url) as websocket:
|
||||
while True:
|
||||
try:
|
||||
out = await websocket.recv()
|
||||
if isinstance(out, str):
|
||||
message = json.loads(out)
|
||||
msg_type = message.get('type')
|
||||
data = message.get('data')
|
||||
|
||||
# 我们只关心与我们prompt_id相关的事件
|
||||
if data and data.get('prompt_id') == prompt_id:
|
||||
if msg_type == 'executed':
|
||||
# 聚合每个节点的输出
|
||||
node_id = data.get('node')
|
||||
output_data = data.get('output')
|
||||
if node_id and output_data:
|
||||
aggregated_outputs[node_id] = output_data
|
||||
print(f"Output received for node {node_id} (Prompt ID: {prompt_id})")
|
||||
|
||||
# 判断执行是否结束
|
||||
# 官方UI使用 "executing" 且 node is null 作为结束标志
|
||||
elif msg_type == 'executing' and data.get('node') is None:
|
||||
print(f"Execution finished for Prompt ID: {prompt_id}")
|
||||
return aggregated_outputs
|
||||
|
||||
except websockets.exceptions.ConnectionClosed as e:
|
||||
print(f"WebSocket connection closed for {prompt_id}. Returning aggregated results. Error: {e}")
|
||||
return aggregated_outputs # 连接关闭也视为结束
|
||||
except Exception as e:
|
||||
print(f"An error occurred for {prompt_id}: {e}")
|
||||
break
|
||||
return aggregated_outputs
|
||||
|
||||
|
||||
async def run_workflow(prompt: dict) -> dict:
|
||||
"""主协调函数:提交任务,然后等待结果。"""
|
||||
client_id = str(uuid.uuid4())
|
||||
prompt_id = await queue_prompt(prompt, client_id)
|
||||
print(f"Workflow successfully queued with Prompt ID: {prompt_id}")
|
||||
results = await get_execution_results(prompt_id, client_id)
|
||||
return results
|
||||
|
|
@ -0,0 +1,14 @@
|
|||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
class Settings(BaseSettings):
|
||||
model_config = SettingsConfigDict(env_file='.env', env_file_encoding='utf-8')
|
||||
COMFYUI_URL: str = "ws://127.0.0.1:8188/ws"
|
||||
COMFYUI_HTTP_URL: str = "http://127.0.0.1:8188"
|
||||
COMFYUI_INPUT_DIR: str
|
||||
COMFYUI_OUTPUT_DIR: str
|
||||
S3_BUCKET_NAME: str
|
||||
AWS_ACCESS_KEY_ID: str
|
||||
AWS_SECRET_ACCESS_KEY: str
|
||||
AWS_REGION_NAME: str
|
||||
|
||||
settings = Settings()
|
||||
|
|
@ -0,0 +1,59 @@
|
|||
import aiosqlite
|
||||
import json
|
||||
import re
|
||||
|
||||
DATABASE_FILE = "workflows_service.sqlite"
|
||||
|
||||
async def init_db():
|
||||
async with aiosqlite.connect(DATABASE_FILE) as db:
|
||||
await db.execute("""
|
||||
CREATE TABLE IF NOT EXISTS workflows (
|
||||
name TEXT PRIMARY KEY,
|
||||
base_name TEXT NOT NULL,
|
||||
version TEXT NOT NULL,
|
||||
workflow_json TEXT NOT NULL
|
||||
)
|
||||
""")
|
||||
await db.commit()
|
||||
print(f"数据库 '{DATABASE_FILE}' 已准备就绪。")
|
||||
|
||||
async def save_workflow(name: str, workflow_json: str):
|
||||
version_match = re.search(r"\[(20\d{12})\]$", name)
|
||||
if not version_match:
|
||||
raise ValueError("Workflow name must have a version suffix like [YYYYMMDDHHMMSS]")
|
||||
version = version_match.group(1)
|
||||
base_name = re.sub(r"\s*\[(20\d{12})\]$", "", name).strip()
|
||||
async with aiosqlite.connect(DATABASE_FILE) as db:
|
||||
await db.execute(
|
||||
"INSERT OR REPLACE INTO workflows (name, base_name, version, workflow_json) VALUES (?, ?, ?, ?)",
|
||||
(name, base_name, version, workflow_json)
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
async def get_all_workflows() -> list[dict]:
|
||||
async with aiosqlite.connect(DATABASE_FILE) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
cursor = await db.execute("SELECT name, workflow_json FROM workflows")
|
||||
rows = await cursor.fetchall()
|
||||
return [{"name": row["name"], "workflow": json.loads(row["workflow_json"])} for row in rows]
|
||||
|
||||
async def get_latest_workflow_by_base_name(base_name: str) -> dict | None:
|
||||
async with aiosqlite.connect(DATABASE_FILE) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
cursor = await db.execute("SELECT * FROM workflows WHERE base_name = ? ORDER BY version DESC LIMIT 1", (base_name,))
|
||||
row = await cursor.fetchone()
|
||||
return dict(row) if row else None
|
||||
|
||||
async def get_workflow_by_version(base_name: str, version: str) -> dict | None:
|
||||
name = f"{base_name} [{version}]"
|
||||
async with aiosqlite.connect(DATABASE_FILE) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
cursor = await db.execute("SELECT * FROM workflows WHERE name = ?", (name,))
|
||||
row = await cursor.fetchone()
|
||||
return dict(row) if row else None
|
||||
|
||||
async def delete_workflow(name: str) -> bool:
|
||||
async with aiosqlite.connect(DATABASE_FILE) as db:
|
||||
cursor = await db.execute("DELETE FROM workflows WHERE name = ?", (name,))
|
||||
await db.commit()
|
||||
return cursor.rowcount > 0
|
||||
|
|
@ -0,0 +1,212 @@
|
|||
import uvicorn
|
||||
from fastapi import FastAPI, Request, HTTPException, Path
|
||||
from fastapi.responses import JSONResponse
|
||||
from typing import Optional, List, Dict, Any, Set
|
||||
import json, uuid, os, shutil, aiohttp, database, workflow_parser, comfyui_client, s3_client
|
||||
from config import settings
|
||||
|
||||
app = FastAPI(title="ComfyUI Workflow Service & Management API")
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event(): await database.init_db(); os.makedirs(settings.COMFYUI_INPUT_DIR,
|
||||
exist_ok=True); os.makedirs(
|
||||
settings.COMFYUI_OUTPUT_DIR, exist_ok=True)
|
||||
|
||||
|
||||
# --- Section 1: 工作流管理API (无改动) ---
|
||||
# ... (代码与上一版完全相同) ...
|
||||
BASE_MANAGEMENT_PATH = "/api/workflow"
|
||||
|
||||
|
||||
@app.post(BASE_MANAGEMENT_PATH, status_code=200)
|
||||
async def publish_workflow_endpoint(request: Request):
|
||||
try:
|
||||
data = await request.json(); name, wf_json = data.get("name"), data.get(
|
||||
"workflow"); await database.save_workflow(name, json.dumps(wf_json)); return JSONResponse(
|
||||
content={"status": "success", "message": f"Workflow '{name}' published."}, status_code=200)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to save workflow: {e}")
|
||||
|
||||
|
||||
@app.get(BASE_MANAGEMENT_PATH, response_model=List[dict])
|
||||
async def get_all_workflows_endpoint():
|
||||
try:
|
||||
return await database.get_all_workflows()
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get workflows: {e}")
|
||||
|
||||
|
||||
@app.delete(f"{BASE_MANAGEMENT_PATH}/{{workflow_name:path}}")
|
||||
async def delete_workflow_endpoint(workflow_name: str = Path(..., title="...")):
|
||||
try:
|
||||
success = await database.delete_workflow(workflow_name);
|
||||
if success:
|
||||
return {"status": "deleted", "name": workflow_name};
|
||||
else:
|
||||
raise HTTPException(status_code=404, detail="Workflow not found")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to delete workflow: {e}")
|
||||
|
||||
|
||||
# --- Section 2: 工作流执行API (核心改动) ---
|
||||
def get_files_in_dir(directory: str) -> Set[str]:
|
||||
file_set = set()
|
||||
for root, _, files in os.walk(directory):
|
||||
for filename in files:
|
||||
if not filename.startswith('.'): file_set.add(os.path.join(root, filename))
|
||||
return file_set
|
||||
|
||||
|
||||
async def download_file_from_url(session: aiohttp.ClientSession, url: str, save_path: str):
|
||||
async with session.get(url) as response:
|
||||
response.raise_for_status();
|
||||
with open(save_path, 'wb') as f:
|
||||
while True:
|
||||
chunk = await response.content.read(8192);
|
||||
if not chunk:
|
||||
break;
|
||||
f.write(chunk)
|
||||
|
||||
|
||||
async def handle_file_upload(file_path: str, base_name: str) -> str:
|
||||
"""辅助函数:上传文件到S3并返回URL"""
|
||||
s3_object_name = f"outputs/{base_name}/{uuid.uuid4()}_{os.path.basename(file_path)}"
|
||||
return await s3_client.upload_file_to_s3(file_path, settings.S3_BUCKET_NAME, s3_object_name)
|
||||
|
||||
|
||||
@app.post("/api/run/{base_name}")
|
||||
async def execute_workflow_endpoint(base_name: str, request_data_raw: Dict[str, Any], version: Optional[str] = None):
|
||||
cleanup_paths = []
|
||||
try:
|
||||
# 1. 获取工作流和处理输入 (与上一版相同)
|
||||
# ...
|
||||
if version:
|
||||
workflow_data = await database.get_workflow_by_version(base_name, version)
|
||||
else:
|
||||
workflow_data = await database.get_latest_workflow_by_base_name(base_name)
|
||||
if not workflow_data: raise HTTPException(status_code=404, detail=f"Workflow '{base_name}' not found.")
|
||||
workflow = json.loads(workflow_data['workflow_json'])
|
||||
api_spec = workflow_parser.parse_api_spec(workflow)
|
||||
request_data = {k.lower(): v for k, v in request_data_raw.items()}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
for param_name, spec in api_spec["inputs"].items():
|
||||
if spec["type"] == "UploadFile" and param_name in request_data:
|
||||
image_url = request_data[param_name]
|
||||
if not isinstance(image_url, str) or not image_url.startswith('http'): raise HTTPException(
|
||||
status_code=400, detail=f"Parameter '{param_name}' must be a valid URL.")
|
||||
original_filename = image_url.split('/')[-1].split('?')[0];
|
||||
_, file_extension = os.path.splitext(original_filename)
|
||||
if not file_extension: file_extension = '.dat'
|
||||
filename = f"api_download_{uuid.uuid4()}{file_extension}"
|
||||
save_path = os.path.join(settings.COMFYUI_INPUT_DIR, filename)
|
||||
try:
|
||||
await download_file_from_url(session, image_url, save_path);
|
||||
request_data[param_name] = filename; cleanup_paths.append(save_path)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500,
|
||||
detail=f"Failed to download file for '{param_name}' from {image_url}. Error: {e}")
|
||||
|
||||
# 2. 执行前快照
|
||||
files_before = get_files_in_dir(settings.COMFYUI_OUTPUT_DIR)
|
||||
|
||||
# 3. Patch, 转换并执行 (缓存破解已移入client)
|
||||
patched_workflow = workflow_parser.patch_workflow(workflow, api_spec, request_data)
|
||||
prompt_to_run = workflow_parser.convert_workflow_to_prompt_api_format(patched_workflow)
|
||||
output_nodes = await comfyui_client.run_workflow(prompt_to_run)
|
||||
|
||||
# 4. 执行后快照并计算差异
|
||||
files_after = get_files_in_dir(settings.COMFYUI_OUTPUT_DIR)
|
||||
new_files = files_after - files_before
|
||||
|
||||
# 5. [核心修正] 统一处理所有输出
|
||||
output_response = {}
|
||||
processed_files = set() # 记录已通过文件系统快照处理的文件
|
||||
|
||||
# 5.1 首先处理所有新生成的文件
|
||||
if new_files:
|
||||
s3_urls = []
|
||||
for file_path in new_files:
|
||||
cleanup_paths.append(file_path)
|
||||
processed_files.add(os.path.basename(file_path))
|
||||
try:
|
||||
s3_urls.append(await handle_file_upload(file_path, base_name))
|
||||
except Exception as e:
|
||||
print(f"Error uploading file {file_path} to S3: {e}")
|
||||
if s3_urls: output_response["output_files"] = s3_urls
|
||||
|
||||
# 5.2 然后处理WebSocket返回的非文件输出,并检查文本输出是否是文件路径
|
||||
for final_param_name, spec in api_spec["outputs"].items():
|
||||
node_id = spec["node_id"]
|
||||
if node_id in output_nodes:
|
||||
node_output = output_nodes[node_id]
|
||||
original_output_name = spec["output_name"]
|
||||
|
||||
if original_output_name in node_output:
|
||||
output_value = node_output[original_output_name]
|
||||
# 展开列表
|
||||
if isinstance(output_value, list):
|
||||
output_value = output_value[0] if output_value else None
|
||||
|
||||
# 检查文本输出是否是未被发现的文件路径
|
||||
if isinstance(output_value, str) and (
|
||||
'.png' in output_value or '.jpg' in output_value or '.mp4' in output_value or 'output' in output_value):
|
||||
potential_filename = os.path.basename(output_value.replace('\\', '/'))
|
||||
if potential_filename not in processed_files:
|
||||
# 这是一个新的文件路径,尝试在output目录中找到它
|
||||
potential_path = os.path.join(settings.COMFYUI_OUTPUT_DIR, potential_filename)
|
||||
if os.path.exists(potential_path):
|
||||
print(f"Found extra file from text output: {potential_path}")
|
||||
cleanup_paths.append(potential_path)
|
||||
processed_files.add(potential_filename)
|
||||
try:
|
||||
s3_url = await handle_file_upload(potential_path, base_name)
|
||||
# 将它也加入output_files列表
|
||||
if "output_files" not in output_response: output_response["output_files"] = []
|
||||
output_response["output_files"].append(s3_url)
|
||||
except Exception as e:
|
||||
print(f"Error uploading extra file {potential_path} to S3: {e}")
|
||||
continue # 处理完毕,跳过将其作为文本输出
|
||||
output_response[final_param_name] = output_value
|
||||
elif "text" in node_output:
|
||||
output_value = node_output["text"]
|
||||
# 如果不是文件,则作为普通值输出
|
||||
output_response[final_param_name] = output_value
|
||||
|
||||
return output_response
|
||||
|
||||
finally:
|
||||
# 清理操作保持不变
|
||||
print(f"Cleaning up {len(cleanup_paths)} temporary files...")
|
||||
for path in cleanup_paths:
|
||||
try:
|
||||
if os.path.exists(path): os.remove(path); print(f" - Deleted: {path}")
|
||||
except Exception as e:
|
||||
print(f" - Error deleting {path}: {e}")
|
||||
|
||||
|
||||
# --- Section 3: 工作流元数据/规范API (无改动) ---
|
||||
# ... (此部分代码与上一版完全相同) ...
|
||||
@app.get("/api/spec/{base_name}")
|
||||
async def get_workflow_spec_endpoint(base_name: str, version: Optional[str] = None):
|
||||
# ...
|
||||
if version:
|
||||
workflow_data = await database.get_workflow_by_version(base_name, version)
|
||||
else:
|
||||
workflow_data = await database.get_latest_workflow_by_base_name(base_name)
|
||||
if not workflow_data:
|
||||
detail = f"Workflow '{base_name}'" + (f" with version '{version}'" if version else "") + " not found."
|
||||
raise HTTPException(status_code=404, detail=detail)
|
||||
try:
|
||||
workflow = json.loads(workflow_data['workflow_json'])
|
||||
return workflow_parser.parse_api_spec(workflow)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to parse workflow specification: {e}")
|
||||
|
||||
|
||||
@app.get("/")
|
||||
def read_root(): return {"message": "Welcome to the ComfyUI Workflow Service API!"}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(app, host="127.0.0.1", port=18000)
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
fastapi
|
||||
uvicorn[standard]
|
||||
pydantic-settings
|
||||
websockets
|
||||
aiohttp
|
||||
boto3
|
||||
aiosqlite
|
||||
|
|
@ -0,0 +1,17 @@
|
|||
import boto3
|
||||
from config import settings
|
||||
import asyncio
|
||||
|
||||
s3_client = boto3.client('s3', aws_access_key_id=settings.AWS_ACCESS_KEY_ID, aws_secret_access_key=settings.AWS_SECRET_ACCESS_KEY, region_name=settings.AWS_REGION_NAME)
|
||||
|
||||
async def upload_file_to_s3(file_path: str, bucket: str, object_name: str) -> str:
|
||||
"""从本地文件路径上传"""
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(None, lambda: s3_client.upload_file(file_path, bucket, object_name))
|
||||
return f"https://cdn.roasmax.cn/{object_name}"
|
||||
|
||||
async def upload_bytes_to_s3(file_bytes: bytes, bucket: str, object_name: str) -> str:
|
||||
"""直接从内存中的bytes上传 (新函数)"""
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(None, lambda: s3_client.put_object(Body=file_bytes, Bucket=bucket, Key=object_name))
|
||||
return f"https://cdn.roasmax.cn/{object_name}"
|
||||
|
|
@ -0,0 +1,151 @@
|
|||
import re
|
||||
from collections import defaultdict
|
||||
|
||||
API_INPUT_PREFIX = "INPUT_"
|
||||
API_OUTPUT_PREFIX = "OUTPUT_"
|
||||
|
||||
|
||||
def parse_api_spec(workflow_data: dict) -> dict:
|
||||
"""
|
||||
解析工作流,并根据规范 '{基础名}_{属性名}_{可选计数}' 生成API参数名。
|
||||
"""
|
||||
spec = {"inputs": {}, "outputs": {}}
|
||||
if "nodes" not in workflow_data or not isinstance(workflow_data["nodes"], list):
|
||||
raise ValueError("Invalid workflow format: 'nodes' key not found or is not a list.")
|
||||
|
||||
nodes_map = {str(node["id"]): node for node in workflow_data["nodes"]}
|
||||
|
||||
# 用于处理重名的计数器
|
||||
input_name_counter = defaultdict(int)
|
||||
output_name_counter = defaultdict(int)
|
||||
|
||||
for node_id, node in nodes_map.items():
|
||||
title = node.get("title")
|
||||
if not title: continue
|
||||
|
||||
# --- 处理API输入 ---
|
||||
if title.startswith(API_INPUT_PREFIX):
|
||||
base_name = title[len(API_INPUT_PREFIX):].lower()
|
||||
if "inputs" in node:
|
||||
for a_input in node.get("inputs", []):
|
||||
# 只关心由用户直接控制的widget输入
|
||||
if a_input.get("link") is None and "widget" in a_input:
|
||||
widget_name = a_input["widget"]["name"].lower()
|
||||
|
||||
# 构建基础参数名
|
||||
param_name_candidate = f"{base_name}_{widget_name}"
|
||||
|
||||
# 处理重名
|
||||
input_name_counter[param_name_candidate] += 1
|
||||
count = input_name_counter[param_name_candidate]
|
||||
final_param_name = f"{param_name_candidate}_{count}" if count > 1 else param_name_candidate
|
||||
|
||||
# 类型推断
|
||||
input_type_str = a_input.get("type", "STRING").upper()
|
||||
if "COMBO" in a_input.get("type"):
|
||||
# 如果是加载类节点,名字为'image'或'video'的widget是文件上传
|
||||
if a_input["widget"]["name"] in ["image", "video"]:
|
||||
param_type = "UploadFile"
|
||||
else:
|
||||
param_type = "string" # 加载节点其他参数默认为string
|
||||
elif "INT" in input_type_str:
|
||||
param_type = "int"
|
||||
elif "FLOAT" in input_type_str:
|
||||
param_type = "float"
|
||||
else:
|
||||
param_type = "string"
|
||||
|
||||
spec["inputs"][final_param_name] = {
|
||||
"node_id": node_id,
|
||||
"type": param_type,
|
||||
"widget_name": a_input["widget"]["name"] # 保留原始widget名用于patch
|
||||
}
|
||||
|
||||
# --- 处理API输出 ---
|
||||
elif title.startswith(API_OUTPUT_PREFIX):
|
||||
base_name = title[len(API_OUTPUT_PREFIX):].lower()
|
||||
|
||||
if "outputs" in node:
|
||||
for an_output in node.get("outputs", []):
|
||||
output_name = an_output["name"].lower()
|
||||
|
||||
# 构建基础参数名
|
||||
param_name_candidate = f"{base_name}_{output_name}"
|
||||
|
||||
# 处理重名
|
||||
output_name_counter[param_name_candidate] += 1
|
||||
count = output_name_counter[param_name_candidate]
|
||||
final_param_name = f"{param_name_candidate}_{count}" if count > 1 else param_name_candidate
|
||||
|
||||
spec["outputs"][final_param_name] = {
|
||||
"node_id": node_id,
|
||||
"class_type": node.get("type"),
|
||||
"output_name": an_output["name"], # 保留原始输出名用于结果匹配
|
||||
"output_index": node["outputs"].index(an_output) # 保留索引
|
||||
}
|
||||
|
||||
return spec
|
||||
|
||||
|
||||
# --- 其他函数保持不变 ---
|
||||
|
||||
def patch_workflow(workflow_data: dict, api_spec: dict, request_data: dict) -> dict:
|
||||
# (此函数与上一版完全相同,无需改动)
|
||||
if "nodes" not in workflow_data: raise ValueError("Invalid workflow format")
|
||||
nodes_map = {str(node["id"]): node for node in workflow_data["nodes"]}
|
||||
for param_name, value in request_data.items():
|
||||
if param_name in api_spec["inputs"]:
|
||||
spec = api_spec["inputs"][param_name]
|
||||
node_id = spec["node_id"]
|
||||
if node_id not in nodes_map: continue
|
||||
target_node = nodes_map[node_id]
|
||||
widget_name_to_patch = spec["widget_name"]
|
||||
widgets_values = target_node.get("widgets_values", {})
|
||||
if isinstance(widgets_values, dict):
|
||||
widgets_values[widget_name_to_patch] = value
|
||||
elif isinstance(widgets_values, list):
|
||||
widget_cursor = 0
|
||||
for input_config in target_node.get("inputs", []):
|
||||
if "widget" in input_config:
|
||||
if input_config["widget"].get("name") == widget_name_to_patch:
|
||||
if widget_cursor < len(widgets_values):
|
||||
target_type = str
|
||||
if spec['type'] == 'int':
|
||||
target_type = int
|
||||
elif spec['type'] == 'float':
|
||||
target_type = float
|
||||
widgets_values[widget_cursor] = target_type(value)
|
||||
break
|
||||
widget_cursor += 1
|
||||
target_node["widgets_values"] = widgets_values
|
||||
workflow_data["nodes"] = list(nodes_map.values())
|
||||
return workflow_data
|
||||
|
||||
|
||||
def convert_workflow_to_prompt_api_format(workflow_data: dict) -> dict:
|
||||
# (此函数与上一版完全相同,无需改动)
|
||||
if "nodes" not in workflow_data: raise ValueError("Invalid workflow format")
|
||||
prompt_api_format, link_map = {}, {}
|
||||
for link in workflow_data.get("links", []):
|
||||
link_map[link[0]] = [str(link[1]), link[2]]
|
||||
for node in workflow_data["nodes"]:
|
||||
node_id = str(node["id"])
|
||||
inputs_dict = {}
|
||||
widgets_values = node.get("widgets_values", [])
|
||||
if isinstance(widgets_values, dict):
|
||||
for key, val in widgets_values.items():
|
||||
if not isinstance(val, dict): inputs_dict[key] = val
|
||||
elif isinstance(widgets_values, list):
|
||||
widget_idx_counter = 0
|
||||
for input_config in node.get("inputs", []):
|
||||
if "widget" in input_config:
|
||||
if widget_idx_counter < len(widgets_values):
|
||||
inputs_dict[input_config["name"]] = widgets_values[widget_idx_counter]
|
||||
widget_idx_counter += 1
|
||||
for input_config in node.get("inputs", []):
|
||||
if "link" in input_config and input_config["link"] is not None:
|
||||
link_id = input_config["link"]
|
||||
if link_id in link_map:
|
||||
inputs_dict[input_config["name"]] = link_map[link_id]
|
||||
prompt_api_format[node_id] = {"class_type": node["type"], "inputs": inputs_dict}
|
||||
return prompt_api_format
|
||||
Loading…
Reference in New Issue