146 lines
5.5 KiB
Python
146 lines
5.5 KiB
Python
import comfy.utils
|
||
import folder_paths
|
||
import time
|
||
import requests
|
||
import torch
|
||
import numpy as np
|
||
from PIL import Image
|
||
import base64
|
||
from io import BytesIO
|
||
import json
|
||
from urllib.parse import urlparse
|
||
import os
|
||
|
||
|
||
class FetchTaskResult:
|
||
# 1. 定义环境 URL 映射
|
||
ENV_URLS = {
|
||
"prod": "https://bowongai-prod--text-video-agent-fastapi-app.modal.run",
|
||
"test": "https://bowongai-test--text-video-agent-fastapi-app.modal.run",
|
||
"dev": "https://bowongai-dev--text-video-agent-fastapi-app.modal.run"
|
||
}
|
||
|
||
@classmethod
|
||
def INPUT_TYPES(s):
|
||
return {
|
||
"required": {
|
||
"env": (list(s.ENV_URLS.keys()),), # 创建一个包含 "prod", "test", "dev" 的下拉列表
|
||
"task_id": ("STRING", {"default": ""}),
|
||
"interval": ("INT", {"default": 2, "min": 1, "max": 60}),
|
||
"timeout": ("INT", {"default": 300, "min": 10, "max": 3600}),
|
||
},
|
||
}
|
||
|
||
RETURN_TYPES = ("IMAGE", "STRING", "STRING")
|
||
RETURN_NAMES = ("images", "video_urls", "raw_response")
|
||
FUNCTION = "execute"
|
||
|
||
CATEGORY = "不忘科技-自定义节点🚩/utils/获取结果"
|
||
|
||
def execute(self, env, task_id, interval, timeout):
|
||
# 4. 根据选择的 env 从映射中获取 base_url
|
||
base_url = self.ENV_URLS[env]
|
||
|
||
if not task_id:
|
||
raise ValueError("Task ID 不能为空 (Task ID cannot be empty)")
|
||
|
||
headers = {} # 如果需要,可在此处添加 headers
|
||
|
||
start_time = time.time()
|
||
while time.time() - start_time < timeout:
|
||
try:
|
||
params = {'task_id': task_id}
|
||
print(f"[{env}] 正在轮询: {base_url}/api/custom/task/status?task_id={task_id}")
|
||
response = requests.get(
|
||
f'{base_url}/api/custom/task/status', params=params, headers=headers)
|
||
response.raise_for_status()
|
||
|
||
data_ = response.json()
|
||
print(f'原始响应结果:{data_}')
|
||
api_status = data_.get('status')
|
||
data = data_.get('data', [])
|
||
|
||
if isinstance(api_status, bool):
|
||
if not api_status:
|
||
raise ValueError(f'{data_["msg"]}')
|
||
print(f"任务 {task_id} 成功完成。正在分流处理媒体...")
|
||
|
||
image_tensors, video_urls = self.dispatch_media(data)
|
||
|
||
final_images = torch.cat(image_tensors, dim=0) if image_tensors else torch.empty(0, 64, 64, 3,
|
||
dtype=torch.float32)
|
||
final_urls = "\n".join(video_urls)
|
||
raw_response_str = json.dumps(data_, indent=2, ensure_ascii=False)
|
||
|
||
print(f"处理完成: {len(image_tensors)} 个图像, {len(video_urls)} 个视频URL。")
|
||
return final_images, final_urls, raw_response_str
|
||
|
||
print(f"任务未完成。API返回状态: {api_status}。将在 {interval} 秒后重试...")
|
||
time.sleep(interval)
|
||
|
||
except requests.exceptions.RequestException as e:
|
||
print(f"请求 API 失败: {e}. {interval} 秒后重试...")
|
||
time.sleep(interval)
|
||
except Exception as e:
|
||
print(f"处理任务时发生未知错误: {e}")
|
||
raise e
|
||
|
||
raise TimeoutError(f"轮询任务 {task_id} 超时 ({timeout} 秒)。")
|
||
|
||
def tensor_from_pil(self, img_pil):
|
||
return torch.from_numpy(np.array(img_pil).astype(np.float32) / 255.0)[None,]
|
||
|
||
def dispatch_media(self, data):
|
||
if not isinstance(data, list):
|
||
return [], []
|
||
|
||
image_tensors = []
|
||
video_urls = []
|
||
|
||
IMAGE_EXTS = ['.png', '.jpg', '.jpeg', '.bmp', '.webp']
|
||
VIDEO_EXTS = ['.mp4', '.webm', '.mkv', '.avi', '.mov']
|
||
|
||
for i, item in enumerate(data):
|
||
if not isinstance(item, str): continue
|
||
|
||
# 方案 A: 检查是否为 URL
|
||
if item.startswith(('http://', 'https://')):
|
||
try:
|
||
url_path = urlparse(item).path
|
||
ext = os.path.splitext(url_path)[1].lower()
|
||
|
||
if ext in IMAGE_EXTS:
|
||
print(f" -> 识别到图片URL,正在下载和处理...")
|
||
response = requests.get(item)
|
||
response.raise_for_status()
|
||
img = Image.open(BytesIO(response.content)).convert("RGB")
|
||
image_tensors.append(self.tensor_from_pil(img))
|
||
|
||
elif ext in VIDEO_EXTS:
|
||
print(f" -> 识别到视频URL,直接返回链接。")
|
||
video_urls.append(item)
|
||
|
||
else:
|
||
print(f" -> 识别到未知类型URL '{item}',已跳过。")
|
||
|
||
except Exception as e:
|
||
print(f" -> 处理URL时出错: {e}")
|
||
else:
|
||
try:
|
||
print(f" -> 尝试作为 Base64 图片解码...")
|
||
img_data = base64.b64decode(item)
|
||
img = Image.open(BytesIO(img_data)).convert("RGB")
|
||
image_tensors.append(self.tensor_from_pil(img))
|
||
except Exception:
|
||
print(f" -> 解码失败,该项不是有效的媒体。")
|
||
|
||
return image_tensors, video_urls
|
||
|
||
NODE_CLASS_MAPPINGS = {
|
||
"FetchTaskResult": FetchTaskResult
|
||
}
|
||
|
||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||
"FetchTaskResult": "获取生成结果 (图片/视频链接)"
|
||
}
|