add 增加waas服务器demo

This commit is contained in:
kyj@bowong.ai 2025-07-31 16:25:52 +08:00
parent 591989e826
commit c69d3cc326
10 changed files with 631 additions and 159 deletions

View File

@ -5,4 +5,87 @@
- 工作流上传
- 工作流版本控制
- 工作流加载
- 需配置工作流服务器(详见workflow_server_demo.py文件)
- 需配置工作流服务器 (详见**WAAS(工作流即服务) Demo API服务器**)
# WAAS(工作流即服务) Demo API服务器
- 路径: ./workflow_service
- 必须按照指定规则命名
- **输入节点名**: 前缀 **INPUT_**
- **除生成文件节点外输出节点名**: 前缀 **OUTPUT_**
- 支持输入节点:
- comfyui-core
- 加载图像
- ComfyUI-VideoHelperSuite
- Load Video(Upload)
- comfyui-easy-use
- 整数
- 字符串
- 浮点数
- 支持输出节点:
- 所有在output文件夹中生成文件(图片/视频)的节点
- comfyui-easy-use
- 展示任何
- 数据库
- 类型: Sqlite (workflows_service.sqlite)
- 数据库结构
```
CREATE TABLE IF NOT EXISTS workflows (
name TEXT PRIMARY KEY,
base_name TEXT NOT NULL,
version TEXT NOT NULL,
workflow_json TEXT NOT NULL
)
```
- 路由
- GET /api/workflow: 列出工作流
- POST /api/workflow: 添加工作流
- DELETE /api/workflow: 删除工作流
- GET /api/run/{base_name}: 获取工作流输入输出元数据
```
输入:
*base_name: 工作流名称
version: 工作流版本
输出:
Json
{
"inputs": {
"image_image": {
"node_id": "13",
"type": "UploadFile",
"widget_name": "image"
},
"prefix_value": {
"node_id": "22",
"type": "int",
"widget_name": "value"
}
},
"outputs": {
"text_output": {
"node_id": "21",
"class_type": "easy showAnything",
"output_name": "output",
"output_index": 0
}
}
}
```
- POST /api/run/{base_name}: 执行工作流
```
输入:
*base_name: 工作流名称
version: 工作流版本
输出:
Json
{
"output_files": [
"https://cdn.roasmax.cn/outputs/测试/4e91e429-c848-4f66-885c-98a83c745872_111_00001_.png"
],
"text_output": [
"output\\111_00001_.png"
]
}
```

View File

