refactor: 更新ComfyWorkflow类的构造函数,简化工作流数据处理逻辑,并在相关API调用中应用新结构
This commit is contained in:
parent
3c4c689532
commit
fc265570d3
|
|
@ -258,12 +258,12 @@ class WorkflowQueueManager:
|
|||
if not workflow_run:
|
||||
raise Exception(f"找不到工作流运行记录: {workflow_run_id}")
|
||||
|
||||
workflow_data = json.loads(workflow_run["workflow_json"])
|
||||
request_data = json.loads(workflow_run["request_data"])
|
||||
|
||||
workflow_data = json.loads(workflow_run.workflow_json)
|
||||
request_data = json.loads(workflow_run.request_data)
|
||||
workflow = ComfyWorkflow(workflow_run.workflow_name, workflow_data)
|
||||
# 执行工作流
|
||||
result = await _execute_prompt_on_server(
|
||||
workflow_data, request_data, server, workflow_run_id
|
||||
workflow, request_data, server, workflow_run_id
|
||||
)
|
||||
|
||||
# 保存处理后的结果到数据库
|
||||
|
|
@ -307,19 +307,19 @@ class WorkflowQueueManager:
|
|||
|
||||
result = {
|
||||
"id": workflow_run_id,
|
||||
"status": workflow_run["status"],
|
||||
"created_at": workflow_run["created_at"],
|
||||
"started_at": workflow_run["started_at"],
|
||||
"completed_at": workflow_run["completed_at"],
|
||||
"server_url": workflow_run["server_url"],
|
||||
"error_message": workflow_run["error_message"],
|
||||
"status": workflow_run.status,
|
||||
"created_at": workflow_run.created_at,
|
||||
"started_at": workflow_run.started_at,
|
||||
"completed_at": workflow_run.completed_at,
|
||||
"server_url": workflow_run.server_url,
|
||||
"error_message": workflow_run.error_message,
|
||||
"nodes": nodes,
|
||||
}
|
||||
|
||||
# 如果任务完成,从数据库获取结果
|
||||
if workflow_run["status"] == "completed" and workflow_run.get("result"):
|
||||
if workflow_run.status == "completed" and workflow_run.result:
|
||||
try:
|
||||
result["result"] = json.loads(workflow_run["result"])
|
||||
result["result"] = json.loads(workflow_run.result)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
result["result"] = None
|
||||
|
||||
|
|
@ -369,7 +369,7 @@ queue_manager = WorkflowQueueManager(monitor_interval=5)
|
|||
|
||||
|
||||
async def _execute_prompt_on_server(
|
||||
workflow_data: dict,
|
||||
workflow: ComfyWorkflow,
|
||||
request_data: dict,
|
||||
server: ComfyUIServerInfo,
|
||||
workflow_run_id: str,
|
||||
|
|
@ -381,8 +381,7 @@ async def _execute_prompt_on_server(
|
|||
client_id = str(uuid.uuid4())
|
||||
|
||||
# 构建prompt
|
||||
flow = ComfyWorkflow(workflow_data)
|
||||
prompt = await flow.build_prompt(server, request_data)
|
||||
prompt = await workflow.build_prompt(server, request_data)
|
||||
|
||||
# 更新工作流运行状态,记录prompt_id和client_id
|
||||
await update_workflow_run_status(
|
||||
|
|
@ -394,7 +393,7 @@ async def _execute_prompt_on_server(
|
|||
)
|
||||
|
||||
# 提交到ComfyUI
|
||||
prompt_id = await _queue_prompt(workflow_data, prompt, client_id, server.http_url)
|
||||
prompt_id = await _queue_prompt(workflow, prompt, client_id, server.http_url)
|
||||
|
||||
# 更新prompt_id
|
||||
await update_workflow_run_status(
|
||||
|
|
@ -407,13 +406,13 @@ async def _execute_prompt_on_server(
|
|||
|
||||
# 获取执行结果,现在支持节点级别的状态跟踪
|
||||
results = await _get_execution_results(
|
||||
workflow_data, prompt_id, client_id, server.ws_url, workflow_run_id
|
||||
workflow, prompt_id, client_id, server.ws_url, workflow_run_id
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
async def _queue_prompt(
|
||||
workflow: dict, prompt: dict, client_id: str, http_url: str
|
||||
workflow: ComfyWorkflow, prompt: dict, client_id: str, http_url: str
|
||||
) -> str:
|
||||
"""通过HTTP POST将工作流任务提交到指定的ComfyUI服务器。"""
|
||||
for node_id in prompt:
|
||||
|
|
@ -423,7 +422,7 @@ async def _queue_prompt(
|
|||
"client_id": client_id,
|
||||
"extra_data": {
|
||||
"api_key_comfy_org": "",
|
||||
"extra_pnginfo": {"workflow": workflow},
|
||||
"extra_pnginfo": {"workflow": workflow.workflow_data.model_dump()},
|
||||
},
|
||||
}
|
||||
async with aiohttp.ClientSession(timeout=ClientTimeout(total=90)) as session:
|
||||
|
|
@ -443,7 +442,7 @@ async def _queue_prompt(
|
|||
|
||||
|
||||
async def _get_execution_results(
|
||||
workflow_data: Dict,
|
||||
workflow: ComfyWorkflow,
|
||||
prompt_id: str,
|
||||
client_id: str,
|
||||
ws_url: str,
|
||||
|
|
@ -511,17 +510,17 @@ async def _get_execution_results(
|
|||
node = next(
|
||||
(
|
||||
x
|
||||
for x in workflow_data["nodes"]
|
||||
if str(x["id"]) == node_id
|
||||
for x in workflow.workflow_data.nodes
|
||||
if str(x.id) == node_id
|
||||
),
|
||||
None,
|
||||
)
|
||||
if (
|
||||
node
|
||||
and node.get("title", "")
|
||||
and node["title"].startswith(API_OUTPUT_PREFIX)
|
||||
and node.title
|
||||
and node.title.startswith(API_OUTPUT_PREFIX)
|
||||
):
|
||||
title = node["title"].replace(API_OUTPUT_PREFIX, "")
|
||||
title = node.title.replace(API_OUTPUT_PREFIX, "")
|
||||
aggregated_outputs[title] = output_data
|
||||
logger.info(
|
||||
f"收到节点 {node_id} 的输出 (Prompt ID: {prompt_id})"
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ API_OUTPUT_PREFIX = "OUTPUT_"
|
|||
|
||||
class ComfyWorkflowNodeWidget(BaseModel):
|
||||
"""节点控件定义"""
|
||||
|
||||
name: str
|
||||
|
||||
|
||||
|
|
@ -52,17 +53,20 @@ class ComfyWorkflowNodeProperties(BaseModel):
|
|||
|
||||
class ComfyAPIFieldSpec(BaseModel):
|
||||
"""API字段规范"""
|
||||
|
||||
type: str
|
||||
widget_name: str
|
||||
|
||||
|
||||
class ComfyAPIOutputSpec(BaseModel):
|
||||
"""API输出规范"""
|
||||
|
||||
output_name: str
|
||||
|
||||
|
||||
class ComfyAPINodeSpec(BaseModel):
|
||||
"""API节点规范"""
|
||||
|
||||
node_id: str
|
||||
class_type: str
|
||||
inputs: Optional[Dict[str, ComfyAPIFieldSpec]] = {}
|
||||
|
|
@ -71,12 +75,14 @@ class ComfyAPINodeSpec(BaseModel):
|
|||
|
||||
class ComfyAPISpec(BaseModel):
|
||||
"""API规范定义"""
|
||||
|
||||
inputs: Dict[str, ComfyAPINodeSpec] = {}
|
||||
outputs: Dict[str, ComfyAPINodeSpec] = {}
|
||||
|
||||
|
||||
class ComfyJSONSchemaProperty(BaseModel):
|
||||
"""JSON Schema属性定义"""
|
||||
|
||||
type: str
|
||||
title: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
|
|
@ -91,6 +97,7 @@ class ComfyJSONSchemaProperty(BaseModel):
|
|||
|
||||
class ComfyJSONSchemaNode(BaseModel):
|
||||
"""JSON Schema节点定义"""
|
||||
|
||||
type: str = "object"
|
||||
title: str
|
||||
description: str
|
||||
|
|
@ -101,6 +108,7 @@ class ComfyJSONSchemaNode(BaseModel):
|
|||
|
||||
class ComfyJSONSchema(BaseModel):
|
||||
"""JSON Schema定义"""
|
||||
|
||||
schema: str = "http://json-schema.org/draft-07/schema#"
|
||||
type: str = "object"
|
||||
title: str = "ComfyUI Workflow Input Schema"
|
||||
|
|
@ -160,13 +168,13 @@ class ComfyWorkflow:
|
|||
ComfyUI工作流处理器类,提供面向对象的工作流管理和处理能力
|
||||
"""
|
||||
|
||||
def __init__(self, workflow_data: dict, workflow_name: str = None):
|
||||
def __init__(self, workflow_name: str, workflow_data: dict):
|
||||
"""
|
||||
初始化工作流实例
|
||||
|
||||
Args:
|
||||
workflow_data: 工作流数据
|
||||
workflow_name: 工作流名称
|
||||
workflow_data: 工作流数据
|
||||
"""
|
||||
self.workflow_data = ComfyWorkflowDataSpec.model_validate(workflow_data)
|
||||
self.workflow_name = workflow_name
|
||||
|
|
@ -262,10 +270,7 @@ class ComfyWorkflow:
|
|||
# 如果节点名不在inputs中,创建新的节点条目
|
||||
if node_name not in inputs:
|
||||
inputs[node_name] = ComfyAPINodeSpec(
|
||||
node_id=node_id,
|
||||
class_type=node.type,
|
||||
inputs={},
|
||||
outputs={}
|
||||
node_id=node_id, class_type=node.type, inputs={}, outputs={}
|
||||
)
|
||||
|
||||
if node.inputs:
|
||||
|
|
@ -280,7 +285,7 @@ class ComfyWorkflow:
|
|||
# 直接使用widget_name作为字段名
|
||||
inputs[node_name].inputs[widget_name] = ComfyAPIFieldSpec(
|
||||
type=self._get_param_type(a_input.type or "STRING"),
|
||||
widget_name=a_input.widget.name
|
||||
widget_name=a_input.widget.name,
|
||||
)
|
||||
|
||||
elif title.startswith(API_OUTPUT_PREFIX):
|
||||
|
|
@ -290,10 +295,7 @@ class ComfyWorkflow:
|
|||
# 如果节点名不在outputs中,创建新的节点条目
|
||||
if node_name not in outputs:
|
||||
outputs[node_name] = ComfyAPINodeSpec(
|
||||
node_id=node_id,
|
||||
class_type=node.type,
|
||||
inputs={},
|
||||
outputs={}
|
||||
node_id=node_id, class_type=node.type, inputs={}, outputs={}
|
||||
)
|
||||
|
||||
if node.outputs:
|
||||
|
|
@ -340,11 +342,15 @@ class ComfyWorkflow:
|
|||
# 获取节点的当前值作为默认值
|
||||
# 创建空的link_map,因为输入节点通常没有连接
|
||||
empty_link_map = {}
|
||||
node_current_inputs = self._get_node_inputs(node, empty_link_map)
|
||||
node_current_inputs = self._get_node_inputs(
|
||||
node, empty_link_map
|
||||
)
|
||||
current_value = node_current_inputs.get(widget_name)
|
||||
|
||||
# 生成字段的 JSON Schema 定义,包含默认值
|
||||
field_schema = self._generate_field_schema_model(a_input, current_value)
|
||||
field_schema = self._generate_field_schema_model(
|
||||
a_input, current_value
|
||||
)
|
||||
node_properties[widget_name] = field_schema
|
||||
|
||||
# 如果字段是必需的,添加到required列表
|
||||
|
|
@ -357,7 +363,7 @@ class ComfyWorkflow:
|
|||
title=f"输入节点: {node_name}",
|
||||
description=f"节点类型: {node.type}",
|
||||
properties=node_properties,
|
||||
required=node_required
|
||||
required=node_required,
|
||||
)
|
||||
|
||||
# 如果节点有必需字段,将整个节点标记为必需
|
||||
|
|
@ -405,7 +411,10 @@ class ComfyWorkflow:
|
|||
|
||||
target_widget_index = -1
|
||||
for i, widget_input in enumerate(widget_inputs):
|
||||
if widget_input.widget and widget_input.widget.name == widget_name_to_patch:
|
||||
if (
|
||||
widget_input.widget
|
||||
and widget_input.widget.name == widget_name_to_patch
|
||||
):
|
||||
target_widget_index = i
|
||||
break
|
||||
|
||||
|
|
@ -416,7 +425,9 @@ class ComfyWorkflow:
|
|||
continue
|
||||
|
||||
# 确保 `widgets_values` 存在且为列表
|
||||
if not target_node.widgets_values or not isinstance(target_node.widgets_values, list):
|
||||
if not target_node.widgets_values or not isinstance(
|
||||
target_node.widgets_values, list
|
||||
):
|
||||
# 如果不存在或格式错误,根据widget数量创建一个占位符列表
|
||||
target_node.widgets_values = [None] * len(widget_inputs)
|
||||
|
||||
|
|
@ -447,7 +458,9 @@ class ComfyWorkflow:
|
|||
patched_workflow_data.nodes = list(nodes_map.values())
|
||||
|
||||
# 对workflow_data中的nodes进行排序,按照id从小到大
|
||||
patched_workflow_data.nodes = sorted(patched_workflow_data.nodes, key=lambda x: x.id)
|
||||
patched_workflow_data.nodes = sorted(
|
||||
patched_workflow_data.nodes, key=lambda x: x.id
|
||||
)
|
||||
|
||||
# 转换为字典格式以保持向后兼容
|
||||
return patched_workflow_data.model_dump()
|
||||
|
|
@ -655,7 +668,9 @@ class ComfyWorkflow:
|
|||
else:
|
||||
return "string"
|
||||
|
||||
def _get_node_inputs(self, node: Union[ComfyWorkflowNode, dict], link_map: dict) -> dict:
|
||||
def _get_node_inputs(
|
||||
self, node: Union[ComfyWorkflowNode, dict], link_map: dict
|
||||
) -> dict:
|
||||
"""
|
||||
获取节点的输入字段值
|
||||
|
||||
|
|
@ -681,8 +696,8 @@ class ComfyWorkflow:
|
|||
# 1. 处理控件输入 (widgets)
|
||||
widgets_values = node_model.widgets_values or []
|
||||
widget_cursor = 0
|
||||
|
||||
for input_config in (node_model.inputs or []):
|
||||
|
||||
for input_config in node_model.inputs or []:
|
||||
# 如果是widget并且有对应的widgets_values
|
||||
if input_config.widget and widget_cursor < len(widgets_values):
|
||||
widget_name = input_config.widget.name
|
||||
|
|
@ -696,7 +711,7 @@ class ComfyWorkflow:
|
|||
widget_cursor += 1
|
||||
|
||||
# 2. 处理连接输入 (links)
|
||||
for input_config in (node_model.inputs or []):
|
||||
for input_config in node_model.inputs or []:
|
||||
if input_config.link is not None:
|
||||
link_id = input_config.link
|
||||
if link_id in link_map:
|
||||
|
|
|
|||
|
|
@ -219,11 +219,11 @@ async def update_workflow_run_node_status(
|
|||
await session.commit()
|
||||
|
||||
|
||||
async def get_workflow_run(workflow_run_id: str) -> Optional[dict]:
|
||||
async def get_workflow_run(workflow_run_id: str) -> Optional[WorkflowRun]:
|
||||
"""获取工作流运行记录"""
|
||||
async with AsyncSessionLocal() as session:
|
||||
workflow_run = await session.get(WorkflowRun, workflow_run_id)
|
||||
return workflow_run.to_dict() if workflow_run else None
|
||||
return workflow_run
|
||||
|
||||
|
||||
async def get_workflow_run_nodes(workflow_run_id: str) -> List[dict]:
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ async def run_workflow(
|
|||
raise HTTPException(status_code=404, detail=detail)
|
||||
|
||||
workflow = json.loads(workflow_data["workflow_json"])
|
||||
flow = ComfyWorkflow(workflow, workflow_name)
|
||||
flow = ComfyWorkflow(workflow_name, workflow)
|
||||
|
||||
# 提交到队列
|
||||
workflow_run_id = await queue_manager.add_task(workflow=flow, request_data=data)
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ async def model_with_multi_dress(
|
|||
raise HTTPException(status_code=404, detail=detail)
|
||||
|
||||
workflow = json.loads(workflow_data["workflow_json"])
|
||||
flow = ComfyWorkflow(workflow, workflow_name)
|
||||
flow = ComfyWorkflow(workflow_name, workflow)
|
||||
|
||||
# 将请求拆分为多个请求
|
||||
batch_data = _convert(data)
|
||||
|
|
|
|||
|
|
@ -92,8 +92,9 @@ async def get_one_workflow_endpoint(base_name: str, version: Optional[str] = Non
|
|||
|
||||
try:
|
||||
workflow = json.loads(workflow_data["workflow_json"])
|
||||
flow = ComfyWorkflow(workflow)
|
||||
flow = ComfyWorkflow(base_name, workflow)
|
||||
return {
|
||||
"name": flow.workflow_name,
|
||||
"workflow": flow.workflow_data,
|
||||
"api_spec": flow.get_api_spec(),
|
||||
"inputs_json_schema": flow.get_inputs_json_schema(),
|
||||
|
|
|
|||
Loading…
Reference in New Issue