refactor: 更新ComfyWorkflow类的构造函数,简化工作流数据处理逻辑,并在相关API调用中应用新结构

This commit is contained in:
iHeyTang 2025-08-21 14:59:43 +08:00
parent 3c4c689532
commit fc265570d3
6 changed files with 66 additions and 51 deletions

View File

@ -258,12 +258,12 @@ class WorkflowQueueManager:
if not workflow_run:
raise Exception(f"找不到工作流运行记录: {workflow_run_id}")
workflow_data = json.loads(workflow_run["workflow_json"])
request_data = json.loads(workflow_run["request_data"])
workflow_data = json.loads(workflow_run.workflow_json)
request_data = json.loads(workflow_run.request_data)
workflow = ComfyWorkflow(workflow_run.workflow_name, workflow_data)
# 执行工作流
result = await _execute_prompt_on_server(
workflow_data, request_data, server, workflow_run_id
workflow, request_data, server, workflow_run_id
)
# 保存处理后的结果到数据库
@ -307,19 +307,19 @@ class WorkflowQueueManager:
result = {
"id": workflow_run_id,
"status": workflow_run["status"],
"created_at": workflow_run["created_at"],
"started_at": workflow_run["started_at"],
"completed_at": workflow_run["completed_at"],
"server_url": workflow_run["server_url"],
"error_message": workflow_run["error_message"],
"status": workflow_run.status,
"created_at": workflow_run.created_at,
"started_at": workflow_run.started_at,
"completed_at": workflow_run.completed_at,
"server_url": workflow_run.server_url,
"error_message": workflow_run.error_message,
"nodes": nodes,
}
# 如果任务完成,从数据库获取结果
if workflow_run["status"] == "completed" and workflow_run.get("result"):
if workflow_run.status == "completed" and workflow_run.result:
try:
result["result"] = json.loads(workflow_run["result"])
result["result"] = json.loads(workflow_run.result)
except (json.JSONDecodeError, TypeError):
result["result"] = None
@ -369,7 +369,7 @@ queue_manager = WorkflowQueueManager(monitor_interval=5)
async def _execute_prompt_on_server(
workflow_data: dict,
workflow: ComfyWorkflow,
request_data: dict,
server: ComfyUIServerInfo,
workflow_run_id: str,
@ -381,8 +381,7 @@ async def _execute_prompt_on_server(
client_id = str(uuid.uuid4())
# 构建prompt
flow = ComfyWorkflow(workflow_data)
prompt = await flow.build_prompt(server, request_data)
prompt = await workflow.build_prompt(server, request_data)
# 更新工作流运行状态记录prompt_id和client_id
await update_workflow_run_status(
@ -394,7 +393,7 @@ async def _execute_prompt_on_server(
)
# 提交到ComfyUI
prompt_id = await _queue_prompt(workflow_data, prompt, client_id, server.http_url)
prompt_id = await _queue_prompt(workflow, prompt, client_id, server.http_url)
# 更新prompt_id
await update_workflow_run_status(
@ -407,13 +406,13 @@ async def _execute_prompt_on_server(
# 获取执行结果,现在支持节点级别的状态跟踪
results = await _get_execution_results(
workflow_data, prompt_id, client_id, server.ws_url, workflow_run_id
workflow, prompt_id, client_id, server.ws_url, workflow_run_id
)
return results
async def _queue_prompt(
workflow: dict, prompt: dict, client_id: str, http_url: str
workflow: ComfyWorkflow, prompt: dict, client_id: str, http_url: str
) -> str:
"""通过HTTP POST将工作流任务提交到指定的ComfyUI服务器。"""
for node_id in prompt:
@ -423,7 +422,7 @@ async def _queue_prompt(
"client_id": client_id,
"extra_data": {
"api_key_comfy_org": "",
"extra_pnginfo": {"workflow": workflow},
"extra_pnginfo": {"workflow": workflow.workflow_data.model_dump()},
},
}
async with aiohttp.ClientSession(timeout=ClientTimeout(total=90)) as session:
@ -443,7 +442,7 @@ async def _queue_prompt(
async def _get_execution_results(
workflow_data: Dict,
workflow: ComfyWorkflow,
prompt_id: str,
client_id: str,
ws_url: str,
@ -511,17 +510,17 @@ async def _get_execution_results(
node = next(
(
x
for x in workflow_data["nodes"]
if str(x["id"]) == node_id
for x in workflow.workflow_data.nodes
if str(x.id) == node_id
),
None,
)
if (
node
and node.get("title", "")
and node["title"].startswith(API_OUTPUT_PREFIX)
and node.title
and node.title.startswith(API_OUTPUT_PREFIX)
):
title = node["title"].replace(API_OUTPUT_PREFIX, "")
title = node.title.replace(API_OUTPUT_PREFIX, "")
aggregated_outputs[title] = output_data
logger.info(
f"收到节点 {node_id} 的输出 (Prompt ID: {prompt_id})"

View File

@ -16,6 +16,7 @@ API_OUTPUT_PREFIX = "OUTPUT_"
class ComfyWorkflowNodeWidget(BaseModel):
"""节点控件定义"""
name: str
@ -52,17 +53,20 @@ class ComfyWorkflowNodeProperties(BaseModel):
class ComfyAPIFieldSpec(BaseModel):
"""API字段规范"""
type: str
widget_name: str
class ComfyAPIOutputSpec(BaseModel):
"""API输出规范"""
output_name: str
class ComfyAPINodeSpec(BaseModel):
"""API节点规范"""
node_id: str
class_type: str
inputs: Optional[Dict[str, ComfyAPIFieldSpec]] = {}
@ -71,12 +75,14 @@ class ComfyAPINodeSpec(BaseModel):
class ComfyAPISpec(BaseModel):
"""API规范定义"""
inputs: Dict[str, ComfyAPINodeSpec] = {}
outputs: Dict[str, ComfyAPINodeSpec] = {}
class ComfyJSONSchemaProperty(BaseModel):
"""JSON Schema属性定义"""
type: str
title: Optional[str] = None
description: Optional[str] = None
@ -91,6 +97,7 @@ class ComfyJSONSchemaProperty(BaseModel):
class ComfyJSONSchemaNode(BaseModel):
"""JSON Schema节点定义"""
type: str = "object"
title: str
description: str
@ -101,6 +108,7 @@ class ComfyJSONSchemaNode(BaseModel):
class ComfyJSONSchema(BaseModel):
"""JSON Schema定义"""
schema: str = "http://json-schema.org/draft-07/schema#"
type: str = "object"
title: str = "ComfyUI Workflow Input Schema"
@ -160,13 +168,13 @@ class ComfyWorkflow:
ComfyUI工作流处理器类提供面向对象的工作流管理和处理能力
"""
def __init__(self, workflow_data: dict, workflow_name: str = None):
def __init__(self, workflow_name: str, workflow_data: dict):
"""
初始化工作流实例
Args:
workflow_data: 工作流数据
workflow_name: 工作流名称
workflow_data: 工作流数据
"""
self.workflow_data = ComfyWorkflowDataSpec.model_validate(workflow_data)
self.workflow_name = workflow_name
@ -262,10 +270,7 @@ class ComfyWorkflow:
# 如果节点名不在inputs中创建新的节点条目
if node_name not in inputs:
inputs[node_name] = ComfyAPINodeSpec(
node_id=node_id,
class_type=node.type,
inputs={},
outputs={}
node_id=node_id, class_type=node.type, inputs={}, outputs={}
)
if node.inputs:
@ -280,7 +285,7 @@ class ComfyWorkflow:
# 直接使用widget_name作为字段名
inputs[node_name].inputs[widget_name] = ComfyAPIFieldSpec(
type=self._get_param_type(a_input.type or "STRING"),
widget_name=a_input.widget.name
widget_name=a_input.widget.name,
)
elif title.startswith(API_OUTPUT_PREFIX):
@ -290,10 +295,7 @@ class ComfyWorkflow:
# 如果节点名不在outputs中创建新的节点条目
if node_name not in outputs:
outputs[node_name] = ComfyAPINodeSpec(
node_id=node_id,
class_type=node.type,
inputs={},
outputs={}
node_id=node_id, class_type=node.type, inputs={}, outputs={}
)
if node.outputs:
@ -340,11 +342,15 @@ class ComfyWorkflow:
# 获取节点的当前值作为默认值
# 创建空的link_map因为输入节点通常没有连接
empty_link_map = {}
node_current_inputs = self._get_node_inputs(node, empty_link_map)
node_current_inputs = self._get_node_inputs(
node, empty_link_map
)
current_value = node_current_inputs.get(widget_name)
# 生成字段的 JSON Schema 定义,包含默认值
field_schema = self._generate_field_schema_model(a_input, current_value)
field_schema = self._generate_field_schema_model(
a_input, current_value
)
node_properties[widget_name] = field_schema
# 如果字段是必需的添加到required列表
@ -357,7 +363,7 @@ class ComfyWorkflow:
title=f"输入节点: {node_name}",
description=f"节点类型: {node.type}",
properties=node_properties,
required=node_required
required=node_required,
)
# 如果节点有必需字段,将整个节点标记为必需
@ -405,7 +411,10 @@ class ComfyWorkflow:
target_widget_index = -1
for i, widget_input in enumerate(widget_inputs):
if widget_input.widget and widget_input.widget.name == widget_name_to_patch:
if (
widget_input.widget
and widget_input.widget.name == widget_name_to_patch
):
target_widget_index = i
break
@ -416,7 +425,9 @@ class ComfyWorkflow:
continue
# 确保 `widgets_values` 存在且为列表
if not target_node.widgets_values or not isinstance(target_node.widgets_values, list):
if not target_node.widgets_values or not isinstance(
target_node.widgets_values, list
):
# 如果不存在或格式错误根据widget数量创建一个占位符列表
target_node.widgets_values = [None] * len(widget_inputs)
@ -447,7 +458,9 @@ class ComfyWorkflow:
patched_workflow_data.nodes = list(nodes_map.values())
# 对workflow_data中的nodes进行排序按照id从小到大
patched_workflow_data.nodes = sorted(patched_workflow_data.nodes, key=lambda x: x.id)
patched_workflow_data.nodes = sorted(
patched_workflow_data.nodes, key=lambda x: x.id
)
# 转换为字典格式以保持向后兼容
return patched_workflow_data.model_dump()
@ -655,7 +668,9 @@ class ComfyWorkflow:
else:
return "string"
def _get_node_inputs(self, node: Union[ComfyWorkflowNode, dict], link_map: dict) -> dict:
def _get_node_inputs(
self, node: Union[ComfyWorkflowNode, dict], link_map: dict
) -> dict:
"""
获取节点的输入字段值
@ -681,8 +696,8 @@ class ComfyWorkflow:
# 1. 处理控件输入 (widgets)
widgets_values = node_model.widgets_values or []
widget_cursor = 0
for input_config in (node_model.inputs or []):
for input_config in node_model.inputs or []:
# 如果是widget并且有对应的widgets_values
if input_config.widget and widget_cursor < len(widgets_values):
widget_name = input_config.widget.name
@ -696,7 +711,7 @@ class ComfyWorkflow:
widget_cursor += 1
# 2. 处理连接输入 (links)
for input_config in (node_model.inputs or []):
for input_config in node_model.inputs or []:
if input_config.link is not None:
link_id = input_config.link
if link_id in link_map:

View File

@ -219,11 +219,11 @@ async def update_workflow_run_node_status(
await session.commit()
async def get_workflow_run(workflow_run_id: str) -> Optional[dict]:
async def get_workflow_run(workflow_run_id: str) -> Optional[WorkflowRun]:
"""获取工作流运行记录"""
async with AsyncSessionLocal() as session:
workflow_run = await session.get(WorkflowRun, workflow_run_id)
return workflow_run.to_dict() if workflow_run else None
return workflow_run
async def get_workflow_run_nodes(workflow_run_id: str) -> List[dict]:

View File

@ -40,7 +40,7 @@ async def run_workflow(
raise HTTPException(status_code=404, detail=detail)
workflow = json.loads(workflow_data["workflow_json"])
flow = ComfyWorkflow(workflow, workflow_name)
flow = ComfyWorkflow(workflow_name, workflow)
# 提交到队列
workflow_run_id = await queue_manager.add_task(workflow=flow, request_data=data)

View File

@ -33,7 +33,7 @@ async def model_with_multi_dress(
raise HTTPException(status_code=404, detail=detail)
workflow = json.loads(workflow_data["workflow_json"])
flow = ComfyWorkflow(workflow, workflow_name)
flow = ComfyWorkflow(workflow_name, workflow)
# 将请求拆分为多个请求
batch_data = _convert(data)

View File

@ -92,8 +92,9 @@ async def get_one_workflow_endpoint(base_name: str, version: Optional[str] = Non
try:
workflow = json.loads(workflow_data["workflow_json"])
flow = ComfyWorkflow(workflow)
flow = ComfyWorkflow(base_name, workflow)
return {
"name": flow.workflow_name,
"workflow": flow.workflow_data,
"api_spec": flow.get_api_spec(),
"inputs_json_schema": flow.get_inputs_json_schema(),