@ -1,158 +0,0 @@
import sqlite3
import json
from http.server import HTTPServer, BaseHTTPRequestHandler
import urllib.parse
# --- 配置 ---
DATABASE_FILE = "workflows.sqlite"
PORT = 8000
# --- 数据库初始化 ---
def init_db():
"""初始化数据库,如果表不存在则创建它"""
with sqlite3.connect(DATABASE_FILE) as conn:
cursor = conn.cursor()
cursor.execute("""
CREATE TABLE IF NOT EXISTS workflows (
name TEXT PRIMARY KEY,
workflow_json TEXT NOT NULL
)
""")
conn.commit()
print(f"数据库 '{DATABASE_FILE}' 已准备就绪。")
# --- API 处理器 ---
class PersistentAPIHandler(BaseHTTPRequestHandler):
def _send_cors_headers(self):
"""发送CORS头部"""
self.send_header('Access-Control-Allow-Origin', '*')
self.send_header('Access-Control-Allow-Methods', 'GET, POST, DELETE, OPTIONS')
self.send_header("Access-Control-Allow-Headers", "X-Requested-With, Content-Type")
def do_OPTIONS(self):
"""处理预检请求"""
self.send_response(200, "ok")
self._send_cors_headers()
self.end_headers()
def do_GET(self):
"""处理获取所有工作流的请求"""
if self.path == '/api/workflow':
try:
with sqlite3.connect(DATABASE_FILE) as conn:
conn.row_factory = sqlite3.Row # 让查询结果可以像字典一样访问
cursor = conn.cursor()
cursor.execute("SELECT name, workflow_json FROM workflows")
rows = cursor.fetchall()
# 将数据库行转换为前端期望的JSON格式
workflows_list = [
{"name": row["name"], "workflow": json.loads(row["workflow_json"])}
for row in rows
]
self.send_response(200)
self.send_header('Content-type', 'application/json')
self._send_cors_headers()
self.end_headers()
self.wfile.write(json.dumps(workflows_list).encode())
except Exception as e:
self._send_error(f"获取工作流时出错: {e}")
else:
self._send_error("路径未找到", 404)
def do_POST(self):
"""处理发布/更新工作流的请求"""
if self.path == '/api/workflow':
try:
content_length = int(self.headers['Content-Length'])
post_data = self.rfile.read(content_length)
data = json.loads(post_data)
name = data.get('name')
workflow_data = data.get('workflow')
if not name or not workflow_data:
return self._send_error("请求体中缺少 'name''workflow'", 400)
workflow_json_str = json.dumps(workflow_data)
with sqlite3.connect(DATABASE_FILE) as conn:
cursor = conn.cursor()
# 使用 INSERT OR REPLACE 来处理创建和更新,非常方便
cursor.execute(
"INSERT OR REPLACE INTO workflows (name, workflow_json) VALUES (?, ?)",
(name, workflow_json_str)
)
conn.commit()
print(f"--- 已保存/更新工作流: {name} ---")
self.send_response(200)
self.send_header('Content-type', 'application/json')
self._send_cors_headers()
self.end_headers()
self.wfile.write(json.dumps({"status": "success", "name": name}).encode())
except Exception as e:
self._send_error(f"保存工作流时出错: {e}")
else:
self._send_error("路径未找到", 404)
def do_DELETE(self):
"""处理删除指定工作流的请求"""
parts = self.path.split('/api/workflow/')
if len(parts) == 2 and parts[1]:
workflow_name_to_delete = urllib.parse.unquote(parts[1])
try:
with sqlite3.connect(DATABASE_FILE) as conn:
cursor = conn.cursor()
cursor.execute("DELETE FROM workflows WHERE name = ?", (workflow_name_to_delete,))
conn.commit()
# 检查是否真的有行被删除了
if cursor.rowcount > 0:
print(f"--- 已删除工作流: {workflow_name_to_delete} ---")
self.send_response(200)
self.send_header('Content-type', 'application/json')
self._send_cors_headers()
self.end_headers()
self.wfile.write(json.dumps({"status": "deleted", "name": workflow_name_to_delete}).encode())
else:
self._send_error("工作流未找到,无法删除", 404)
except Exception as e:
self._send_error(f"删除工作流时出错: {e}")
else:
self._send_error("无效的删除路径", 400)
def _send_error(self, message, code=500):
self.send_response(code)
self.send_header('Content-type', 'application/json')
self._send_cors_headers()
self.end_headers()
self.wfile.write(json.dumps({"status": "error", "message": message}).encode())
# --- 主执行函数 ---
def run(server_class=HTTPServer, handler_class=PersistentAPIHandler, port=PORT):
# 1. 初始化数据库
init_db()
# 2. 启动服务器
server_address = ('', port)
httpd = server_class(server_address, handler_class)
print(f"\n持久化API服务器已在 http://localhost:{port} 上启动...")
print(f"数据库文件: ./{DATABASE_FILE}")
print(" - POST /api/workflow -> 发布/更新工作流")
print(" - GET /api/workflow -> 获取所有工作流")
print(" - DELETE /api/workflow/<name> -> 删除指定工作流")
print("\n按 Ctrl+C 停止服务器。")
httpd.serve_forever()
if __name__ == '__main__':
run()

13
workflow_service/.env Normal file
View File

@ -0,0 +1,13 @@
# ComfyUI服务器地址
COMFYUI_URL="ws://127.0.0.1:8188/ws"
COMFYUI_HTTP_URL="http://127.0.0.1:8188"
# 绝对路径!例如: /home/user/ComfyUI/input
COMFYUI_INPUT_DIR="F:\ComfyUI\input"
# 绝对路径!例如: /home/user/ComfyUI/output
COMFYUI_OUTPUT_DIR="F:\ComfyUI\output"
# AWS S3 配置
S3_BUCKET_NAME="modal-media-cache"
AWS_ACCESS_KEY_ID="AKIAYRH5NGRSWHN2L4M6"
AWS_SECRET_ACCESS_KEY="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"
AWS_REGION_NAME="ap-northeast-2"

View File

@ -0,0 +1,74 @@
import websockets
import json
import uuid
import aiohttp
import random
from config import settings
async def queue_prompt(prompt: dict, client_id: str) -> str:
"""通过HTTP POST将工作流任务提交到ComfyUI队列。"""
# [缓存破解] 在每个节点中添加一个随机数输入,强制重新执行
for node_id in prompt:
prompt[node_id]["inputs"][f"cache_buster_{uuid.uuid4().hex}"] = random.random()
payload = {"prompt": prompt, "client_id": client_id}
async with aiohttp.ClientSession() as session:
http_url = f"{settings.COMFYUI_HTTP_URL}/prompt"
async with session.post(http_url, json=payload) as response:
response.raise_for_status()
result = await response.json()
if "prompt_id" not in result:
raise Exception(f"Invalid response from ComfyUI /prompt endpoint: {result}")
return result["prompt_id"]
async def get_execution_results(prompt_id: str, client_id: str) -> dict:
"""
通过WebSocket连接聚合所有'executed'事件的输出
直到整个执行流程结束
"""
ws_url = f"{settings.COMFYUI_URL}?clientId={client_id}"
aggregated_outputs = {}
async with websockets.connect(ws_url) as websocket:
while True:
try:
out = await websocket.recv()
if isinstance(out, str):
message = json.loads(out)
msg_type = message.get('type')
data = message.get('data')
# 我们只关心与我们prompt_id相关的事件
if data and data.get('prompt_id') == prompt_id:
if msg_type == 'executed':
# 聚合每个节点的输出
node_id = data.get('node')
output_data = data.get('output')
if node_id and output_data:
aggregated_outputs[node_id] = output_data
print(f"Output received for node {node_id} (Prompt ID: {prompt_id})")
# 判断执行是否结束
# 官方UI使用 "executing" 且 node is null 作为结束标志
elif msg_type == 'executing' and data.get('node') is None:
print(f"Execution finished for Prompt ID: {prompt_id}")
return aggregated_outputs
except websockets.exceptions.ConnectionClosed as e:
print(f"WebSocket connection closed for {prompt_id}. Returning aggregated results. Error: {e}")
return aggregated_outputs # 连接关闭也视为结束
except Exception as e:
print(f"An error occurred for {prompt_id}: {e}")
break
return aggregated_outputs
async def run_workflow(prompt: dict) -> dict:
"""主协调函数:提交任务,然后等待结果。"""
client_id = str(uuid.uuid4())
prompt_id = await queue_prompt(prompt, client_id)
print(f"Workflow successfully queued with Prompt ID: {prompt_id}")
results = await get_execution_results(prompt_id, client_id)
return results

View File

@ -0,0 +1,14 @@
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
model_config = SettingsConfigDict(env_file='.env', env_file_encoding='utf-8')
COMFYUI_URL: str = "ws://127.0.0.1:8188/ws"
COMFYUI_HTTP_URL: str = "http://127.0.0.1:8188"
COMFYUI_INPUT_DIR: str
COMFYUI_OUTPUT_DIR: str
S3_BUCKET_NAME: str
AWS_ACCESS_KEY_ID: str
AWS_SECRET_ACCESS_KEY: str
AWS_REGION_NAME: str
settings = Settings()

View File

