fix
This commit is contained in:
parent
1af6b42573
commit
7830d2e02d
|
|
@ -1,4 +1,5 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import base64
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
|
@ -7,7 +8,7 @@ import uuid
|
||||||
import websockets
|
import websockets
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Dict, Any, Optional, List, Set
|
from typing import Dict, Any, Optional, List, Set, Union
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from aiohttp import ClientTimeout
|
from aiohttp import ClientTimeout
|
||||||
|
|
@ -309,7 +310,9 @@ async def execute_prompt_on_server(
|
||||||
client_id = str(uuid.uuid4())
|
client_id = str(uuid.uuid4())
|
||||||
|
|
||||||
# 应用请求数据到工作流
|
# 应用请求数据到工作流
|
||||||
patched_workflow = patch_workflow(workflow_data, api_spec, request_data)
|
patched_workflow = await patch_workflow(
|
||||||
|
workflow_data, api_spec, request_data, server
|
||||||
|
)
|
||||||
|
|
||||||
# 转换为prompt格式
|
# 转换为prompt格式
|
||||||
prompt = convert_workflow_to_prompt_api_format(patched_workflow)
|
prompt = convert_workflow_to_prompt_api_format(patched_workflow)
|
||||||
|
|
@ -438,10 +441,11 @@ def parse_api_spec(workflow_data: dict) -> Dict[str, Dict[str, Any]]:
|
||||||
return spec
|
return spec
|
||||||
|
|
||||||
|
|
||||||
def patch_workflow(
|
async def patch_workflow(
|
||||||
workflow_data: dict,
|
workflow_data: dict,
|
||||||
api_spec: dict[str, dict[str, Any]],
|
api_spec: dict[str, dict[str, Any]],
|
||||||
request_data: dict[str, Any],
|
request_data: dict[str, Any],
|
||||||
|
server: ComfyUIServer,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
将request_data中的参数值,patch到workflow_data中。并返回修改后的workflow_data。
|
将request_data中的参数值,patch到workflow_data中。并返回修改后的workflow_data。
|
||||||
|
|
@ -501,6 +505,8 @@ def patch_workflow(
|
||||||
|
|
||||||
# 在正确的位置上更新值
|
# 在正确的位置上更新值
|
||||||
try:
|
try:
|
||||||
|
if target_node.get("type") == "LoadImage":
|
||||||
|
value = await upload_image_to_comfy(value, server)
|
||||||
target_node["widgets_values"][target_widget_index] = target_type(value)
|
target_node["widgets_values"][target_widget_index] = target_type(value)
|
||||||
except (ValueError, TypeError) as e:
|
except (ValueError, TypeError) as e:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
|
@ -779,3 +785,35 @@ async def _get_execution_results_legacy(
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
return aggregated_outputs
|
return aggregated_outputs
|
||||||
|
|
||||||
|
|
||||||
|
async def upload_image_to_comfy(file_path: str, server: ComfyUIServer) -> str:
|
||||||
|
"""
|
||||||
|
上传文件到服务器。
|
||||||
|
file 是一个http链接
|
||||||
|
返回一个名称
|
||||||
|
"""
|
||||||
|
file_name = file_path.split("/")[-1]
|
||||||
|
file_data = await download_file(file_path)
|
||||||
|
|
||||||
|
form_data = aiohttp.FormData()
|
||||||
|
form_data.add_field("image", file_data, filename=file_name, content_type="image/*")
|
||||||
|
form_data.add_field("type", "input")
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{server.http_url}/api/upload/image", data=form_data
|
||||||
|
) as response:
|
||||||
|
response.raise_for_status()
|
||||||
|
result = await response.json()
|
||||||
|
return result["name"]
|
||||||
|
|
||||||
|
|
||||||
|
async def download_file(src: str) -> bytes:
|
||||||
|
"""
|
||||||
|
下载文件到本地。
|
||||||
|
"""
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.get(src) as response:
|
||||||
|
response.raise_for_status()
|
||||||
|
return await response.read()
|
||||||
|
|
|
||||||
|
|
@ -15,11 +15,6 @@ from pydantic import BaseModel
|
||||||
from workflow_service import comfyui_client
|
from workflow_service import comfyui_client
|
||||||
from workflow_service import database
|
from workflow_service import database
|
||||||
from workflow_service.utils.s3_client import upload_file_to_s3
|
from workflow_service.utils.s3_client import upload_file_to_s3
|
||||||
from workflow_service.comfyui_client import (
|
|
||||||
ComfyUIExecutionError,
|
|
||||||
queue_manager,
|
|
||||||
submit_workflow_to_queue,
|
|
||||||
)
|
|
||||||
from workflow_service.config import Settings
|
from workflow_service.config import Settings
|
||||||
|
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
|
|
@ -197,8 +192,11 @@ async def run_workflow(
|
||||||
api_spec = comfyui_client.parse_api_spec(workflow)
|
api_spec = comfyui_client.parse_api_spec(workflow)
|
||||||
|
|
||||||
# 提交到队列
|
# 提交到队列
|
||||||
workflow_run_id = await submit_workflow_to_queue(
|
workflow_run_id = await comfyui_client.submit_workflow_to_queue(
|
||||||
workflow_name=workflow_name, workflow_data=workflow, api_spec=api_spec, request_data=data
|
workflow_name=workflow_name,
|
||||||
|
workflow_data=workflow,
|
||||||
|
api_spec=api_spec,
|
||||||
|
request_data=data,
|
||||||
)
|
)
|
||||||
|
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
|
|
@ -220,7 +218,7 @@ async def get_run_status(workflow_run_id: str):
|
||||||
获取工作流执行状态。
|
获取工作流执行状态。
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
status = await queue_manager.get_task_status(workflow_run_id)
|
status = await comfyui_client.queue_manager.get_task_status(workflow_run_id)
|
||||||
return status
|
return status
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=f"获取状态失败: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"获取状态失败: {str(e)}")
|
||||||
|
|
@ -232,13 +230,13 @@ async def get_metrics():
|
||||||
获取队列状态概览。
|
获取队列状态概览。
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
pending_count = len(queue_manager.pending_tasks)
|
pending_count = len(comfyui_client.queue_manager.pending_tasks)
|
||||||
running_count = len(queue_manager.running_tasks)
|
running_count = len(comfyui_client.queue_manager.running_tasks)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"pending_tasks": pending_count,
|
"pending_tasks": pending_count,
|
||||||
"running_tasks": running_count,
|
"running_tasks": running_count,
|
||||||
"total_servers": len(queue_manager.running_tasks),
|
"total_servers": len(comfyui_client.queue_manager.running_tasks),
|
||||||
"queue_manager_status": "active",
|
"queue_manager_status": "active",
|
||||||
}
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue