ComfyUI-WorkflowPublisher/workflow_service/main.py

213 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import uvicorn
from fastapi import FastAPI, Request, HTTPException, Path
from fastapi.responses import JSONResponse
from typing import Optional, List, Dict, Any, Set
import json, uuid, os, shutil, aiohttp, database, workflow_parser, comfyui_client, s3_client
from config import settings
app = FastAPI(title="ComfyUI Workflow Service & Management API")
@app.on_event("startup")
async def startup_event(): await database.init_db(); os.makedirs(settings.COMFYUI_INPUT_DIR,
exist_ok=True); os.makedirs(
settings.COMFYUI_OUTPUT_DIR, exist_ok=True)
# --- Section 1: 工作流管理API (无改动) ---
# ... (代码与上一版完全相同) ...
BASE_MANAGEMENT_PATH = "/api/workflow"
@app.post(BASE_MANAGEMENT_PATH, status_code=200)
async def publish_workflow_endpoint(request: Request):
try:
data = await request.json(); name, wf_json = data.get("name"), data.get(
"workflow"); await database.save_workflow(name, json.dumps(wf_json)); return JSONResponse(
content={"status": "success", "message": f"Workflow '{name}' published."}, status_code=200)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to save workflow: {e}")
@app.get(BASE_MANAGEMENT_PATH, response_model=List[dict])
async def get_all_workflows_endpoint():
try:
return await database.get_all_workflows()
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to get workflows: {e}")
@app.delete(f"{BASE_MANAGEMENT_PATH}/{{workflow_name:path}}")
async def delete_workflow_endpoint(workflow_name: str = Path(..., title="...")):
try:
success = await database.delete_workflow(workflow_name);
if success:
return {"status": "deleted", "name": workflow_name};
else:
raise HTTPException(status_code=404, detail="Workflow not found")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to delete workflow: {e}")
# --- Section 2: 工作流执行API (核心改动) ---
def get_files_in_dir(directory: str) -> Set[str]:
file_set = set()
for root, _, files in os.walk(directory):
for filename in files:
if not filename.startswith('.'): file_set.add(os.path.join(root, filename))
return file_set
async def download_file_from_url(session: aiohttp.ClientSession, url: str, save_path: str):
async with session.get(url) as response:
response.raise_for_status();
with open(save_path, 'wb') as f:
while True:
chunk = await response.content.read(8192);
if not chunk:
break;
f.write(chunk)
async def handle_file_upload(file_path: str, base_name: str) -> str:
"""辅助函数上传文件到S3并返回URL"""
s3_object_name = f"outputs/{base_name}/{uuid.uuid4()}_{os.path.basename(file_path)}"
return await s3_client.upload_file_to_s3(file_path, settings.S3_BUCKET_NAME, s3_object_name)
@app.post("/api/run/{base_name}")
async def execute_workflow_endpoint(base_name: str, request_data_raw: Dict[str, Any], version: Optional[str] = None):
cleanup_paths = []
try:
# 1. 获取工作流和处理输入 (与上一版相同)
# ...
if version:
workflow_data = await database.get_workflow_by_version(base_name, version)
else:
workflow_data = await database.get_latest_workflow_by_base_name(base_name)
if not workflow_data: raise HTTPException(status_code=404, detail=f"Workflow '{base_name}' not found.")
workflow = json.loads(workflow_data['workflow_json'])
api_spec = workflow_parser.parse_api_spec(workflow)
request_data = {k.lower(): v for k, v in request_data_raw.items()}
async with aiohttp.ClientSession() as session:
for param_name, spec in api_spec["inputs"].items():
if spec["type"] == "UploadFile" and param_name in request_data:
image_url = request_data[param_name]
if not isinstance(image_url, str) or not image_url.startswith('http'): raise HTTPException(
status_code=400, detail=f"Parameter '{param_name}' must be a valid URL.")
original_filename = image_url.split('/')[-1].split('?')[0];
_, file_extension = os.path.splitext(original_filename)
if not file_extension: file_extension = '.dat'
filename = f"api_download_{uuid.uuid4()}{file_extension}"
save_path = os.path.join(settings.COMFYUI_INPUT_DIR, filename)
try:
await download_file_from_url(session, image_url, save_path);
request_data[param_name] = filename; cleanup_paths.append(save_path)
except Exception as e:
raise HTTPException(status_code=500,
detail=f"Failed to download file for '{param_name}' from {image_url}. Error: {e}")
# 2. 执行前快照
files_before = get_files_in_dir(settings.COMFYUI_OUTPUT_DIR)
# 3. Patch, 转换并执行 (缓存破解已移入client)
patched_workflow = workflow_parser.patch_workflow(workflow, api_spec, request_data)
prompt_to_run = workflow_parser.convert_workflow_to_prompt_api_format(patched_workflow)
output_nodes = await comfyui_client.run_workflow(prompt_to_run)
# 4. 执行后快照并计算差异
files_after = get_files_in_dir(settings.COMFYUI_OUTPUT_DIR)
new_files = files_after - files_before
# 5. [核心修正] 统一处理所有输出
output_response = {}
processed_files = set() # 记录已通过文件系统快照处理的文件
# 5.1 首先处理所有新生成的文件
if new_files:
s3_urls = []
for file_path in new_files:
cleanup_paths.append(file_path)
processed_files.add(os.path.basename(file_path))
try:
s3_urls.append(await handle_file_upload(file_path, base_name))
except Exception as e:
print(f"Error uploading file {file_path} to S3: {e}")
if s3_urls: output_response["output_files"] = s3_urls
# 5.2 然后处理WebSocket返回的非文件输出并检查文本输出是否是文件路径
for final_param_name, spec in api_spec["outputs"].items():
node_id = spec["node_id"]
if node_id in output_nodes:
node_output = output_nodes[node_id]
original_output_name = spec["output_name"]
if original_output_name in node_output:
output_value = node_output[original_output_name]
# 展开列表
if isinstance(output_value, list):
output_value = output_value[0] if output_value else None
# 检查文本输出是否是未被发现的文件路径
if isinstance(output_value, str) and (
'.png' in output_value or '.jpg' in output_value or '.mp4' in output_value or 'output' in output_value):
potential_filename = os.path.basename(output_value.replace('\\', '/'))
if potential_filename not in processed_files:
# 这是一个新的文件路径尝试在output目录中找到它
potential_path = os.path.join(settings.COMFYUI_OUTPUT_DIR, potential_filename)
if os.path.exists(potential_path):
print(f"Found extra file from text output: {potential_path}")
cleanup_paths.append(potential_path)
processed_files.add(potential_filename)
try:
s3_url = await handle_file_upload(potential_path, base_name)
# 将它也加入output_files列表
if "output_files" not in output_response: output_response["output_files"] = []
output_response["output_files"].append(s3_url)
except Exception as e:
print(f"Error uploading extra file {potential_path} to S3: {e}")
continue # 处理完毕,跳过将其作为文本输出
output_response[final_param_name] = output_value
elif "text" in node_output:
output_value = node_output["text"]
# 如果不是文件,则作为普通值输出
output_response[final_param_name] = output_value
return output_response
finally:
# 清理操作保持不变
print(f"Cleaning up {len(cleanup_paths)} temporary files...")
for path in cleanup_paths:
try:
if os.path.exists(path): os.remove(path); print(f" - Deleted: {path}")
except Exception as e:
print(f" - Error deleting {path}: {e}")
# --- Section 3: 工作流元数据/规范API (无改动) ---
# ... (此部分代码与上一版完全相同) ...
@app.get("/api/spec/{base_name}")
async def get_workflow_spec_endpoint(base_name: str, version: Optional[str] = None):
# ...
if version:
workflow_data = await database.get_workflow_by_version(base_name, version)
else:
workflow_data = await database.get_latest_workflow_by_base_name(base_name)
if not workflow_data:
detail = f"Workflow '{base_name}'" + (f" with version '{version}'" if version else "") + " not found."
raise HTTPException(status_code=404, detail=detail)
try:
workflow = json.loads(workflow_data['workflow_json'])
return workflow_parser.parse_api_spec(workflow)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to parse workflow specification: {e}")
@app.get("/")
def read_root(): return {"message": "Welcome to the ComfyUI Workflow Service API!"}
if __name__ == "__main__":
uvicorn.run(app, host="127.0.0.1", port=18000)