@ -0,0 +1,59 @@
import aiosqlite
import json
import re
DATABASE_FILE = "workflows_service.sqlite"
async def init_db():
async with aiosqlite.connect(DATABASE_FILE) as db:
await db.execute("""
CREATE TABLE IF NOT EXISTS workflows (
name TEXT PRIMARY KEY,
base_name TEXT NOT NULL,
version TEXT NOT NULL,
workflow_json TEXT NOT NULL
)
""")
await db.commit()
print(f"数据库 '{DATABASE_FILE}' 已准备就绪。")
async def save_workflow(name: str, workflow_json: str):
version_match = re.search(r"\[(20\d{12})\]$", name)
if not version_match:
raise ValueError("Workflow name must have a version suffix like [YYYYMMDDHHMMSS]")
version = version_match.group(1)
base_name = re.sub(r"\s*\[(20\d{12})\]$", "", name).strip()
async with aiosqlite.connect(DATABASE_FILE) as db:
await db.execute(
"INSERT OR REPLACE INTO workflows (name, base_name, version, workflow_json) VALUES (?, ?, ?, ?)",
(name, base_name, version, workflow_json)
)
await db.commit()
async def get_all_workflows() -> list[dict]:
async with aiosqlite.connect(DATABASE_FILE) as db:
db.row_factory = aiosqlite.Row
cursor = await db.execute("SELECT name, workflow_json FROM workflows")
rows = await cursor.fetchall()
return [{"name": row["name"], "workflow": json.loads(row["workflow_json"])} for row in rows]
async def get_latest_workflow_by_base_name(base_name: str) -> dict | None:
async with aiosqlite.connect(DATABASE_FILE) as db:
db.row_factory = aiosqlite.Row
cursor = await db.execute("SELECT * FROM workflows WHERE base_name = ? ORDER BY version DESC LIMIT 1", (base_name,))
row = await cursor.fetchone()
return dict(row) if row else None
async def get_workflow_by_version(base_name: str, version: str) -> dict | None:
name = f"{base_name} [{version}]"
async with aiosqlite.connect(DATABASE_FILE) as db:
db.row_factory = aiosqlite.Row
cursor = await db.execute("SELECT * FROM workflows WHERE name = ?", (name,))
row = await cursor.fetchone()
return dict(row) if row else None
async def delete_workflow(name: str) -> bool:
async with aiosqlite.connect(DATABASE_FILE) as db:
cursor = await db.execute("DELETE FROM workflows WHERE name = ?", (name,))
await db.commit()
return cursor.rowcount > 0

212
workflow_service/main.py Normal file
View File

