ComfyUI-CustomNode/nodes/fetch_task_result.py

146 lines
5.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import comfy.utils
import folder_paths
import time
import requests
import torch
import numpy as np
from PIL import Image
import base64
from io import BytesIO
import json
from urllib.parse import urlparse
import os
class FetchTaskResult:
# 1. 定义环境 URL 映射
ENV_URLS = {
"prod": "https://bowongai-prod--text-video-agent-fastapi-app.modal.run",
"test": "https://bowongai-test--text-video-agent-fastapi-app.modal.run",
"dev": "https://bowongai-dev--text-video-agent-fastapi-app.modal.run"
}
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"env": (list(s.ENV_URLS.keys()),), # 创建一个包含 "prod", "test", "dev" 的下拉列表
"task_id": ("STRING", {"default": ""}),
"interval": ("INT", {"default": 2, "min": 1, "max": 60}),
"timeout": ("INT", {"default": 300, "min": 10, "max": 3600}),
},
}
RETURN_TYPES = ("IMAGE", "STRING", "STRING")
RETURN_NAMES = ("images", "video_urls", "raw_response")
FUNCTION = "execute"
CATEGORY = "不忘科技-自定义节点🚩/utils/获取结果"
def execute(self, env, task_id, interval, timeout):
# 4. 根据选择的 env 从映射中获取 base_url
base_url = self.ENV_URLS[env]
if not task_id:
raise ValueError("Task ID 不能为空 (Task ID cannot be empty)")
headers = {} # 如果需要,可在此处添加 headers
start_time = time.time()
while time.time() - start_time < timeout:
try:
params = {'task_id': task_id}
print(f"[{env}] 正在轮询: {base_url}/api/custom/task/status?task_id={task_id}")
response = requests.get(
f'{base_url}/api/custom/task/status', params=params, headers=headers)
response.raise_for_status()
data_ = response.json()
print(f'原始响应结果:{data_}')
api_status = data_.get('status')
data = data_.get('data', [])
if isinstance(api_status, bool):
if not api_status:
raise ValueError(f'{data_["msg"]}')
print(f"任务 {task_id} 成功完成。正在分流处理媒体...")
image_tensors, video_urls = self.dispatch_media(data)
final_images = torch.cat(image_tensors, dim=0) if image_tensors else torch.empty(0, 64, 64, 3,
dtype=torch.float32)
final_urls = "\n".join(video_urls)
raw_response_str = json.dumps(data_, indent=2, ensure_ascii=False)
print(f"处理完成: {len(image_tensors)} 个图像, {len(video_urls)} 个视频URL。")
return final_images, final_urls, raw_response_str
print(f"任务未完成。API返回状态: {api_status}。将在 {interval} 秒后重试...")
time.sleep(interval)
except requests.exceptions.RequestException as e:
print(f"请求 API 失败: {e}. {interval} 秒后重试...")
time.sleep(interval)
except Exception as e:
print(f"处理任务时发生未知错误: {e}")
raise e
raise TimeoutError(f"轮询任务 {task_id} 超时 ({timeout} 秒)。")
def tensor_from_pil(self, img_pil):
return torch.from_numpy(np.array(img_pil).astype(np.float32) / 255.0)[None,]
def dispatch_media(self, data):
if not isinstance(data, list):
return [], []
image_tensors = []
video_urls = []
IMAGE_EXTS = ['.png', '.jpg', '.jpeg', '.bmp', '.webp']
VIDEO_EXTS = ['.mp4', '.webm', '.mkv', '.avi', '.mov']
for i, item in enumerate(data):
if not isinstance(item, str): continue
# 方案 A: 检查是否为 URL
if item.startswith(('http://', 'https://')):
try:
url_path = urlparse(item).path
ext = os.path.splitext(url_path)[1].lower()
if ext in IMAGE_EXTS:
print(f" -> 识别到图片URL正在下载和处理...")
response = requests.get(item)
response.raise_for_status()
img = Image.open(BytesIO(response.content)).convert("RGB")
image_tensors.append(self.tensor_from_pil(img))
elif ext in VIDEO_EXTS:
print(f" -> 识别到视频URL直接返回链接。")
video_urls.append(item)
else:
print(f" -> 识别到未知类型URL '{item}',已跳过。")
except Exception as e:
print(f" -> 处理URL时出错: {e}")
else:
try:
print(f" -> 尝试作为 Base64 图片解码...")
img_data = base64.b64decode(item)
img = Image.open(BytesIO(img_data)).convert("RGB")
image_tensors.append(self.tensor_from_pil(img))
except Exception:
print(f" -> 解码失败,该项不是有效的媒体。")
return image_tensors, video_urls
NODE_CLASS_MAPPINGS = {
"FetchTaskResult": FetchTaskResult
}
NODE_DISPLAY_NAME_MAPPINGS = {
"FetchTaskResult": "获取生成结果 (图片/视频链接)"
}