import aiohttp from workflow_service import comfyui_client from workflow_service import database import json import os from workflow_service import s3_client import uuid from workflow_service import workflow_parser from typing import Optional, List, Dict, Any, Set import uvicorn from fastapi import FastAPI, Request, HTTPException, Path from fastapi.responses import JSONResponse from workflow_service.config import Settings settings = Settings() web_app = FastAPI(title="ComfyUI Workflow Service & Management API") @web_app.on_event("startup") async def startup_event(): await database.init_db(); os.makedirs(settings.COMFYUI_INPUT_DIR, exist_ok=True); os.makedirs( settings.COMFYUI_OUTPUT_DIR, exist_ok=True) # --- Section 1: 工作流管理API (无改动) --- # ... (代码与上一版完全相同) ... BASE_MANAGEMENT_PATH = "/api/workflow" @web_app.post(BASE_MANAGEMENT_PATH, status_code=200) async def publish_workflow_endpoint(request: Request): try: data = await request.json(); name, wf_json = data.get("name"), data.get( "workflow"); await database.save_workflow(name, json.dumps(wf_json)); return JSONResponse( content={"status": "success", "message": f"Workflow '{name}' published."}, status_code=200) except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to save workflow: {e}") @web_app.get(BASE_MANAGEMENT_PATH, response_model=List[dict]) async def get_all_workflows_endpoint(): try: return await database.get_all_workflows() except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to get workflows: {e}") @web_app.delete(f"{BASE_MANAGEMENT_PATH}/{{workflow_name:path}}") async def delete_workflow_endpoint(workflow_name: str = Path(..., title="...")): try: success = await database.delete_workflow(workflow_name); if success: return {"status": "deleted", "name": workflow_name}; else: raise HTTPException(status_code=404, detail="Workflow not found") except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to delete workflow: {e}") # --- Section 2: 工作流执行API (核心改动) --- def get_files_in_dir(directory: str) -> Set[str]: file_set = set() for root, _, files in os.walk(directory): for filename in files: if not filename.startswith('.'): file_set.add(os.path.join(root, filename)) return file_set async def download_file_from_url(session: aiohttp.ClientSession, url: str, save_path: str): async with session.get(url) as response: response.raise_for_status(); with open(save_path, 'wb') as f: while True: chunk = await response.content.read(8192); if not chunk: break; f.write(chunk) async def handle_file_upload(file_path: str, base_name: str) -> str: """辅助函数:上传文件到S3并返回URL""" s3_object_name = f"outputs/{base_name}/{uuid.uuid4()}_{os.path.basename(file_path)}" return await s3_client.upload_file_to_s3(file_path, settings.S3_BUCKET_NAME, s3_object_name) @web_app.post("/api/run/{base_name}") async def execute_workflow_endpoint(base_name: str, request_data_raw: Dict[str, Any], version: Optional[str] = None): cleanup_paths = [] try: # 1. 获取工作流和处理输入 (与上一版相同) # ... if version: workflow_data = await database.get_workflow_by_version(base_name, version) else: workflow_data = await database.get_latest_workflow_by_base_name(base_name) if not workflow_data: raise HTTPException(status_code=404, detail=f"Workflow '{base_name}' not found.") workflow = json.loads(workflow_data['workflow_json']) api_spec = workflow_parser.parse_api_spec(workflow) request_data = {k.lower(): v for k, v in request_data_raw.items()} async with aiohttp.ClientSession() as session: for param_name, spec in api_spec["inputs"].items(): if spec["type"] == "UploadFile" and param_name in request_data: image_url = request_data[param_name] if not isinstance(image_url, str) or not image_url.startswith('http'): raise HTTPException( status_code=400, detail=f"Parameter '{param_name}' must be a valid URL.") original_filename = image_url.split('/')[-1].split('?')[0]; _, file_extension = os.path.splitext(original_filename) if not file_extension: file_extension = '.dat' filename = f"api_download_{uuid.uuid4()}{file_extension}" save_path = os.path.join(settings.COMFYUI_INPUT_DIR, filename) try: await download_file_from_url(session, image_url, save_path); request_data[param_name] = filename; cleanup_paths.append(save_path) except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to download file for '{param_name}' from {image_url}. Error: {e}") # 2. 执行前快照 files_before = get_files_in_dir(settings.COMFYUI_OUTPUT_DIR) # 3. Patch, 转换并执行 (缓存破解已移入client) patched_workflow = workflow_parser.patch_workflow(workflow, api_spec, request_data) prompt_to_run = workflow_parser.convert_workflow_to_prompt_api_format(patched_workflow) output_nodes = await comfyui_client.run_workflow(prompt_to_run) # 4. 执行后快照并计算差异 files_after = get_files_in_dir(settings.COMFYUI_OUTPUT_DIR) new_files = files_after - files_before # 5. [核心修正] 统一处理所有输出 output_response = {} processed_files = set() # 记录已通过文件系统快照处理的文件 # 5.1 首先处理所有新生成的文件 if new_files: s3_urls = [] for file_path in new_files: cleanup_paths.append(file_path) processed_files.add(os.path.basename(file_path)) try: s3_urls.append(await handle_file_upload(file_path, base_name)) except Exception as e: print(f"Error uploading file {file_path} to S3: {e}") if s3_urls: output_response["output_files"] = s3_urls # 5.2 然后处理WebSocket返回的非文件输出,并检查文本输出是否是文件路径 for final_param_name, spec in api_spec["outputs"].items(): node_id = spec["node_id"] if node_id in output_nodes: node_output = output_nodes[node_id] original_output_name = spec["output_name"] if original_output_name in node_output: output_value = node_output[original_output_name] # 展开列表 if isinstance(output_value, list): output_value = output_value[0] if output_value else None # 检查文本输出是否是未被发现的文件路径 if isinstance(output_value, str) and ( '.png' in output_value or '.jpg' in output_value or '.mp4' in output_value or 'output' in output_value): potential_filename = os.path.basename(output_value.replace('\\', '/')) if potential_filename not in processed_files: # 这是一个新的文件路径,尝试在output目录中找到它 potential_path = os.path.join(settings.COMFYUI_OUTPUT_DIR, potential_filename) if os.path.exists(potential_path): print(f"Found extra file from text output: {potential_path}") cleanup_paths.append(potential_path) processed_files.add(potential_filename) try: s3_url = await handle_file_upload(potential_path, base_name) # 将它也加入output_files列表 if "output_files" not in output_response: output_response["output_files"] = [] output_response["output_files"].append(s3_url) except Exception as e: print(f"Error uploading extra file {potential_path} to S3: {e}") continue # 处理完毕,跳过将其作为文本输出 output_response[final_param_name] = output_value elif "text" in node_output: output_value = node_output["text"] # 如果不是文件,则作为普通值输出 output_response[final_param_name] = output_value return output_response finally: # 清理操作保持不变 print(f"Cleaning up {len(cleanup_paths)} temporary files...") for path in cleanup_paths: try: if os.path.exists(path): os.remove(path); print(f" - Deleted: {path}") except Exception as e: print(f" - Error deleting {path}: {e}") # --- Section 3: 工作流元数据/规范API (无改动) --- # ... (此部分代码与上一版完全相同) ... @web_app.get("/api/spec/{base_name}") async def get_workflow_spec_endpoint(base_name: str, version: Optional[str] = None): # ... if version: workflow_data = await database.get_workflow_by_version(base_name, version) else: workflow_data = await database.get_latest_workflow_by_base_name(base_name) if not workflow_data: detail = f"Workflow '{base_name}'" + (f" with version '{version}'" if version else "") + " not found." raise HTTPException(status_code=404, detail=detail) try: workflow = json.loads(workflow_data['workflow_json']) return workflow_parser.parse_api_spec(workflow) except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to parse workflow specification: {e}") @web_app.get("/") def read_root(): return {"message": "Welcome to the ComfyUI Workflow Service API!"} if __name__ == "__main__": uvicorn.run(web_app, host="127.0.0.1", port=18000)