ComfyUI-CustomNode/nodes/video_agent.py

265 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
File video_agent.py
Author charon
Date 2025/9/4 23:01
"""
import io
import re
import time
import numpy as np
import requests
from PIL import Image
try:
from loguru import logger
except ImportError:
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger("VideoAPINode_Final")
print("提示: loguru 未安装使用内置logging。建议安装以获得更好的日志体验: pip install loguru")
def fetch_and_process_models():
video_urls = {
"prod": "https://bowongai-prod--text-video-agent-fastapi-app.modal.run/api/custom/model/list?category=video",
"dev": "https://bowongai-dev--text-video-agent-fastapi-app.modal.run/api/custom/model/list?category=video",
"test": "https://bowongai-test--text-video-agent-fastapi-app.modal.run/api/custom/model/list?category=video"
}
frame_urls = {
"prod": "https://bowongai-prod--text-video-agent-fastapi-app.modal.run/api/custom/extend/model/list",
"dev": "https://bowongai-dev--text-video-agent-fastapi-app.modal.run/api/custom/extend/model/list",
"test": "https://bowongai-test--text-video-agent-fastapi-app.modal.run/api/custom/extend/model/list"
}
model_data = {
"configs": {},
"full_display_list": [],
"display_to_tech_name": {},
"temp_list_for_sorting": []
}
def process_response(response, is_frame_api_source=False):
data = response.json()
if not data.get("status") or "data" not in data:
raise ValueError(f"API响应格式错误: {data.get('msg', '未知错误')}")
for model in data["data"]:
original_tech_name = model.get("model_name")
mode = model.get("mode")
if not original_tech_name: continue
tech_name = f"frame/{original_tech_name}" if is_frame_api_source else original_tech_name
description = model.get("description", tech_name)
display_name = f"{description} ({tech_name})"
model_data["configs"][tech_name] = model
model_data["display_to_tech_name"][display_name] = tech_name
sort_key = 99
if is_frame_api_source:
sort_key = 3
elif mode == "i2v":
sort_key = 2
elif mode == "both":
sort_key = 1
elif mode == "t2v":
sort_key = 0
model_data["temp_list_for_sorting"].append((sort_key, display_name))
try:
video_response = None
for u in video_urls.values():
print(f'start request config from:{u}')
try:
video_response = requests.get(u, timeout=10, headers={
'accept': 'application/json'})
video_response.raise_for_status()
print(f'config response:{video_response.text}')
break
except:
continue
if video_response: process_response(video_response, is_frame_api_source=False)
except Exception as e:
logger.error(f"常规模型加载失败: {e}")
try:
frame_response = None
for u in frame_urls.values():
try:
frame_response = requests.get(u, timeout=10, headers={
'accept': 'application/json'})
frame_response.raise_for_status()
break
except:
continue
if frame_response: process_response(frame_response, is_frame_api_source=True)
except Exception as e:
logger.error(f"首尾帧模型加载失败: {e}")
model_data["temp_list_for_sorting"].sort(key=lambda x: x[0])
model_data["full_display_list"] = [item[1] for item in model_data["temp_list_for_sorting"]]
if not model_data["full_display_list"]: model_data["full_display_list"] = ["错误:无法加载模型"]
return model_data
MODEL_DATA = fetch_and_process_models()
class VideoSubmitNode:
MODEL_DATA = MODEL_DATA
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("data",)
CATEGORY = "不忘科技-自定义节点🚩/api/视频生成"
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model_name_display": (cls.MODEL_DATA["full_display_list"],),
"prompt": ("STRING", {"multiline": True, "default": "","placeholder": "请输入提示词"}),
"aspect_ratio": ("STRING", {"multiline": False, "default": "9:16"}),
"duration": ("STRING", {"multiline": False, "default": "5"}),
"resolution": ("STRING", {"multiline": False, "default": "720p"}),
"environment": (["prod", "dev", "test"], {"default": "prod"}),
},
"optional": {
"head_image": ("IMAGE", {"description": "首帧图片"}),
"tail_image": ("IMAGE", {"description": "尾帧图片"}),
}
}
FUNCTION = "submit_task"
def _get_base_url_and_tech_name(self, environment, model_name_display):
env_map = {"prod": "https://bowongai-prod--text-video-agent-fastapi-app.modal.run",
"dev": "https://bowongai-dev--text-video-agent-fastapi-app.modal.run",
"test": "https://bowongai-test--text-video-agent-fastapi-app.modal.run"}
base_url = env_map.get(environment, env_map["prod"])
tech_name = self.MODEL_DATA["display_to_tech_name"].get(model_name_display) or (
re.search(r'\((.*?)\)', model_name_display).group(1) if re.search(r'\((.*?)\)',
model_name_display) else model_name_display)
logger.info(f"模型: '{model_name_display}' -> '{tech_name}'")
return base_url, tech_name
def _upload_file_2cdn(self, tensor_img, base_url: str):
img_tensor = tensor_img[0]
img_np = np.clip(255. * img_tensor.cpu().numpy(), 0, 255).astype(np.uint8)
pil_image = Image.fromarray(img_np)
buffer = io.BytesIO()
pil_image.save(buffer, format="PNG")
buffer.seek(0)
file_name = f'{time.time_ns()}.png'
mime_type = 'image/png'
files = {'file': (file_name, buffer, mime_type)}
response = requests.post(f'{base_url}/api/file/upload/s3', headers={'accept': 'application/json'}, files=files)
response.raise_for_status()
resp_json = response.json()
if resp_json.get('status'):
return resp_json.get('data')
else:
raise ValueError(resp_json.get('msg', '上传文件失败'))
def _handler_base_video_task(self, prompt, model_name, aspect_ratio, duration, resolution, base_url,
head_image=None):
headers = {'accept': 'application/json'}
payload = {'prompt': (None, prompt), 'model_name': (None, model_name), 'duration': (None, duration),
'resolution': (None, resolution), 'aspect_ratio': (None, aspect_ratio),
'webhook_flag': (None, 'false')}
files = {}
if head_image is not None:
img_tensor = head_image[0]
img_np = np.clip(255. * img_tensor.cpu().numpy(), 0, 255).astype(np.uint8)
pil_image = Image.fromarray(img_np)
buffer = io.BytesIO()
pil_image.save(buffer, format="PNG")
buffer.seek(0)
files['img_file'] = (f'{time.time_ns()}.png', buffer, 'image/png')
files.update(payload)
api_endpoint = f'{base_url}/api/custom/video/submit/task'
response = requests.post(api_endpoint, headers=headers, files=files, timeout=90)
response.raise_for_status()
resp_json = response.json()
if resp_json.get('status'):
return resp_json.get('data')
else:
error_msg = resp_json.get('msg', '未知API错误')
raise ValueError(f"API返回失败: {error_msg}")
def _handler_frame_video_task(self, prompt, model_name, aspect_ratio, duration, resolution, base_url, head_image,
tail_image):
model_name_for_api = model_name.replace('frame/', '')
head_img_url = self._upload_file_2cdn(head_image, base_url)
tail_img_url = self._upload_file_2cdn(tail_image, base_url)
data = {'prompt': prompt, 'head_img_url': head_img_url, 'tail_img_url': tail_img_url,
'model_name': model_name_for_api, 'duration': duration, 'aspect_ratio': aspect_ratio,
'resolution': resolution, 'webhook_flag': 'false'}
response = requests.post(f'{base_url}/api/custom/extend/frame/submit/task',
headers={'accept': 'application/json'}, data=data)
response.raise_for_status()
resp_json = response.json()
if resp_json.get('status'):
return resp_json.get('data')
else:
raise RuntimeError(resp_json.get('msg', '任务失败'))
def submit_task(self, model_name_display, prompt, aspect_ratio, duration, resolution, environment, head_image=None,
tail_image=None):
try:
base_url, tech_name = self._get_base_url_and_tech_name(environment, model_name_display)
model_config = self.MODEL_DATA["configs"].get(tech_name)
if not model_config:
raise ValueError(f"无法找到模型 '{tech_name}' 的配置。")
is_frame_model = tech_name.startswith('frame/')
def validate_and_correct_parameter(param_name, user_value, supported_values):
if not supported_values:
return user_value
if user_value in supported_values:
return user_value
default_value = supported_values[0]
logger.warning(
f"参数警告!模型 '{tech_name}' 不支持 '{param_name}': '{user_value}'"
f"已自动替换为支持的默认值: '{default_value}'。支持的选项: {supported_values}"
)
return default_value
final_ar = aspect_ratio
final_res = resolution
final_dur = validate_and_correct_parameter("时长", duration, model_config.get("supported_duration", []))
if is_frame_model:
if head_image is None or tail_image is None: raise ValueError(
"您选择了[首尾帧]模型,必须同时提供 'head_image''tail_image' 输入。")
result = self._handler_frame_video_task(prompt, tech_name, final_ar, final_dur, final_res, base_url,
head_image, tail_image)
else:
image_to_pass = None
true_model_mode = model_config.get('mode')
if true_model_mode == 'i2v':
if head_image is None: raise ValueError("您选择了[图]模型,必须提供 'head_image' 输入。")
image_to_pass = head_image
elif true_model_mode == 't2v':
if head_image is not None: logger.warning("您选择了[文]模型,连接的'head_image'将被忽略。")
elif true_model_mode == 'both':
image_to_pass = head_image
result = self._handler_base_video_task(prompt, tech_name, final_ar, final_dur, final_res, base_url,
image_to_pass)
return (result,)
except Exception as e:
logger.error(f"任务处理失败: {e}")
return (f"错误: {str(e)}",)
# NODE_CLASS_MAPPINGS = {
# "VideoSubmitNode": VideoSubmitNode,
# }
# NODE_DISPLAY_NAME_MAPPINGS = {
# "VideoSubmitNode": "统一视频生成节点",
# }