227 lines
10 KiB
Python
227 lines
10 KiB
Python
import aiohttp
|
||
from workflow_service import comfyui_client
|
||
from workflow_service import database
|
||
import json
|
||
import os
|
||
from workflow_service import s3_client
|
||
import uuid
|
||
from workflow_service import workflow_parser
|
||
from typing import Optional, List, Dict, Any, Set
|
||
|
||
import uvicorn
|
||
from fastapi import FastAPI, Request, HTTPException, Path
|
||
from fastapi.responses import JSONResponse
|
||
|
||
from workflow_service.config import Settings
|
||
settings = Settings()
|
||
|
||
web_app = FastAPI(title="ComfyUI Workflow Service & Management API")
|
||
|
||
|
||
@web_app.on_event("startup")
|
||
async def startup_event(): await database.init_db(); os.makedirs(settings.COMFYUI_INPUT_DIR,
|
||
exist_ok=True); os.makedirs(
|
||
settings.COMFYUI_OUTPUT_DIR, exist_ok=True)
|
||
|
||
|
||
# --- Section 1: 工作流管理API (无改动) ---
|
||
# ... (代码与上一版完全相同) ...
|
||
BASE_MANAGEMENT_PATH = "/api/workflow"
|
||
|
||
|
||
@web_app.post(BASE_MANAGEMENT_PATH, status_code=200)
|
||
async def publish_workflow_endpoint(request: Request):
|
||
try:
|
||
data = await request.json();
|
||
name, wf_json = data.get("name"), data.get(
|
||
"workflow");
|
||
await database.save_workflow(name, json.dumps(wf_json));
|
||
return JSONResponse(
|
||
content={"status": "success", "message": f"Workflow '{name}' published."}, status_code=200)
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"Failed to save workflow: {e}")
|
||
|
||
|
||
@web_app.get(BASE_MANAGEMENT_PATH, response_model=List[dict])
|
||
async def get_all_workflows_endpoint():
|
||
try:
|
||
return await database.get_all_workflows()
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"Failed to get workflows: {e}")
|
||
|
||
|
||
@web_app.delete(f"{BASE_MANAGEMENT_PATH}/{{workflow_name:path}}")
|
||
async def delete_workflow_endpoint(workflow_name: str = Path(..., title="...")):
|
||
try:
|
||
success = await database.delete_workflow(workflow_name);
|
||
if success:
|
||
return {"status": "deleted", "name": workflow_name};
|
||
else:
|
||
raise HTTPException(status_code=404, detail="Workflow not found")
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"Failed to delete workflow: {e}")
|
||
|
||
|
||
# --- Section 2: 工作流执行API (核心改动) ---
|
||
def get_files_in_dir(directory: str) -> Set[str]:
|
||
file_set = set()
|
||
for root, _, files in os.walk(directory):
|
||
for filename in files:
|
||
if not filename.startswith('.'): file_set.add(os.path.join(root, filename))
|
||
return file_set
|
||
|
||
|
||
async def download_file_from_url(session: aiohttp.ClientSession, url: str, save_path: str):
|
||
async with session.get(url) as response:
|
||
response.raise_for_status();
|
||
with open(save_path, 'wb') as f:
|
||
while True:
|
||
chunk = await response.content.read(8192);
|
||
if not chunk:
|
||
break;
|
||
f.write(chunk)
|
||
|
||
|
||
async def handle_file_upload(file_path: str, base_name: str) -> str:
|
||
"""辅助函数:上传文件到S3并返回URL"""
|
||
s3_object_name = f"outputs/{base_name}/{uuid.uuid4()}_{os.path.basename(file_path)}"
|
||
return await s3_client.upload_file_to_s3(file_path, settings.S3_BUCKET_NAME, s3_object_name)
|
||
|
||
|
||
@web_app.post("/api/run/{base_name}")
|
||
async def execute_workflow_endpoint(base_name: str, request_data_raw: Dict[str, Any], version: Optional[str] = None):
|
||
cleanup_paths = []
|
||
try:
|
||
# 1. 获取工作流和处理输入 (与上一版相同)
|
||
# ...
|
||
if version:
|
||
workflow_data = await database.get_workflow_by_version(base_name, version)
|
||
else:
|
||
workflow_data = await database.get_latest_workflow_by_base_name(base_name)
|
||
if not workflow_data: raise HTTPException(status_code=404, detail=f"Workflow '{base_name}' not found.")
|
||
workflow = json.loads(workflow_data['workflow_json'])
|
||
api_spec = workflow_parser.parse_api_spec(workflow)
|
||
request_data = {k.lower(): v for k, v in request_data_raw.items()}
|
||
async with aiohttp.ClientSession() as session:
|
||
for param_name, spec in api_spec["inputs"].items():
|
||
if spec["type"] == "UploadFile" and param_name in request_data:
|
||
image_url = request_data[param_name]
|
||
if not isinstance(image_url, str) or not image_url.startswith('http'): raise HTTPException(
|
||
status_code=400, detail=f"Parameter '{param_name}' must be a valid URL.")
|
||
original_filename = image_url.split('/')[-1].split('?')[0];
|
||
_, file_extension = os.path.splitext(original_filename)
|
||
if not file_extension: file_extension = '.dat'
|
||
filename = f"api_download_{uuid.uuid4()}{file_extension}"
|
||
save_path = os.path.join(settings.COMFYUI_INPUT_DIR, filename)
|
||
try:
|
||
await download_file_from_url(session, image_url, save_path);
|
||
request_data[param_name] = filename;
|
||
cleanup_paths.append(save_path)
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500,
|
||
detail=f"Failed to download file for '{param_name}' from {image_url}. Error: {e}")
|
||
|
||
# 2. 执行前快照
|
||
files_before = get_files_in_dir(settings.COMFYUI_OUTPUT_DIR)
|
||
|
||
# 3. Patch, 转换并执行 (缓存破解已移入client)
|
||
patched_workflow = workflow_parser.patch_workflow(workflow, api_spec, request_data)
|
||
prompt_to_run = workflow_parser.convert_workflow_to_prompt_api_format(patched_workflow)
|
||
output_nodes = await comfyui_client.run_workflow(prompt_to_run)
|
||
|
||
# 4. 执行后快照并计算差异
|
||
files_after = get_files_in_dir(settings.COMFYUI_OUTPUT_DIR)
|
||
new_files = files_after - files_before
|
||
|
||
# 5. [核心修正] 统一处理所有输出
|
||
output_response = {}
|
||
processed_files = set() # 记录已通过文件系统快照处理的文件
|
||
|
||
# 5.1 首先处理所有新生成的文件
|
||
if new_files:
|
||
s3_urls = []
|
||
for file_path in new_files:
|
||
cleanup_paths.append(file_path)
|
||
processed_files.add(os.path.basename(file_path))
|
||
try:
|
||
s3_urls.append(await handle_file_upload(file_path, base_name))
|
||
except Exception as e:
|
||
print(f"Error uploading file {file_path} to S3: {e}")
|
||
if s3_urls: output_response["output_files"] = s3_urls
|
||
|
||
# 5.2 然后处理WebSocket返回的非文件输出,并检查文本输出是否是文件路径
|
||
for final_param_name, spec in api_spec["outputs"].items():
|
||
node_id = spec["node_id"]
|
||
if node_id in output_nodes:
|
||
node_output = output_nodes[node_id]
|
||
original_output_name = spec["output_name"]
|
||
|
||
if original_output_name in node_output:
|
||
output_value = node_output[original_output_name]
|
||
# 展开列表
|
||
if isinstance(output_value, list):
|
||
output_value = output_value[0] if output_value else None
|
||
|
||
# 检查文本输出是否是未被发现的文件路径
|
||
if isinstance(output_value, str) and (
|
||
'.png' in output_value or '.jpg' in output_value or '.mp4' in output_value or 'output' in output_value):
|
||
potential_filename = os.path.basename(output_value.replace('\\', '/'))
|
||
if potential_filename not in processed_files:
|
||
# 这是一个新的文件路径,尝试在output目录中找到它
|
||
potential_path = os.path.join(settings.COMFYUI_OUTPUT_DIR, potential_filename)
|
||
if os.path.exists(potential_path):
|
||
print(f"Found extra file from text output: {potential_path}")
|
||
cleanup_paths.append(potential_path)
|
||
processed_files.add(potential_filename)
|
||
try:
|
||
s3_url = await handle_file_upload(potential_path, base_name)
|
||
# 将它也加入output_files列表
|
||
if "output_files" not in output_response: output_response["output_files"] = []
|
||
output_response["output_files"].append(s3_url)
|
||
except Exception as e:
|
||
print(f"Error uploading extra file {potential_path} to S3: {e}")
|
||
continue # 处理完毕,跳过将其作为文本输出
|
||
output_response[final_param_name] = output_value
|
||
elif "text" in node_output:
|
||
output_value = node_output["text"]
|
||
# 如果不是文件,则作为普通值输出
|
||
output_response[final_param_name] = output_value
|
||
|
||
return output_response
|
||
|
||
finally:
|
||
# 清理操作保持不变
|
||
print(f"Cleaning up {len(cleanup_paths)} temporary files...")
|
||
for path in cleanup_paths:
|
||
try:
|
||
if os.path.exists(path): os.remove(path); print(f" - Deleted: {path}")
|
||
except Exception as e:
|
||
print(f" - Error deleting {path}: {e}")
|
||
|
||
|
||
# --- Section 3: 工作流元数据/规范API (无改动) ---
|
||
# ... (此部分代码与上一版完全相同) ...
|
||
@web_app.get("/api/spec/{base_name}")
|
||
async def get_workflow_spec_endpoint(base_name: str, version: Optional[str] = None):
|
||
# ...
|
||
if version:
|
||
workflow_data = await database.get_workflow_by_version(base_name, version)
|
||
else:
|
||
workflow_data = await database.get_latest_workflow_by_base_name(base_name)
|
||
if not workflow_data:
|
||
detail = f"Workflow '{base_name}'" + (f" with version '{version}'" if version else "") + " not found."
|
||
raise HTTPException(status_code=404, detail=detail)
|
||
try:
|
||
workflow = json.loads(workflow_data['workflow_json'])
|
||
return workflow_parser.parse_api_spec(workflow)
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"Failed to parse workflow specification: {e}")
|
||
|
||
|
||
@web_app.get("/")
|
||
def read_root(): return {"message": "Welcome to the ComfyUI Workflow Service API!"}
|
||
|
||
|
||
if __name__ == "__main__":
|
||
uvicorn.run(web_app, host="127.0.0.1", port=18000)
|