265 lines
12 KiB
Python
265 lines
12 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""
|
||
File video_agent.py
|
||
Author charon
|
||
Date 2025/9/4 23:01
|
||
"""
|
||
import io
|
||
import re
|
||
import time
|
||
|
||
import numpy as np
|
||
import requests
|
||
from PIL import Image
|
||
|
||
try:
|
||
from loguru import logger
|
||
except ImportError:
|
||
import logging
|
||
|
||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||
logger = logging.getLogger("VideoAPINode_Final")
|
||
print("提示: loguru 未安装,使用内置logging。建议安装以获得更好的日志体验: pip install loguru")
|
||
|
||
|
||
def fetch_and_process_models():
|
||
video_urls = {
|
||
"prod": "https://bowongai-prod--text-video-agent-fastapi-app.modal.run/api/custom/model/list?category=video",
|
||
"dev": "https://bowongai-dev--text-video-agent-fastapi-app.modal.run/api/custom/model/list?category=video",
|
||
"test": "https://bowongai-test--text-video-agent-fastapi-app.modal.run/api/custom/model/list?category=video"
|
||
}
|
||
frame_urls = {
|
||
"prod": "https://bowongai-prod--text-video-agent-fastapi-app.modal.run/api/custom/extend/model/list",
|
||
"dev": "https://bowongai-dev--text-video-agent-fastapi-app.modal.run/api/custom/extend/model/list",
|
||
"test": "https://bowongai-test--text-video-agent-fastapi-app.modal.run/api/custom/extend/model/list"
|
||
}
|
||
|
||
model_data = {
|
||
"configs": {},
|
||
"full_display_list": [],
|
||
"display_to_tech_name": {},
|
||
"temp_list_for_sorting": []
|
||
}
|
||
|
||
def process_response(response, is_frame_api_source=False):
|
||
data = response.json()
|
||
if not data.get("status") or "data" not in data:
|
||
raise ValueError(f"API响应格式错误: {data.get('msg', '未知错误')}")
|
||
|
||
for model in data["data"]:
|
||
original_tech_name = model.get("model_name")
|
||
mode = model.get("mode")
|
||
if not original_tech_name: continue
|
||
|
||
tech_name = f"frame/{original_tech_name}" if is_frame_api_source else original_tech_name
|
||
description = model.get("description", tech_name)
|
||
display_name = f"{description} ({tech_name})"
|
||
|
||
model_data["configs"][tech_name] = model
|
||
model_data["display_to_tech_name"][display_name] = tech_name
|
||
|
||
sort_key = 99
|
||
if is_frame_api_source:
|
||
sort_key = 3
|
||
elif mode == "i2v":
|
||
sort_key = 2
|
||
elif mode == "both":
|
||
sort_key = 1
|
||
elif mode == "t2v":
|
||
sort_key = 0
|
||
model_data["temp_list_for_sorting"].append((sort_key, display_name))
|
||
|
||
try:
|
||
video_response = None
|
||
for u in video_urls.values():
|
||
print(f'start request config from:{u}')
|
||
try:
|
||
video_response = requests.get(u, timeout=10, headers={
|
||
'accept': 'application/json'})
|
||
video_response.raise_for_status()
|
||
print(f'config response:{video_response.text}')
|
||
break
|
||
except:
|
||
continue
|
||
if video_response: process_response(video_response, is_frame_api_source=False)
|
||
except Exception as e:
|
||
logger.error(f"常规模型加载失败: {e}")
|
||
try:
|
||
frame_response = None
|
||
for u in frame_urls.values():
|
||
try:
|
||
frame_response = requests.get(u, timeout=10, headers={
|
||
'accept': 'application/json'})
|
||
frame_response.raise_for_status()
|
||
break
|
||
except:
|
||
continue
|
||
if frame_response: process_response(frame_response, is_frame_api_source=True)
|
||
except Exception as e:
|
||
logger.error(f"首尾帧模型加载失败: {e}")
|
||
|
||
model_data["temp_list_for_sorting"].sort(key=lambda x: x[0])
|
||
model_data["full_display_list"] = [item[1] for item in model_data["temp_list_for_sorting"]]
|
||
|
||
if not model_data["full_display_list"]: model_data["full_display_list"] = ["错误:无法加载模型"]
|
||
|
||
return model_data
|
||
|
||
|
||
MODEL_DATA = fetch_and_process_models()
|
||
|
||
|
||
class VideoSubmitNode:
|
||
MODEL_DATA = MODEL_DATA
|
||
RETURN_TYPES = ("STRING",)
|
||
RETURN_NAMES = ("data",)
|
||
CATEGORY = "不忘科技-自定义节点🚩/api/视频生成"
|
||
|
||
@classmethod
|
||
def INPUT_TYPES(cls):
|
||
return {
|
||
"required": {
|
||
"model_name_display": (cls.MODEL_DATA["full_display_list"],),
|
||
"prompt": ("STRING", {"multiline": True, "default": "","placeholder": "请输入提示词"}),
|
||
"aspect_ratio": ("STRING", {"multiline": False, "default": "9:16"}),
|
||
"duration": ("STRING", {"multiline": False, "default": "5"}),
|
||
"resolution": ("STRING", {"multiline": False, "default": "720p"}),
|
||
"environment": (["prod", "dev", "test"], {"default": "prod"}),
|
||
},
|
||
"optional": {
|
||
"head_image": ("IMAGE", {"description": "首帧图片"}),
|
||
"tail_image": ("IMAGE", {"description": "尾帧图片"}),
|
||
}
|
||
}
|
||
|
||
FUNCTION = "submit_task"
|
||
|
||
def _get_base_url_and_tech_name(self, environment, model_name_display):
|
||
env_map = {"prod": "https://bowongai-prod--text-video-agent-fastapi-app.modal.run",
|
||
"dev": "https://bowongai-dev--text-video-agent-fastapi-app.modal.run",
|
||
"test": "https://bowongai-test--text-video-agent-fastapi-app.modal.run"}
|
||
base_url = env_map.get(environment, env_map["prod"])
|
||
tech_name = self.MODEL_DATA["display_to_tech_name"].get(model_name_display) or (
|
||
re.search(r'\((.*?)\)', model_name_display).group(1) if re.search(r'\((.*?)\)',
|
||
model_name_display) else model_name_display)
|
||
logger.info(f"模型: '{model_name_display}' -> '{tech_name}'")
|
||
return base_url, tech_name
|
||
|
||
def _upload_file_2cdn(self, tensor_img, base_url: str):
|
||
img_tensor = tensor_img[0]
|
||
img_np = np.clip(255. * img_tensor.cpu().numpy(), 0, 255).astype(np.uint8)
|
||
pil_image = Image.fromarray(img_np)
|
||
buffer = io.BytesIO()
|
||
pil_image.save(buffer, format="PNG")
|
||
buffer.seek(0)
|
||
file_name = f'{time.time_ns()}.png'
|
||
mime_type = 'image/png'
|
||
files = {'file': (file_name, buffer, mime_type)}
|
||
response = requests.post(f'{base_url}/api/file/upload/s3', headers={'accept': 'application/json'}, files=files)
|
||
response.raise_for_status()
|
||
resp_json = response.json()
|
||
if resp_json.get('status'):
|
||
return resp_json.get('data')
|
||
else:
|
||
raise ValueError(resp_json.get('msg', '上传文件失败'))
|
||
|
||
def _handler_base_video_task(self, prompt, model_name, aspect_ratio, duration, resolution, base_url,
|
||
head_image=None):
|
||
headers = {'accept': 'application/json'}
|
||
payload = {'prompt': (None, prompt), 'model_name': (None, model_name), 'duration': (None, duration),
|
||
'resolution': (None, resolution), 'aspect_ratio': (None, aspect_ratio),
|
||
'webhook_flag': (None, 'false')}
|
||
files = {}
|
||
if head_image is not None:
|
||
img_tensor = head_image[0]
|
||
img_np = np.clip(255. * img_tensor.cpu().numpy(), 0, 255).astype(np.uint8)
|
||
pil_image = Image.fromarray(img_np)
|
||
buffer = io.BytesIO()
|
||
pil_image.save(buffer, format="PNG")
|
||
buffer.seek(0)
|
||
files['img_file'] = (f'{time.time_ns()}.png', buffer, 'image/png')
|
||
files.update(payload)
|
||
api_endpoint = f'{base_url}/api/custom/video/submit/task'
|
||
response = requests.post(api_endpoint, headers=headers, files=files, timeout=90)
|
||
response.raise_for_status()
|
||
resp_json = response.json()
|
||
if resp_json.get('status'):
|
||
return resp_json.get('data')
|
||
else:
|
||
error_msg = resp_json.get('msg', '未知API错误')
|
||
raise ValueError(f"API返回失败: {error_msg}")
|
||
|
||
def _handler_frame_video_task(self, prompt, model_name, aspect_ratio, duration, resolution, base_url, head_image,
|
||
tail_image):
|
||
model_name_for_api = model_name.replace('frame/', '')
|
||
head_img_url = self._upload_file_2cdn(head_image, base_url)
|
||
tail_img_url = self._upload_file_2cdn(tail_image, base_url)
|
||
data = {'prompt': prompt, 'head_img_url': head_img_url, 'tail_img_url': tail_img_url,
|
||
'model_name': model_name_for_api, 'duration': duration, 'aspect_ratio': aspect_ratio,
|
||
'resolution': resolution, 'webhook_flag': 'false'}
|
||
response = requests.post(f'{base_url}/api/custom/extend/frame/submit/task',
|
||
headers={'accept': 'application/json'}, data=data)
|
||
response.raise_for_status()
|
||
resp_json = response.json()
|
||
if resp_json.get('status'):
|
||
return resp_json.get('data')
|
||
else:
|
||
raise RuntimeError(resp_json.get('msg', '任务失败'))
|
||
|
||
def submit_task(self, model_name_display, prompt, aspect_ratio, duration, resolution, environment, head_image=None,
|
||
tail_image=None):
|
||
try:
|
||
base_url, tech_name = self._get_base_url_and_tech_name(environment, model_name_display)
|
||
model_config = self.MODEL_DATA["configs"].get(tech_name)
|
||
if not model_config:
|
||
raise ValueError(f"无法找到模型 '{tech_name}' 的配置。")
|
||
|
||
is_frame_model = tech_name.startswith('frame/')
|
||
|
||
def validate_and_correct_parameter(param_name, user_value, supported_values):
|
||
if not supported_values:
|
||
return user_value
|
||
if user_value in supported_values:
|
||
return user_value
|
||
default_value = supported_values[0]
|
||
logger.warning(
|
||
f"参数警告!模型 '{tech_name}' 不支持 '{param_name}': '{user_value}'。"
|
||
f"已自动替换为支持的默认值: '{default_value}'。支持的选项: {supported_values}"
|
||
)
|
||
return default_value
|
||
|
||
final_ar = aspect_ratio
|
||
final_res = resolution
|
||
final_dur = validate_and_correct_parameter("时长", duration, model_config.get("supported_duration", []))
|
||
|
||
if is_frame_model:
|
||
if head_image is None or tail_image is None: raise ValueError(
|
||
"您选择了[首尾帧]模型,必须同时提供 'head_image' 和 'tail_image' 输入。")
|
||
result = self._handler_frame_video_task(prompt, tech_name, final_ar, final_dur, final_res, base_url,
|
||
head_image, tail_image)
|
||
else:
|
||
image_to_pass = None
|
||
true_model_mode = model_config.get('mode')
|
||
if true_model_mode == 'i2v':
|
||
if head_image is None: raise ValueError("您选择了[图]模型,必须提供 'head_image' 输入。")
|
||
image_to_pass = head_image
|
||
elif true_model_mode == 't2v':
|
||
if head_image is not None: logger.warning("您选择了[文]模型,连接的'head_image'将被忽略。")
|
||
elif true_model_mode == 'both':
|
||
image_to_pass = head_image
|
||
result = self._handler_base_video_task(prompt, tech_name, final_ar, final_dur, final_res, base_url,
|
||
image_to_pass)
|
||
|
||
return (result,)
|
||
except Exception as e:
|
||
logger.error(f"任务处理失败: {e}")
|
||
return (f"错误: {str(e)}",)
|
||
|
||
|
||
# NODE_CLASS_MAPPINGS = {
|
||
# "VideoSubmitNode": VideoSubmitNode,
|
||
# }
|
||
# NODE_DISPLAY_NAME_MAPPINGS = {
|
||
# "VideoSubmitNode": "统一视频生成节点",
|
||
# }
|