@ -0,0 +1,212 @@
import uvicorn
from fastapi import FastAPI, Request, HTTPException, Path
from fastapi.responses import JSONResponse
from typing import Optional, List, Dict, Any, Set
import json, uuid, os, shutil, aiohttp, database, workflow_parser, comfyui_client, s3_client
from config import settings
app = FastAPI(title="ComfyUI Workflow Service & Management API")
@app.on_event("startup")
async def startup_event(): await database.init_db(); os.makedirs(settings.COMFYUI_INPUT_DIR,
exist_ok=True); os.makedirs(
settings.COMFYUI_OUTPUT_DIR, exist_ok=True)
# --- Section 1: 工作流管理API (无改动) ---
# ... (代码与上一版完全相同) ...
BASE_MANAGEMENT_PATH = "/api/workflow"
@app.post(BASE_MANAGEMENT_PATH, status_code=200)
async def publish_workflow_endpoint(request: Request):
try:
data = await request.json(); name, wf_json = data.get("name"), data.get(
"workflow"); await database.save_workflow(name, json.dumps(wf_json)); return JSONResponse(
content={"status": "success", "message": f"Workflow '{name}' published."}, status_code=200)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to save workflow: {e}")
@app.get(BASE_MANAGEMENT_PATH, response_model=List[dict])
async def get_all_workflows_endpoint():
try:
return await database.get_all_workflows()
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to get workflows: {e}")
@app.delete(f"{BASE_MANAGEMENT_PATH}/{{workflow_name:path}}")
async def delete_workflow_endpoint(workflow_name: str = Path(..., title="...")):
try:
success = await database.delete_workflow(workflow_name);
if success:
return {"status": "deleted", "name": workflow_name};
else:
raise HTTPException(status_code=404, detail="Workflow not found")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to delete workflow: {e}")
# --- Section 2: 工作流执行API (核心改动) ---
def get_files_in_dir(directory: str) -> Set[str]:
file_set = set()
for root, _, files in os.walk(directory):
for filename in files:
if not filename.startswith('.'): file_set.add(os.path.join(root, filename))
return file_set
async def download_file_from_url(session: aiohttp.ClientSession, url: str, save_path: str):
async with session.get(url) as response:
response.raise_for_status();
with open(save_path, 'wb') as f:
while True:
chunk = await response.content.read(8192);
if not chunk:
break;
f.write(chunk)
async def handle_file_upload(file_path: str, base_name: str) -> str:
"""辅助函数上传文件到S3并返回URL"""
s3_object_name = f"outputs/{base_name}/{uuid.uuid4()}_{os.path.basename(file_path)}"
return await s3_client.upload_file_to_s3(file_path, settings.S3_BUCKET_NAME, s3_object_name)
@app.post("/api/run/{base_name}")
async def execute_workflow_endpoint(base_name: str, request_data_raw: Dict[str, Any], version: Optional[str] = None):
cleanup_paths = []
try:
# 1. 获取工作流和处理输入 (与上一版相同)
# ...
if version:
workflow_data = await database.get_workflow_by_version(base_name, version)
else:
workflow_data = await database.get_latest_workflow_by_base_name(base_name)
if not workflow_data: raise HTTPException(status_code=404, detail=f"Workflow '{base_name}' not found.")
workflow = json.loads(workflow_data['workflow_json'])
api_spec = workflow_parser.parse_api_spec(workflow)
request_data = {k.lower(): v for k, v in request_data_raw.items()}
async with aiohttp.ClientSession() as session:
for param_name, spec in api_spec["inputs"].items():
if spec["type"] == "UploadFile" and param_name in request_data:
image_url = request_data[param_name]
if not isinstance(image_url, str) or not image_url.startswith('http'): raise HTTPException(
status_code=400, detail=f"Parameter '{param_name}' must be a valid URL.")
original_filename = image_url.split('/')[-1].split('?')[0];
_, file_extension = os.path.splitext(original_filename)
if not file_extension: file_extension = '.dat'
filename = f"api_download_{uuid.uuid4()}{file_extension}"
save_path = os.path.join(settings.COMFYUI_INPUT_DIR, filename)
try:
await download_file_from_url(session, image_url, save_path);
request_data[param_name] = filename; cleanup_paths.append(save_path)
except Exception as e:
raise HTTPException(status_code=500,
detail=f"Failed to download file for '{param_name}' from {image_url}. Error: {e}")
# 2. 执行前快照
files_before = get_files_in_dir(settings.COMFYUI_OUTPUT_DIR)
# 3. Patch, 转换并执行 (缓存破解已移入client)
patched_workflow = workflow_parser.patch_workflow(workflow, api_spec, request_data)
prompt_to_run = workflow_parser.convert_workflow_to_prompt_api_format(patched_workflow)
output_nodes = await comfyui_client.run_workflow(prompt_to_run)
# 4. 执行后快照并计算差异
files_after = get_files_in_dir(settings.COMFYUI_OUTPUT_DIR)
new_files = files_after - files_before
# 5. [核心修正] 统一处理所有输出
output_response = {}
processed_files = set() # 记录已通过文件系统快照处理的文件
# 5.1 首先处理所有新生成的文件
if new_files:
s3_urls = []
for file_path in new_files:
cleanup_paths.append(file_path)
processed_files.add(os.path.basename(file_path))
try:
s3_urls.append(await handle_file_upload(file_path, base_name))
except Exception as e:
print(f"Error uploading file {file_path} to S3: {e}")
if s3_urls: output_response["output_files"] = s3_urls
# 5.2 然后处理WebSocket返回的非文件输出并检查文本输出是否是文件路径
for final_param_name, spec in api_spec["outputs"].items():
node_id = spec["node_id"]
if node_id in output_nodes:
node_output = output_nodes[node_id]
original_output_name = spec["output_name"]
if original_output_name in node_output:
output_value = node_output[original_output_name]
# 展开列表
if isinstance(output_value, list):
output_value = output_value[0] if output_value else None
# 检查文本输出是否是未被发现的文件路径
if isinstance(output_value, str) and (
'.png' in output_value or '.jpg' in output_value or '.mp4' in output_value or 'output' in output_value):
potential_filename = os.path.basename(output_value.replace('\\', '/'))
if potential_filename not in processed_files:
# 这是一个新的文件路径尝试在output目录中找到它
potential_path = os.path.join(settings.COMFYUI_OUTPUT_DIR, potential_filename)
if os.path.exists(potential_path):
print(f"Found extra file from text output: {potential_path}")
cleanup_paths.append(potential_path)
processed_files.add(potential_filename)
try:
s3_url = await handle_file_upload(potential_path, base_name)
# 将它也加入output_files列表
if "output_files" not in output_response: output_response["output_files"] = []
output_response["output_files"].append(s3_url)
except Exception as e:
print(f"Error uploading extra file {potential_path} to S3: {e}")
continue # 处理完毕,跳过将其作为文本输出
output_response[final_param_name] = output_value
elif "text" in node_output:
output_value = node_output["text"]
# 如果不是文件,则作为普通值输出
output_response[final_param_name] = output_value
return output_response
finally:
# 清理操作保持不变
print(f"Cleaning up {len(cleanup_paths)} temporary files...")
for path in cleanup_paths:
try:
if os.path.exists(path): os.remove(path); print(f" - Deleted: {path}")
except Exception as e:
print(f" - Error deleting {path}: {e}")
# --- Section 3: 工作流元数据/规范API (无改动) ---
# ... (此部分代码与上一版完全相同) ...
@app.get("/api/spec/{base_name}")
async def get_workflow_spec_endpoint(base_name: str, version: Optional[str] = None):
# ...
if version:
workflow_data = await database.get_workflow_by_version(base_name, version)
else:
workflow_data = await database.get_latest_workflow_by_base_name(base_name)
if not workflow_data:
detail = f"Workflow '{base_name}'" + (f" with version '{version}'" if version else "") + " not found."
raise HTTPException(status_code=404, detail=detail)
try:
workflow = json.loads(workflow_data['workflow_json'])
return workflow_parser.parse_api_spec(workflow)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to parse workflow specification: {e}")
@app.get("/")
def read_root(): return {"message": "Welcome to the ComfyUI Workflow Service API!"}
if __name__ == "__main__":
uvicorn.run(app, host="127.0.0.1", port=18000)

