From c69d3cc32616b34e9193218c3ef33c670ff44042 Mon Sep 17 00:00:00 2001 From: "kyj@bowong.ai" Date: Thu, 31 Jul 2025 16:25:52 +0800 Subject: [PATCH] =?UTF-8?q?add=20=E5=A2=9E=E5=8A=A0waas=E6=9C=8D=E5=8A=A1?= =?UTF-8?q?=E5=99=A8demo?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 85 ++++++++++- workflow_server_demo.py | 158 --------------------- workflow_service/.env | 13 ++ workflow_service/comfyui_client.py | 74 ++++++++++ workflow_service/config.py | 14 ++ workflow_service/database.py | 59 ++++++++ workflow_service/main.py | 212 ++++++++++++++++++++++++++++ workflow_service/requirements.txt | 7 + workflow_service/s3_client.py | 17 +++ workflow_service/workflow_parser.py | 151 ++++++++++++++++++++ 10 files changed, 631 insertions(+), 159 deletions(-) delete mode 100644 workflow_server_demo.py create mode 100644 workflow_service/.env create mode 100644 workflow_service/comfyui_client.py create mode 100644 workflow_service/config.py create mode 100644 workflow_service/database.py create mode 100644 workflow_service/main.py create mode 100644 workflow_service/requirements.txt create mode 100644 workflow_service/s3_client.py create mode 100644 workflow_service/workflow_parser.py diff --git a/README.md b/README.md index 059b2e1..08ac422 100644 --- a/README.md +++ b/README.md @@ -5,4 +5,87 @@ - 工作流上传 - 工作流版本控制 - 工作流加载 -- 需配置工作流服务器(详见workflow_server_demo.py文件) \ No newline at end of file +- 需配置工作流服务器 (详见**WAAS(工作流即服务) Demo API服务器**) + + +# WAAS(工作流即服务) Demo API服务器 +- 路径: ./workflow_service +- 必须按照指定规则命名 + - **输入节点名**: 前缀 **INPUT_** + - **除生成文件节点外输出节点名**: 前缀 **OUTPUT_** +- 支持输入节点: + - comfyui-core + - 加载图像 + - ComfyUI-VideoHelperSuite + - Load Video(Upload) + - comfyui-easy-use + - 整数 + - 字符串 + - 浮点数 +- 支持输出节点: + - 所有在output文件夹中生成文件(图片/视频)的节点 + - comfyui-easy-use + - 展示任何 +- 数据库 + - 类型: Sqlite (workflows_service.sqlite) +- 数据库结构 + ``` + CREATE TABLE IF NOT EXISTS workflows ( + name TEXT PRIMARY KEY, + base_name TEXT NOT NULL, + version TEXT NOT NULL, + workflow_json TEXT NOT NULL + ) + ``` +- 路由 + - GET /api/workflow: 列出工作流 + - POST /api/workflow: 添加工作流 + - DELETE /api/workflow: 删除工作流 + - GET /api/run/{base_name}: 获取工作流输入输出元数据 + ``` + 输入: + *base_name: 工作流名称 + version: 工作流版本 + + 输出: + Json + { + "inputs": { + "image_image": { + "node_id": "13", + "type": "UploadFile", + "widget_name": "image" + }, + "prefix_value": { + "node_id": "22", + "type": "int", + "widget_name": "value" + } + }, + "outputs": { + "text_output": { + "node_id": "21", + "class_type": "easy showAnything", + "output_name": "output", + "output_index": 0 + } + } + } + ``` + - POST /api/run/{base_name}: 执行工作流 + ``` + 输入: + *base_name: 工作流名称 + version: 工作流版本 + + 输出: + Json + { + "output_files": [ + "https://cdn.roasmax.cn/outputs/测试/4e91e429-c848-4f66-885c-98a83c745872_111_00001_.png" + ], + "text_output": [ + "output\\111_00001_.png" + ] + } + ``` \ No newline at end of file diff --git a/workflow_server_demo.py b/workflow_server_demo.py deleted file mode 100644 index 0ffda3e..0000000 --- a/workflow_server_demo.py +++ /dev/null @@ -1,158 +0,0 @@ -import sqlite3 -import json -from http.server import HTTPServer, BaseHTTPRequestHandler -import urllib.parse - -# --- 配置 --- -DATABASE_FILE = "workflows.sqlite" -PORT = 8000 - - -# --- 数据库初始化 --- -def init_db(): - """初始化数据库,如果表不存在则创建它""" - with sqlite3.connect(DATABASE_FILE) as conn: - cursor = conn.cursor() - cursor.execute(""" - CREATE TABLE IF NOT EXISTS workflows ( - name TEXT PRIMARY KEY, - workflow_json TEXT NOT NULL - ) - """) - conn.commit() - print(f"数据库 '{DATABASE_FILE}' 已准备就绪。") - - -# --- API 处理器 --- -class PersistentAPIHandler(BaseHTTPRequestHandler): - - def _send_cors_headers(self): - """发送CORS头部""" - self.send_header('Access-Control-Allow-Origin', '*') - self.send_header('Access-Control-Allow-Methods', 'GET, POST, DELETE, OPTIONS') - self.send_header("Access-Control-Allow-Headers", "X-Requested-With, Content-Type") - - def do_OPTIONS(self): - """处理预检请求""" - self.send_response(200, "ok") - self._send_cors_headers() - self.end_headers() - - def do_GET(self): - """处理获取所有工作流的请求""" - if self.path == '/api/workflow': - try: - with sqlite3.connect(DATABASE_FILE) as conn: - conn.row_factory = sqlite3.Row # 让查询结果可以像字典一样访问 - cursor = conn.cursor() - cursor.execute("SELECT name, workflow_json FROM workflows") - rows = cursor.fetchall() - - # 将数据库行转换为前端期望的JSON格式 - workflows_list = [ - {"name": row["name"], "workflow": json.loads(row["workflow_json"])} - for row in rows - ] - - self.send_response(200) - self.send_header('Content-type', 'application/json') - self._send_cors_headers() - self.end_headers() - self.wfile.write(json.dumps(workflows_list).encode()) - - except Exception as e: - self._send_error(f"获取工作流时出错: {e}") - else: - self._send_error("路径未找到", 404) - - def do_POST(self): - """处理发布/更新工作流的请求""" - if self.path == '/api/workflow': - try: - content_length = int(self.headers['Content-Length']) - post_data = self.rfile.read(content_length) - data = json.loads(post_data) - - name = data.get('name') - workflow_data = data.get('workflow') - - if not name or not workflow_data: - return self._send_error("请求体中缺少 'name' 或 'workflow'", 400) - - workflow_json_str = json.dumps(workflow_data) - - with sqlite3.connect(DATABASE_FILE) as conn: - cursor = conn.cursor() - # 使用 INSERT OR REPLACE 来处理创建和更新,非常方便 - cursor.execute( - "INSERT OR REPLACE INTO workflows (name, workflow_json) VALUES (?, ?)", - (name, workflow_json_str) - ) - conn.commit() - - print(f"--- 已保存/更新工作流: {name} ---") - self.send_response(200) - self.send_header('Content-type', 'application/json') - self._send_cors_headers() - self.end_headers() - self.wfile.write(json.dumps({"status": "success", "name": name}).encode()) - - except Exception as e: - self._send_error(f"保存工作流时出错: {e}") - else: - self._send_error("路径未找到", 404) - - def do_DELETE(self): - """处理删除指定工作流的请求""" - parts = self.path.split('/api/workflow/') - if len(parts) == 2 and parts[1]: - workflow_name_to_delete = urllib.parse.unquote(parts[1]) - try: - with sqlite3.connect(DATABASE_FILE) as conn: - cursor = conn.cursor() - cursor.execute("DELETE FROM workflows WHERE name = ?", (workflow_name_to_delete,)) - conn.commit() - - # 检查是否真的有行被删除了 - if cursor.rowcount > 0: - print(f"--- 已删除工作流: {workflow_name_to_delete} ---") - self.send_response(200) - self.send_header('Content-type', 'application/json') - self._send_cors_headers() - self.end_headers() - self.wfile.write(json.dumps({"status": "deleted", "name": workflow_name_to_delete}).encode()) - else: - self._send_error("工作流未找到,无法删除", 404) - - except Exception as e: - self._send_error(f"删除工作流时出错: {e}") - else: - self._send_error("无效的删除路径", 400) - - def _send_error(self, message, code=500): - self.send_response(code) - self.send_header('Content-type', 'application/json') - self._send_cors_headers() - self.end_headers() - self.wfile.write(json.dumps({"status": "error", "message": message}).encode()) - - -# --- 主执行函数 --- -def run(server_class=HTTPServer, handler_class=PersistentAPIHandler, port=PORT): - # 1. 初始化数据库 - init_db() - - # 2. 启动服务器 - server_address = ('', port) - httpd = server_class(server_address, handler_class) - print(f"\n持久化API服务器已在 http://localhost:{port} 上启动...") - print(f"数据库文件: ./{DATABASE_FILE}") - print(" - POST /api/workflow -> 发布/更新工作流") - print(" - GET /api/workflow -> 获取所有工作流") - print(" - DELETE /api/workflow/ -> 删除指定工作流") - print("\n按 Ctrl+C 停止服务器。") - httpd.serve_forever() - - -if __name__ == '__main__': - run() \ No newline at end of file diff --git a/workflow_service/.env b/workflow_service/.env new file mode 100644 index 0000000..4c200a2 --- /dev/null +++ b/workflow_service/.env @@ -0,0 +1,13 @@ +# ComfyUI服务器地址 +COMFYUI_URL="ws://127.0.0.1:8188/ws" +COMFYUI_HTTP_URL="http://127.0.0.1:8188" +# 绝对路径!例如: /home/user/ComfyUI/input +COMFYUI_INPUT_DIR="F:\ComfyUI\input" +# 绝对路径!例如: /home/user/ComfyUI/output +COMFYUI_OUTPUT_DIR="F:\ComfyUI\output" + +# AWS S3 配置 +S3_BUCKET_NAME="modal-media-cache" +AWS_ACCESS_KEY_ID="AKIAYRH5NGRSWHN2L4M6" +AWS_SECRET_ACCESS_KEY="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" +AWS_REGION_NAME="ap-northeast-2" \ No newline at end of file diff --git a/workflow_service/comfyui_client.py b/workflow_service/comfyui_client.py new file mode 100644 index 0000000..956a166 --- /dev/null +++ b/workflow_service/comfyui_client.py @@ -0,0 +1,74 @@ +import websockets +import json +import uuid +import aiohttp +import random +from config import settings + + +async def queue_prompt(prompt: dict, client_id: str) -> str: + """通过HTTP POST将工作流任务提交到ComfyUI队列。""" + # [缓存破解] 在每个节点中添加一个随机数输入,强制重新执行 + for node_id in prompt: + prompt[node_id]["inputs"][f"cache_buster_{uuid.uuid4().hex}"] = random.random() + + payload = {"prompt": prompt, "client_id": client_id} + async with aiohttp.ClientSession() as session: + http_url = f"{settings.COMFYUI_HTTP_URL}/prompt" + async with session.post(http_url, json=payload) as response: + response.raise_for_status() + result = await response.json() + if "prompt_id" not in result: + raise Exception(f"Invalid response from ComfyUI /prompt endpoint: {result}") + return result["prompt_id"] + + +async def get_execution_results(prompt_id: str, client_id: str) -> dict: + """ + 通过WebSocket连接,聚合所有'executed'事件的输出, + 直到整个执行流程结束。 + """ + ws_url = f"{settings.COMFYUI_URL}?clientId={client_id}" + aggregated_outputs = {} + + async with websockets.connect(ws_url) as websocket: + while True: + try: + out = await websocket.recv() + if isinstance(out, str): + message = json.loads(out) + msg_type = message.get('type') + data = message.get('data') + + # 我们只关心与我们prompt_id相关的事件 + if data and data.get('prompt_id') == prompt_id: + if msg_type == 'executed': + # 聚合每个节点的输出 + node_id = data.get('node') + output_data = data.get('output') + if node_id and output_data: + aggregated_outputs[node_id] = output_data + print(f"Output received for node {node_id} (Prompt ID: {prompt_id})") + + # 判断执行是否结束 + # 官方UI使用 "executing" 且 node is null 作为结束标志 + elif msg_type == 'executing' and data.get('node') is None: + print(f"Execution finished for Prompt ID: {prompt_id}") + return aggregated_outputs + + except websockets.exceptions.ConnectionClosed as e: + print(f"WebSocket connection closed for {prompt_id}. Returning aggregated results. Error: {e}") + return aggregated_outputs # 连接关闭也视为结束 + except Exception as e: + print(f"An error occurred for {prompt_id}: {e}") + break + return aggregated_outputs + + +async def run_workflow(prompt: dict) -> dict: + """主协调函数:提交任务,然后等待结果。""" + client_id = str(uuid.uuid4()) + prompt_id = await queue_prompt(prompt, client_id) + print(f"Workflow successfully queued with Prompt ID: {prompt_id}") + results = await get_execution_results(prompt_id, client_id) + return results \ No newline at end of file diff --git a/workflow_service/config.py b/workflow_service/config.py new file mode 100644 index 0000000..cced6f6 --- /dev/null +++ b/workflow_service/config.py @@ -0,0 +1,14 @@ +from pydantic_settings import BaseSettings, SettingsConfigDict + +class Settings(BaseSettings): + model_config = SettingsConfigDict(env_file='.env', env_file_encoding='utf-8') + COMFYUI_URL: str = "ws://127.0.0.1:8188/ws" + COMFYUI_HTTP_URL: str = "http://127.0.0.1:8188" + COMFYUI_INPUT_DIR: str + COMFYUI_OUTPUT_DIR: str + S3_BUCKET_NAME: str + AWS_ACCESS_KEY_ID: str + AWS_SECRET_ACCESS_KEY: str + AWS_REGION_NAME: str + +settings = Settings() \ No newline at end of file diff --git a/workflow_service/database.py b/workflow_service/database.py new file mode 100644 index 0000000..fb73cde --- /dev/null +++ b/workflow_service/database.py @@ -0,0 +1,59 @@ +import aiosqlite +import json +import re + +DATABASE_FILE = "workflows_service.sqlite" + +async def init_db(): + async with aiosqlite.connect(DATABASE_FILE) as db: + await db.execute(""" + CREATE TABLE IF NOT EXISTS workflows ( + name TEXT PRIMARY KEY, + base_name TEXT NOT NULL, + version TEXT NOT NULL, + workflow_json TEXT NOT NULL + ) + """) + await db.commit() + print(f"数据库 '{DATABASE_FILE}' 已准备就绪。") + +async def save_workflow(name: str, workflow_json: str): + version_match = re.search(r"\[(20\d{12})\]$", name) + if not version_match: + raise ValueError("Workflow name must have a version suffix like [YYYYMMDDHHMMSS]") + version = version_match.group(1) + base_name = re.sub(r"\s*\[(20\d{12})\]$", "", name).strip() + async with aiosqlite.connect(DATABASE_FILE) as db: + await db.execute( + "INSERT OR REPLACE INTO workflows (name, base_name, version, workflow_json) VALUES (?, ?, ?, ?)", + (name, base_name, version, workflow_json) + ) + await db.commit() + +async def get_all_workflows() -> list[dict]: + async with aiosqlite.connect(DATABASE_FILE) as db: + db.row_factory = aiosqlite.Row + cursor = await db.execute("SELECT name, workflow_json FROM workflows") + rows = await cursor.fetchall() + return [{"name": row["name"], "workflow": json.loads(row["workflow_json"])} for row in rows] + +async def get_latest_workflow_by_base_name(base_name: str) -> dict | None: + async with aiosqlite.connect(DATABASE_FILE) as db: + db.row_factory = aiosqlite.Row + cursor = await db.execute("SELECT * FROM workflows WHERE base_name = ? ORDER BY version DESC LIMIT 1", (base_name,)) + row = await cursor.fetchone() + return dict(row) if row else None + +async def get_workflow_by_version(base_name: str, version: str) -> dict | None: + name = f"{base_name} [{version}]" + async with aiosqlite.connect(DATABASE_FILE) as db: + db.row_factory = aiosqlite.Row + cursor = await db.execute("SELECT * FROM workflows WHERE name = ?", (name,)) + row = await cursor.fetchone() + return dict(row) if row else None + +async def delete_workflow(name: str) -> bool: + async with aiosqlite.connect(DATABASE_FILE) as db: + cursor = await db.execute("DELETE FROM workflows WHERE name = ?", (name,)) + await db.commit() + return cursor.rowcount > 0 \ No newline at end of file diff --git a/workflow_service/main.py b/workflow_service/main.py new file mode 100644 index 0000000..4f91165 --- /dev/null +++ b/workflow_service/main.py @@ -0,0 +1,212 @@ +import uvicorn +from fastapi import FastAPI, Request, HTTPException, Path +from fastapi.responses import JSONResponse +from typing import Optional, List, Dict, Any, Set +import json, uuid, os, shutil, aiohttp, database, workflow_parser, comfyui_client, s3_client +from config import settings + +app = FastAPI(title="ComfyUI Workflow Service & Management API") + + +@app.on_event("startup") +async def startup_event(): await database.init_db(); os.makedirs(settings.COMFYUI_INPUT_DIR, + exist_ok=True); os.makedirs( + settings.COMFYUI_OUTPUT_DIR, exist_ok=True) + + +# --- Section 1: 工作流管理API (无改动) --- +# ... (代码与上一版完全相同) ... +BASE_MANAGEMENT_PATH = "/api/workflow" + + +@app.post(BASE_MANAGEMENT_PATH, status_code=200) +async def publish_workflow_endpoint(request: Request): + try: + data = await request.json(); name, wf_json = data.get("name"), data.get( + "workflow"); await database.save_workflow(name, json.dumps(wf_json)); return JSONResponse( + content={"status": "success", "message": f"Workflow '{name}' published."}, status_code=200) + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to save workflow: {e}") + + +@app.get(BASE_MANAGEMENT_PATH, response_model=List[dict]) +async def get_all_workflows_endpoint(): + try: + return await database.get_all_workflows() + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to get workflows: {e}") + + +@app.delete(f"{BASE_MANAGEMENT_PATH}/{{workflow_name:path}}") +async def delete_workflow_endpoint(workflow_name: str = Path(..., title="...")): + try: + success = await database.delete_workflow(workflow_name); + if success: + return {"status": "deleted", "name": workflow_name}; + else: + raise HTTPException(status_code=404, detail="Workflow not found") + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to delete workflow: {e}") + + +# --- Section 2: 工作流执行API (核心改动) --- +def get_files_in_dir(directory: str) -> Set[str]: + file_set = set() + for root, _, files in os.walk(directory): + for filename in files: + if not filename.startswith('.'): file_set.add(os.path.join(root, filename)) + return file_set + + +async def download_file_from_url(session: aiohttp.ClientSession, url: str, save_path: str): + async with session.get(url) as response: + response.raise_for_status(); + with open(save_path, 'wb') as f: + while True: + chunk = await response.content.read(8192); + if not chunk: + break; + f.write(chunk) + + +async def handle_file_upload(file_path: str, base_name: str) -> str: + """辅助函数:上传文件到S3并返回URL""" + s3_object_name = f"outputs/{base_name}/{uuid.uuid4()}_{os.path.basename(file_path)}" + return await s3_client.upload_file_to_s3(file_path, settings.S3_BUCKET_NAME, s3_object_name) + + +@app.post("/api/run/{base_name}") +async def execute_workflow_endpoint(base_name: str, request_data_raw: Dict[str, Any], version: Optional[str] = None): + cleanup_paths = [] + try: + # 1. 获取工作流和处理输入 (与上一版相同) + # ... + if version: + workflow_data = await database.get_workflow_by_version(base_name, version) + else: + workflow_data = await database.get_latest_workflow_by_base_name(base_name) + if not workflow_data: raise HTTPException(status_code=404, detail=f"Workflow '{base_name}' not found.") + workflow = json.loads(workflow_data['workflow_json']) + api_spec = workflow_parser.parse_api_spec(workflow) + request_data = {k.lower(): v for k, v in request_data_raw.items()} + async with aiohttp.ClientSession() as session: + for param_name, spec in api_spec["inputs"].items(): + if spec["type"] == "UploadFile" and param_name in request_data: + image_url = request_data[param_name] + if not isinstance(image_url, str) or not image_url.startswith('http'): raise HTTPException( + status_code=400, detail=f"Parameter '{param_name}' must be a valid URL.") + original_filename = image_url.split('/')[-1].split('?')[0]; + _, file_extension = os.path.splitext(original_filename) + if not file_extension: file_extension = '.dat' + filename = f"api_download_{uuid.uuid4()}{file_extension}" + save_path = os.path.join(settings.COMFYUI_INPUT_DIR, filename) + try: + await download_file_from_url(session, image_url, save_path); + request_data[param_name] = filename; cleanup_paths.append(save_path) + except Exception as e: + raise HTTPException(status_code=500, + detail=f"Failed to download file for '{param_name}' from {image_url}. Error: {e}") + + # 2. 执行前快照 + files_before = get_files_in_dir(settings.COMFYUI_OUTPUT_DIR) + + # 3. Patch, 转换并执行 (缓存破解已移入client) + patched_workflow = workflow_parser.patch_workflow(workflow, api_spec, request_data) + prompt_to_run = workflow_parser.convert_workflow_to_prompt_api_format(patched_workflow) + output_nodes = await comfyui_client.run_workflow(prompt_to_run) + + # 4. 执行后快照并计算差异 + files_after = get_files_in_dir(settings.COMFYUI_OUTPUT_DIR) + new_files = files_after - files_before + + # 5. [核心修正] 统一处理所有输出 + output_response = {} + processed_files = set() # 记录已通过文件系统快照处理的文件 + + # 5.1 首先处理所有新生成的文件 + if new_files: + s3_urls = [] + for file_path in new_files: + cleanup_paths.append(file_path) + processed_files.add(os.path.basename(file_path)) + try: + s3_urls.append(await handle_file_upload(file_path, base_name)) + except Exception as e: + print(f"Error uploading file {file_path} to S3: {e}") + if s3_urls: output_response["output_files"] = s3_urls + + # 5.2 然后处理WebSocket返回的非文件输出,并检查文本输出是否是文件路径 + for final_param_name, spec in api_spec["outputs"].items(): + node_id = spec["node_id"] + if node_id in output_nodes: + node_output = output_nodes[node_id] + original_output_name = spec["output_name"] + + if original_output_name in node_output: + output_value = node_output[original_output_name] + # 展开列表 + if isinstance(output_value, list): + output_value = output_value[0] if output_value else None + + # 检查文本输出是否是未被发现的文件路径 + if isinstance(output_value, str) and ( + '.png' in output_value or '.jpg' in output_value or '.mp4' in output_value or 'output' in output_value): + potential_filename = os.path.basename(output_value.replace('\\', '/')) + if potential_filename not in processed_files: + # 这是一个新的文件路径,尝试在output目录中找到它 + potential_path = os.path.join(settings.COMFYUI_OUTPUT_DIR, potential_filename) + if os.path.exists(potential_path): + print(f"Found extra file from text output: {potential_path}") + cleanup_paths.append(potential_path) + processed_files.add(potential_filename) + try: + s3_url = await handle_file_upload(potential_path, base_name) + # 将它也加入output_files列表 + if "output_files" not in output_response: output_response["output_files"] = [] + output_response["output_files"].append(s3_url) + except Exception as e: + print(f"Error uploading extra file {potential_path} to S3: {e}") + continue # 处理完毕,跳过将其作为文本输出 + output_response[final_param_name] = output_value + elif "text" in node_output: + output_value = node_output["text"] + # 如果不是文件,则作为普通值输出 + output_response[final_param_name] = output_value + + return output_response + + finally: + # 清理操作保持不变 + print(f"Cleaning up {len(cleanup_paths)} temporary files...") + for path in cleanup_paths: + try: + if os.path.exists(path): os.remove(path); print(f" - Deleted: {path}") + except Exception as e: + print(f" - Error deleting {path}: {e}") + + +# --- Section 3: 工作流元数据/规范API (无改动) --- +# ... (此部分代码与上一版完全相同) ... +@app.get("/api/spec/{base_name}") +async def get_workflow_spec_endpoint(base_name: str, version: Optional[str] = None): + # ... + if version: + workflow_data = await database.get_workflow_by_version(base_name, version) + else: + workflow_data = await database.get_latest_workflow_by_base_name(base_name) + if not workflow_data: + detail = f"Workflow '{base_name}'" + (f" with version '{version}'" if version else "") + " not found." + raise HTTPException(status_code=404, detail=detail) + try: + workflow = json.loads(workflow_data['workflow_json']) + return workflow_parser.parse_api_spec(workflow) + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to parse workflow specification: {e}") + + +@app.get("/") +def read_root(): return {"message": "Welcome to the ComfyUI Workflow Service API!"} + + +if __name__ == "__main__": + uvicorn.run(app, host="127.0.0.1", port=18000) diff --git a/workflow_service/requirements.txt b/workflow_service/requirements.txt new file mode 100644 index 0000000..c41e7c2 --- /dev/null +++ b/workflow_service/requirements.txt @@ -0,0 +1,7 @@ +fastapi +uvicorn[standard] +pydantic-settings +websockets +aiohttp +boto3 +aiosqlite \ No newline at end of file diff --git a/workflow_service/s3_client.py b/workflow_service/s3_client.py new file mode 100644 index 0000000..b39fbc5 --- /dev/null +++ b/workflow_service/s3_client.py @@ -0,0 +1,17 @@ +import boto3 +from config import settings +import asyncio + +s3_client = boto3.client('s3', aws_access_key_id=settings.AWS_ACCESS_KEY_ID, aws_secret_access_key=settings.AWS_SECRET_ACCESS_KEY, region_name=settings.AWS_REGION_NAME) + +async def upload_file_to_s3(file_path: str, bucket: str, object_name: str) -> str: + """从本地文件路径上传""" + loop = asyncio.get_event_loop() + await loop.run_in_executor(None, lambda: s3_client.upload_file(file_path, bucket, object_name)) + return f"https://cdn.roasmax.cn/{object_name}" + +async def upload_bytes_to_s3(file_bytes: bytes, bucket: str, object_name: str) -> str: + """直接从内存中的bytes上传 (新函数)""" + loop = asyncio.get_event_loop() + await loop.run_in_executor(None, lambda: s3_client.put_object(Body=file_bytes, Bucket=bucket, Key=object_name)) + return f"https://cdn.roasmax.cn/{object_name}" \ No newline at end of file diff --git a/workflow_service/workflow_parser.py b/workflow_service/workflow_parser.py new file mode 100644 index 0000000..f4e8c22 --- /dev/null +++ b/workflow_service/workflow_parser.py @@ -0,0 +1,151 @@ +import re +from collections import defaultdict + +API_INPUT_PREFIX = "INPUT_" +API_OUTPUT_PREFIX = "OUTPUT_" + + +def parse_api_spec(workflow_data: dict) -> dict: + """ + 解析工作流,并根据规范 '{基础名}_{属性名}_{可选计数}' 生成API参数名。 + """ + spec = {"inputs": {}, "outputs": {}} + if "nodes" not in workflow_data or not isinstance(workflow_data["nodes"], list): + raise ValueError("Invalid workflow format: 'nodes' key not found or is not a list.") + + nodes_map = {str(node["id"]): node for node in workflow_data["nodes"]} + + # 用于处理重名的计数器 + input_name_counter = defaultdict(int) + output_name_counter = defaultdict(int) + + for node_id, node in nodes_map.items(): + title = node.get("title") + if not title: continue + + # --- 处理API输入 --- + if title.startswith(API_INPUT_PREFIX): + base_name = title[len(API_INPUT_PREFIX):].lower() + if "inputs" in node: + for a_input in node.get("inputs", []): + # 只关心由用户直接控制的widget输入 + if a_input.get("link") is None and "widget" in a_input: + widget_name = a_input["widget"]["name"].lower() + + # 构建基础参数名 + param_name_candidate = f"{base_name}_{widget_name}" + + # 处理重名 + input_name_counter[param_name_candidate] += 1 + count = input_name_counter[param_name_candidate] + final_param_name = f"{param_name_candidate}_{count}" if count > 1 else param_name_candidate + + # 类型推断 + input_type_str = a_input.get("type", "STRING").upper() + if "COMBO" in a_input.get("type"): + # 如果是加载类节点,名字为'image'或'video'的widget是文件上传 + if a_input["widget"]["name"] in ["image", "video"]: + param_type = "UploadFile" + else: + param_type = "string" # 加载节点其他参数默认为string + elif "INT" in input_type_str: + param_type = "int" + elif "FLOAT" in input_type_str: + param_type = "float" + else: + param_type = "string" + + spec["inputs"][final_param_name] = { + "node_id": node_id, + "type": param_type, + "widget_name": a_input["widget"]["name"] # 保留原始widget名用于patch + } + + # --- 处理API输出 --- + elif title.startswith(API_OUTPUT_PREFIX): + base_name = title[len(API_OUTPUT_PREFIX):].lower() + + if "outputs" in node: + for an_output in node.get("outputs", []): + output_name = an_output["name"].lower() + + # 构建基础参数名 + param_name_candidate = f"{base_name}_{output_name}" + + # 处理重名 + output_name_counter[param_name_candidate] += 1 + count = output_name_counter[param_name_candidate] + final_param_name = f"{param_name_candidate}_{count}" if count > 1 else param_name_candidate + + spec["outputs"][final_param_name] = { + "node_id": node_id, + "class_type": node.get("type"), + "output_name": an_output["name"], # 保留原始输出名用于结果匹配 + "output_index": node["outputs"].index(an_output) # 保留索引 + } + + return spec + + +# --- 其他函数保持不变 --- + +def patch_workflow(workflow_data: dict, api_spec: dict, request_data: dict) -> dict: + # (此函数与上一版完全相同,无需改动) + if "nodes" not in workflow_data: raise ValueError("Invalid workflow format") + nodes_map = {str(node["id"]): node for node in workflow_data["nodes"]} + for param_name, value in request_data.items(): + if param_name in api_spec["inputs"]: + spec = api_spec["inputs"][param_name] + node_id = spec["node_id"] + if node_id not in nodes_map: continue + target_node = nodes_map[node_id] + widget_name_to_patch = spec["widget_name"] + widgets_values = target_node.get("widgets_values", {}) + if isinstance(widgets_values, dict): + widgets_values[widget_name_to_patch] = value + elif isinstance(widgets_values, list): + widget_cursor = 0 + for input_config in target_node.get("inputs", []): + if "widget" in input_config: + if input_config["widget"].get("name") == widget_name_to_patch: + if widget_cursor < len(widgets_values): + target_type = str + if spec['type'] == 'int': + target_type = int + elif spec['type'] == 'float': + target_type = float + widgets_values[widget_cursor] = target_type(value) + break + widget_cursor += 1 + target_node["widgets_values"] = widgets_values + workflow_data["nodes"] = list(nodes_map.values()) + return workflow_data + + +def convert_workflow_to_prompt_api_format(workflow_data: dict) -> dict: + # (此函数与上一版完全相同,无需改动) + if "nodes" not in workflow_data: raise ValueError("Invalid workflow format") + prompt_api_format, link_map = {}, {} + for link in workflow_data.get("links", []): + link_map[link[0]] = [str(link[1]), link[2]] + for node in workflow_data["nodes"]: + node_id = str(node["id"]) + inputs_dict = {} + widgets_values = node.get("widgets_values", []) + if isinstance(widgets_values, dict): + for key, val in widgets_values.items(): + if not isinstance(val, dict): inputs_dict[key] = val + elif isinstance(widgets_values, list): + widget_idx_counter = 0 + for input_config in node.get("inputs", []): + if "widget" in input_config: + if widget_idx_counter < len(widgets_values): + inputs_dict[input_config["name"]] = widgets_values[widget_idx_counter] + widget_idx_counter += 1 + for input_config in node.get("inputs", []): + if "link" in input_config and input_config["link"] is not None: + link_id = input_config["link"] + if link_id in link_map: + inputs_dict[input_config["name"]] = link_map[link_id] + prompt_api_format[node_id] = {"class_type": node["type"], "inputs": inputs_dict} + return prompt_api_format \ No newline at end of file