ComfyUI-WorkflowPublisher/__init__.py

335 lines
15 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import json
import re
import urllib
import aiohttp
from aiohttp import web
from server import PromptServer
# 获取当前插件的目录
NODE_DIR = os.path.dirname(os.path.abspath(__file__))
CONFIG_FILE = os.path.join(NODE_DIR, "config.json")
# 默认配置
DEFAULT_CONFIG = {
"host": ""
}
def load_config():
"""加载配置文件,如果文件不存在则创建"""
if not os.path.exists(CONFIG_FILE):
with open(CONFIG_FILE, 'w', encoding='utf-8') as f:
json.dump(DEFAULT_CONFIG, f, indent=4, ensure_ascii=False)
return DEFAULT_CONFIG
try:
with open(CONFIG_FILE, 'r', encoding='utf-8') as f:
config = json.load(f)
# 确保所有必需的字段都存在
for key in DEFAULT_CONFIG:
if key not in config:
config[key] = DEFAULT_CONFIG[key]
return config
except (json.JSONDecodeError, FileNotFoundError):
# 如果配置文件损坏或不存在,重新创建
with open(CONFIG_FILE, 'w', encoding='utf-8') as f:
json.dump(DEFAULT_CONFIG, f, indent=4, ensure_ascii=False)
return DEFAULT_CONFIG
def save_config(config_data):
"""保存配置到文件"""
with open(CONFIG_FILE, 'w', encoding='utf-8') as f:
json.dump(config_data, f, indent=4, ensure_ascii=False)
# 创建一个虚拟节点类虽然它不会出现在图表中但这是ComfyUI加载自定义节点的标准方式
class WorkflowPublisherNode:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
# 这个节点实际上是UI扩展不处理任何输入输出但需要定义这些方法
return {
"required": {},
}
RETURN_TYPES = ()
FUNCTION = "do_nothing"
OUTPUT_NODE = True
CATEGORY = "utilities"
def do_nothing(self):
return ()
# -----------------
# API 端点定义
# -----------------
# 添加自定义API路由
@PromptServer.instance.routes.get("/publisher/settings")
async def get_publisher_settings(request):
"""获取发布器设置"""
config = load_config()
response = web.json_response(config)
response.headers['Access-Control-Allow-Origin'] = '*'
response.headers['Access-Control-Allow-Methods'] = 'GET, POST, PUT, DELETE, OPTIONS'
response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization'
return response
@PromptServer.instance.routes.post("/publisher/settings")
async def save_publisher_settings(request):
"""保存发布器设置"""
try:
data = await request.json()
host = data.get("host", "")
config = load_config()
config["host"] = host
save_config(config)
response = web.json_response({"status": "success", "message": "Settings saved"})
response.headers['Access-Control-Allow-Origin'] = '*'
response.headers['Access-Control-Allow-Methods'] = 'GET, POST, PUT, DELETE, OPTIONS'
response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization'
return response
except Exception as e:
response = web.json_response({"status": "error", "message": str(e)}, status=500)
response.headers['Access-Control-Allow-Origin'] = '*'
response.headers['Access-Control-Allow-Methods'] = 'GET, POST, PUT, DELETE, OPTIONS'
response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization'
return response
@PromptServer.instance.routes.post("/publisher/publish")
async def publish_workflow_handler(request):
"""处理工作流发布请求"""
try:
data = await request.json()
workflow = data.get("workflow")
workflow_name = data.get("name")
if not workflow or not workflow_name:
response = web.json_response({"status": "error", "message": "Missing workflow data or name"}, status=400)
response.headers['Access-Control-Allow-Origin'] = '*'
response.headers['Access-Control-Allow-Methods'] = 'GET, POST, PUT, DELETE, OPTIONS'
response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization'
return response
config = load_config()
host = config.get("host", "")
if not host:
response = web.json_response({"status": "error", "message": "Host not configured"}, status=400)
response.headers['Access-Control-Allow-Origin'] = '*'
response.headers['Access-Control-Allow-Methods'] = 'GET, POST, PUT, DELETE, OPTIONS'
response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization'
return response
# 构建完整的URL - 直接使用 host + /api/workflow
full_url = host.rstrip('/') + '/api/workflow'
# 准备要发送到目标API的数据
payload = {
"name": workflow_name,
"workflow": workflow
}
headers = {'Content-Type': 'application/json'}
# 使用 aiohttp 发送异步请求
async with aiohttp.ClientSession() as session:
async with session.post(full_url, json=payload, headers=headers) as response:
response_text = await response.text()
if response.status == 200 or response.status == 201:
result_response = web.json_response(
{"status": "success", "message": "Workflow published successfully!", "details": response_text})
result_response.headers['Access-Control-Allow-Origin'] = '*'
result_response.headers['Access-Control-Allow-Methods'] = 'GET, POST, PUT, DELETE, OPTIONS'
result_response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization'
return result_response
else:
result_response = web.json_response({
"status": "error",
"message": f"Failed to publish workflow. Target API returned status {response.status}",
"details": response_text
}, status=500)
result_response.headers['Access-Control-Allow-Origin'] = '*'
result_response.headers['Access-Control-Allow-Methods'] = 'GET, POST, PUT, DELETE, OPTIONS'
result_response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization'
return result_response
except Exception as e:
import traceback
traceback.print_exc()
result_response = web.json_response({"status": "error", "message": str(e)}, status=500)
result_response.headers['Access-Control-Allow-Origin'] = '*'
result_response.headers['Access-Control-Allow-Methods'] = 'GET, POST, PUT, DELETE, OPTIONS'
result_response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization'
return result_response
@PromptServer.instance.routes.get("/publisher/workflows")
async def get_workflows_from_server(request):
"""从目标服务器获取工作流列表,并按基础名称进行分组"""
try:
config = load_config()
host = config.get("host", "")
if not host:
response = web.json_response({"status": "error", "message": "Host not configured"}, status=400)
response.headers['Access-Control-Allow-Origin'] = '*'
response.headers['Access-Control-Allow-Methods'] = 'GET, POST, PUT, DELETE, OPTIONS'
response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization'
return response
# 构建完整的URL - 直接使用 host + /api/workflow
full_url = host.rstrip('/') + '/api/workflow'
async with aiohttp.ClientSession() as session:
async with session.get(full_url) as response:
if response.status != 200:
response_text = await response.text()
result_response = web.json_response({
"status": "error",
"message": f"Failed to fetch workflows. Target API returned status {response.status}",
"details": response_text
}, status=500)
result_response.headers['Access-Control-Allow-Origin'] = '*'
result_response.headers['Access-Control-Allow-Methods'] = 'GET, POST, PUT, DELETE, OPTIONS'
result_response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization'
return result_response
workflows = await response.json()
# --- [核心逻辑] 对工作流进行分组和版本化 ---
grouped_workflows = {}
# 正则表达式匹配 '任意字符 [YYYYMMDDHHMMSS]' 格式
version_pattern = re.compile(r"^(.*) \[(20\d{12})\]$")
for wf in workflows:
match = version_pattern.match(wf.get("name", ""))
if match:
base_name = match.group(1).strip()
version = match.group(2)
else:
# 对于没有版本号的旧工作流,将整个名称作为基础名称
base_name = wf.get("name", "Unnamed Workflow")
version = "N/A" # 无版本信息
if base_name not in grouped_workflows:
grouped_workflows[base_name] = []
grouped_workflows[base_name].append({
"version": version,
"workflow": wf.get("workflow")
})
# 对每个工作流的版本进行降序排序(最新版本在前)
for base_name in grouped_workflows:
grouped_workflows[base_name].sort(key=lambda x: x['version'], reverse=True)
result_response = web.json_response(grouped_workflows)
result_response.headers['Access-Control-Allow-Origin'] = '*'
result_response.headers['Access-Control-Allow-Methods'] = 'GET, POST, PUT, DELETE, OPTIONS'
result_response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization'
return result_response
except Exception as e:
import traceback
traceback.print_exc()
result_response = web.json_response({"status": "error", "message": str(e)}, status=500)
result_response.headers['Access-Control-Allow-Origin'] = '*'
result_response.headers['Access-Control-Allow-Methods'] = 'GET, POST, PUT, DELETE, OPTIONS'
result_response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization'
return result_response
@PromptServer.instance.routes.post("/publisher/workflow/delete")
async def delete_workflow_version(request):
"""接收前端的删除请求并将其转发到目标API服务器"""
try:
data = await request.json()
workflow_name = data.get("name")
if not workflow_name:
response = web.json_response({"status": "error", "message": "Missing workflow name"}, status=400)
response.headers['Access-Control-Allow-Origin'] = '*'
response.headers['Access-Control-Allow-Methods'] = 'GET, POST, PUT, DELETE, OPTIONS'
response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization'
return response
config = load_config()
host = config.get("host", "")
if not host:
response = web.json_response({"status": "error", "message": "Host not configured"}, status=400)
response.headers['Access-Control-Allow-Origin'] = '*'
response.headers['Access-Control-Allow-Methods'] = 'GET, POST, PUT, DELETE, OPTIONS'
response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization'
return response
# 构建完整的URL - 直接使用 host + /api/workflow
base_url = host.rstrip('/') + '/api/workflow'
# 构建目标URL需要对工作流名称进行URL编码
delete_url = f"{base_url}/{urllib.parse.quote(workflow_name)}"
async with aiohttp.ClientSession() as session:
async with session.delete(delete_url) as response:
if response.status == 200:
result_response = web.json_response({"status": "success", "message": "Workflow version deleted"})
result_response.headers['Access-Control-Allow-Origin'] = '*'
result_response.headers['Access-Control-Allow-Methods'] = 'GET, POST, PUT, DELETE, OPTIONS'
result_response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization'
return result_response
else:
details = await response.text()
result_response = web.json_response({
"status": "error",
"message": f"Target API failed to delete. Status: {response.status}",
"details": details
}, status=response.status)
result_response.headers['Access-Control-Allow-Origin'] = '*'
result_response.headers['Access-Control-Allow-Methods'] = 'GET, POST, PUT, DELETE, OPTIONS'
result_response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization'
return result_response
except Exception as e:
result_response = web.json_response({"status": "error", "message": str(e)}, status=500)
result_response.headers['Access-Control-Allow-Origin'] = '*'
result_response.headers['Access-Control-Allow-Methods'] = 'GET, POST, PUT, DELETE, OPTIONS'
result_response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization'
return result_response
# 添加OPTIONS请求处理用于CORS预检请求
@PromptServer.instance.routes.options("/publisher/{path:.*}")
async def handle_options(request):
"""处理CORS预检请求"""
response = web.Response()
response.headers['Access-Control-Allow-Origin'] = '*'
response.headers['Access-Control-Allow-Methods'] = 'GET, POST, PUT, DELETE, OPTIONS'
response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization'
response.headers['Access-Control-Max-Age'] = '86400' # 24小时
return response
# -----------------
# ComfyUI 注册
# -----------------
# 告诉 ComfyUI 我们有一个包含JS文件的web目录
WEB_DIRECTORY = "js"
# 节点映射
NODE_CLASS_MAPPINGS = {
# "WorkflowPublisher": WorkflowPublisherNode
}
# 节点显示名称映射
NODE_DISPLAY_NAME_MAPPINGS = {
# "WorkflowPublisher": "Workflow Publisher (UI)"
}
print("✅ Loaded Workflow Publisher Node")