View File

@ -0,0 +1,7 @@
fastapi
uvicorn[standard]
pydantic-settings
websockets
aiohttp
boto3
aiosqlite

View File

@ -0,0 +1,17 @@
import boto3
from config import settings
import asyncio
s3_client = boto3.client('s3', aws_access_key_id=settings.AWS_ACCESS_KEY_ID, aws_secret_access_key=settings.AWS_SECRET_ACCESS_KEY, region_name=settings.AWS_REGION_NAME)
async def upload_file_to_s3(file_path: str, bucket: str, object_name: str) -> str:
"""从本地文件路径上传"""
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, lambda: s3_client.upload_file(file_path, bucket, object_name))
return f"https://cdn.roasmax.cn/{object_name}"
async def upload_bytes_to_s3(file_bytes: bytes, bucket: str, object_name: str) -> str:
"""直接从内存中的bytes上传 (新函数)"""
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, lambda: s3_client.put_object(Body=file_bytes, Bucket=bucket, Key=object_name))
return f"https://cdn.roasmax.cn/{object_name}"

View File

@ -0,0 +1,151 @@
import re
from collections import defaultdict
API_INPUT_PREFIX = "INPUT_"
API_OUTPUT_PREFIX = "OUTPUT_"
def parse_api_spec(workflow_data: dict) -> dict:
"""
解析工作流并根据规范 '{基础名}_{属性名}_{可选计数}' 生成API参数名
"""
spec = {"inputs": {}, "outputs": {}}
if "nodes" not in workflow_data or not isinstance(workflow_data["nodes"], list):
raise ValueError("Invalid workflow format: 'nodes' key not found or is not a list.")
nodes_map = {str(node["id"]): node for node in workflow_data["nodes"]}
# 用于处理重名的计数器
input_name_counter = defaultdict(int)
output_name_counter = defaultdict(int)
for node_id, node in nodes_map.items():
title = node.get("title")
if not title: continue
# --- 处理API输入 ---
if title.startswith(API_INPUT_PREFIX):
base_name = title[len(API_INPUT_PREFIX):].lower()
if "inputs" in node:
for a_input in node.get("inputs", []):
# 只关心由用户直接控制的widget输入
if a_input.get("link") is None and "widget" in a_input:
widget_name = a_input["widget"]["name"].lower()
# 构建基础参数名
param_name_candidate = f"{base_name}_{widget_name}"
# 处理重名
input_name_counter[param_name_candidate] += 1
count = input_name_counter[param_name_candidate]
final_param_name = f"{param_name_candidate}_{count}" if count > 1 else param_name_candidate
# 类型推断
input_type_str = a_input.get("type", "STRING").upper()
if "COMBO" in a_input.get("type"):
# 如果是加载类节点,名字为'image'或'video'的widget是文件上传
if a_input["widget"]["name"] in ["image", "video"]:
param_type = "UploadFile"
else:
param_type = "string" # 加载节点其他参数默认为string
elif "INT" in input_type_str:
param_type = "int"
elif "FLOAT" in input_type_str:
param_type = "float"
else:
param_type = "string"
spec["inputs"][final_param_name] = {
"node_id": node_id,
"type": param_type,
"widget_name": a_input["widget"]["name"] # 保留原始widget名用于patch
}
# --- 处理API输出 ---
elif title.startswith(API_OUTPUT_PREFIX):
base_name = title[len(API_OUTPUT_PREFIX):].lower()
if "outputs" in node:
for an_output in node.get("outputs", []):
output_name = an_output["name"].lower()
# 构建基础参数名
param_name_candidate = f"{base_name}_{output_name}"
# 处理重名
output_name_counter[param_name_candidate] += 1
count = output_name_counter[param_name_candidate]
final_param_name = f"{param_name_candidate}_{count}" if count > 1 else param_name_candidate
spec["outputs"][final_param_name] = {
"node_id": node_id,
"class_type": node.get("type"),
"output_name": an_output["name"], # 保留原始输出名用于结果匹配
"output_index": node["outputs"].index(an_output) # 保留索引
}
return spec
# --- 其他函数保持不变 ---
def patch_workflow(workflow_data: dict, api_spec: dict, request_data: dict) -> dict:
# (此函数与上一版完全相同,无需改动)
if "nodes" not in workflow_data: raise ValueError("Invalid workflow format")
nodes_map = {str(node["id"]): node for node in workflow_data["nodes"]}
for param_name, value in request_data.items():
if param_name in api_spec["inputs"]:
spec = api_spec["inputs"][param_name]
node_id = spec["node_id"]
if node_id not in nodes_map: continue
target_node = nodes_map[node_id]
widget_name_to_patch = spec["widget_name"]
widgets_values = target_node.get("widgets_values", {})
if isinstance(widgets_values, dict):
widgets_values[widget_name_to_patch] = value
elif isinstance(widgets_values, list):
widget_cursor = 0
for input_config in target_node.get("inputs", []):
if "widget" in input_config:
if input_config["widget"].get("name") == widget_name_to_patch:
if widget_cursor < len(widgets_values):
target_type = str
if spec['type'] == 'int':
target_type = int
elif spec['type'] == 'float':
target_type = float
widgets_values[widget_cursor] = target_type(value)
break
widget_cursor += 1
target_node["widgets_values"] = widgets_values
workflow_data["nodes"] = list(nodes_map.values())
return workflow_data
def convert_workflow_to_prompt_api_format(workflow_data: dict) -> dict:
# (此函数与上一版完全相同,无需改动)
if "nodes" not in workflow_data: raise ValueError("Invalid workflow format")
prompt_api_format, link_map = {}, {}
for link in workflow_data.get("links", []):
link_map[link[0]] = [str(link[1]), link[2]]
for node in workflow_data["nodes"]:
node_id = str(node["id"])
inputs_dict = {}
widgets_values = node.get("widgets_values", [])
if isinstance(widgets_values, dict):
for key, val in widgets_values.items():
if not isinstance(val, dict): inputs_dict[key] = val
elif isinstance(widgets_values, list):
widget_idx_counter = 0
for input_config in node.get("inputs", []):
if "widget" in input_config:
if widget_idx_counter < len(widgets_values):
inputs_dict[input_config["name"]] = widgets_values[widget_idx_counter]
widget_idx_counter += 1
for input_config in node.get("inputs", []):
if "link" in input_config and input_config["link"] is not None:
link_id = input_config["link"]
if link_id in link_map:
inputs_dict[input_config["name"]] = link_map[link_id]
prompt_api_format[node_id] = {"class_type": node["type"], "inputs": inputs_dict}
return prompt_api_format