fix
This commit is contained in:
parent
1af6b42573
commit
7830d2e02d
|
|
@ -1,4 +1,5 @@
|
|||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
|
|
@ -7,7 +8,7 @@ import uuid
|
|||
import websockets
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, Optional, List, Set
|
||||
from typing import Dict, Any, Optional, List, Set, Union
|
||||
|
||||
import aiohttp
|
||||
from aiohttp import ClientTimeout
|
||||
|
|
@ -309,7 +310,9 @@ async def execute_prompt_on_server(
|
|||
client_id = str(uuid.uuid4())
|
||||
|
||||
# 应用请求数据到工作流
|
||||
patched_workflow = patch_workflow(workflow_data, api_spec, request_data)
|
||||
patched_workflow = await patch_workflow(
|
||||
workflow_data, api_spec, request_data, server
|
||||
)
|
||||
|
||||
# 转换为prompt格式
|
||||
prompt = convert_workflow_to_prompt_api_format(patched_workflow)
|
||||
|
|
@ -438,10 +441,11 @@ def parse_api_spec(workflow_data: dict) -> Dict[str, Dict[str, Any]]:
|
|||
return spec
|
||||
|
||||
|
||||
def patch_workflow(
|
||||
async def patch_workflow(
|
||||
workflow_data: dict,
|
||||
api_spec: dict[str, dict[str, Any]],
|
||||
request_data: dict[str, Any],
|
||||
server: ComfyUIServer,
|
||||
) -> dict:
|
||||
"""
|
||||
将request_data中的参数值,patch到workflow_data中。并返回修改后的workflow_data。
|
||||
|
|
@ -501,6 +505,8 @@ def patch_workflow(
|
|||
|
||||
# 在正确的位置上更新值
|
||||
try:
|
||||
if target_node.get("type") == "LoadImage":
|
||||
value = await upload_image_to_comfy(value, server)
|
||||
target_node["widgets_values"][target_widget_index] = target_type(value)
|
||||
except (ValueError, TypeError) as e:
|
||||
logger.warning(
|
||||
|
|
@ -779,3 +785,35 @@ async def _get_execution_results_legacy(
|
|||
raise e
|
||||
|
||||
return aggregated_outputs
|
||||
|
||||
|
||||
async def upload_image_to_comfy(file_path: str, server: ComfyUIServer) -> str:
|
||||
"""
|
||||
上传文件到服务器。
|
||||
file 是一个http链接
|
||||
返回一个名称
|
||||
"""
|
||||
file_name = file_path.split("/")[-1]
|
||||
file_data = await download_file(file_path)
|
||||
|
||||
form_data = aiohttp.FormData()
|
||||
form_data.add_field("image", file_data, filename=file_name, content_type="image/*")
|
||||
form_data.add_field("type", "input")
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{server.http_url}/api/upload/image", data=form_data
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
result = await response.json()
|
||||
return result["name"]
|
||||
|
||||
|
||||
async def download_file(src: str) -> bytes:
|
||||
"""
|
||||
下载文件到本地。
|
||||
"""
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(src) as response:
|
||||
response.raise_for_status()
|
||||
return await response.read()
|
||||
|
|
|
|||
|
|
@ -15,11 +15,6 @@ from pydantic import BaseModel
|
|||
from workflow_service import comfyui_client
|
||||
from workflow_service import database
|
||||
from workflow_service.utils.s3_client import upload_file_to_s3
|
||||
from workflow_service.comfyui_client import (
|
||||
ComfyUIExecutionError,
|
||||
queue_manager,
|
||||
submit_workflow_to_queue,
|
||||
)
|
||||
from workflow_service.config import Settings
|
||||
|
||||
settings = Settings()
|
||||
|
|
@ -197,8 +192,11 @@ async def run_workflow(
|
|||
api_spec = comfyui_client.parse_api_spec(workflow)
|
||||
|
||||
# 提交到队列
|
||||
workflow_run_id = await submit_workflow_to_queue(
|
||||
workflow_name=workflow_name, workflow_data=workflow, api_spec=api_spec, request_data=data
|
||||
workflow_run_id = await comfyui_client.submit_workflow_to_queue(
|
||||
workflow_name=workflow_name,
|
||||
workflow_data=workflow,
|
||||
api_spec=api_spec,
|
||||
request_data=data,
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
|
|
@ -220,7 +218,7 @@ async def get_run_status(workflow_run_id: str):
|
|||
获取工作流执行状态。
|
||||
"""
|
||||
try:
|
||||
status = await queue_manager.get_task_status(workflow_run_id)
|
||||
status = await comfyui_client.queue_manager.get_task_status(workflow_run_id)
|
||||
return status
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"获取状态失败: {str(e)}")
|
||||
|
|
@ -232,13 +230,13 @@ async def get_metrics():
|
|||
获取队列状态概览。
|
||||
"""
|
||||
try:
|
||||
pending_count = len(queue_manager.pending_tasks)
|
||||
running_count = len(queue_manager.running_tasks)
|
||||
pending_count = len(comfyui_client.queue_manager.pending_tasks)
|
||||
running_count = len(comfyui_client.queue_manager.running_tasks)
|
||||
|
||||
return {
|
||||
"pending_tasks": pending_count,
|
||||
"running_tasks": running_count,
|
||||
"total_servers": len(queue_manager.running_tasks),
|
||||
"total_servers": len(comfyui_client.queue_manager.running_tasks),
|
||||
"queue_manager_status": "active",
|
||||
}
|
||||
except Exception as e:
|
||||
|
|
|
|||
Loading…
Reference in New Issue