ComfyUI-CustomNode/nodes/union_llm_node.py

168 lines
7.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding:utf-8 -*-
"""
File union_llm_node.py
Author silence
Date 2025/9/5
"""
import os
import requests
import base64
import mimetypes
import torch
import numpy as np
from PIL import Image
import folder_paths
tensor_to_file_map = {}
class LLMUionNode:
"""
一个聚合LLM节点。最终修复版根据用户指正彻底重构了执行逻辑
确保代码的清晰、正确和稳定。
"""
MODELS = ['gemini-2.5-flash', 'gemini-2.5-pro', "gpt-4o-1120", "gpt-4.1"]
ENVIRONMENTS = ["prod", "dev", "test"]
ENV_URLS = {
"prod": 'https://bowongai-prod--text-video-agent-fastapi-app.modal.run',
"dev": 'https://bowongai-dev--text-video-agent-fastapi-app.modal.run',
"test": 'https://bowongai-test--text-video-agent-fastapi-app.modal.run'
}
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model_name": (s.MODELS,),
"prompt": ("STRING", { "multiline": True, "default": "详细描述这个视频" }),
},
"optional": {
"video_input": ("*",),
"image": ("IMAGE",),
"environment": (s.ENVIRONMENTS,),
"timeout": ("INT", {"default": 300, "min": 10, "max": 1200}),
}
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("text",)
FUNCTION = "execute"
CATEGORY = "不忘科技-自定义节点🚩/LLM"
def tensor_to_pil(self, tensor):
if tensor is None: return None
image_np = tensor[0].cpu().numpy()
image_np = (image_np * 255).astype(np.uint8)
return Image.fromarray(image_np)
def save_pil_to_temp(self, pil_image):
output_dir = folder_paths.get_temp_directory()
(full_output_folder, filename, counter, _, _) = folder_paths.get_save_image_path("llm_temp_image", output_dir)
filepath = os.path.join(full_output_folder, f"{filename}_{counter:05}.png")
pil_image.save(filepath, 'PNG')
return filepath
# --- API 处理函数无需改变, 它们接收文件路径 ---
def handler_google_analytics(self, prompt: str, model_id: str, media_file_path: str, base_url: str, timeout: int):
headers = {'accept': 'application/json'}
files = {'prompt': (None, prompt), 'model_id': (None, model_id)}
if media_file_path and os.path.exists(media_file_path):
files['img_file'] = (os.path.basename(media_file_path), open(media_file_path, 'rb'), mimetypes.guess_type(media_file_path)[0] or 'application/octet-stream')
try:
response = requests.post(f'{base_url}/api/llm/google/analysis', headers=headers, files=files, timeout=timeout)
response.raise_for_status()
resp_json = response.json()
result = resp_json.get('data') if resp_json else None
return result or f"API返回成功但没有有效的 'data' 内容。 响应: {response.text}"
except requests.RequestException as e:
return f"Error calling Gemini API: {str(e)}"
def handler_other_llm(self, model_name: str, prompt: str, media_path: str, timeout: int):
messages_content = [{"type": "text", "text": prompt}]
if media_path and os.path.exists(media_path):
try:
with open(media_path, "rb") as media_file:
base64_media = base64.b64encode(media_file.read()).decode('utf-8')
mime_type = mimetypes.guess_type(media_path)[0] or "application/octet-stream"
data_url = f"data:{mime_type};base64,{base64_media}"
messages_content.append({"type": "image_url", "image_url": {"url": data_url}})
except Exception as e:
return f"Error encoding media file: {str(e)}"
json_payload = {"model": model_name, "messages": [{"role": "user", "content": messages_content}], "temperature": 0.7, "max_tokens": 4096}
try:
resp = requests.post("https://gateway.bowong.cc/chat/completions", headers={"Content-Type": "application/json", "Authorization": "Bearer auth-bowong7777"}, json=json_payload, timeout=timeout)
resp.raise_for_status()
resp_json = resp.json()
if 'choices' in resp_json and resp_json['choices']:
return resp_json['choices'][0]['message']['content']
else:
return f'Call LLM failed: {resp_json.get("error", {}).get("message", "LLM API returned no choices.")}'
except requests.RequestException as e:
return f"Error calling other LLM API: {str(e)}"
def execute(self, model_name: str, prompt: str, environment: str = "prod",
video_input: object = None, image: torch.Tensor = None, timeout=300):
base_url = self.ENV_URLS.get(environment, self.ENV_URLS["prod"])
media_path = None
# --- **最终的、唯一的、正确的修复逻辑** ---
# 优先级 1: 处理 video_input
if video_input is not None:
unwrapped_input = video_input[0] if isinstance(video_input, (list, tuple)) and video_input else video_input
# 检查是否是支持 save_to() 的视频对象
if hasattr(unwrapped_input, 'save_to'):
try:
output_dir = folder_paths.get_temp_directory()
(full_output_folder, filename, counter, _, _) = folder_paths.get_save_image_path("llm_temp_video", output_dir)
temp_video_path = os.path.join(full_output_folder, f"{filename}_{counter:05}.mp4")
print(f"检测到视频对象,使用 save_to() 保存到: {temp_video_path}")
unwrapped_input.save_to(temp_video_path)
if os.path.exists(temp_video_path):
media_path = temp_video_path
else:
return (f"错误: 调用 save_to() 后文件未成功创建。",)
except Exception as e:
return (f"调用 save_to() 时出错: {e}",)
# 兼容处理字符串输入的情况
elif isinstance(unwrapped_input, str):
filename = unwrapped_input
print(f"检测到字符串输入,作为文件名处理: '{filename}'")
full_path = folder_paths.get_full_path("input", filename)
if full_path and os.path.exists(full_path):
media_path = full_path
else:
return (f"错误: 无法在 'input' 文件夹中找到文件 '{filename}'",)
# 优先级 2: 如果没有处理 video_input再处理 image
elif image is not None:
print("检测到图像输入, 正在保存为临时文件...")
pil_image = self.tensor_to_pil(image)
media_path = self.save_pil_to_temp(pil_image)
# 优先级 3: 纯文本模式
else:
print("未提供媒体文件, 以纯文本模式运行。")
if media_path:
print(f"成功解析媒体文件路径: {media_path}")
# 分发到 API handlers
model_name = model_name.strip()
if model_name.startswith('gemini'):
result = self.handler_google_analytics(prompt, model_name, media_path, base_url=base_url, timeout=timeout)
else:
result = self.handler_other_llm(model_name, prompt, media_path, timeout=timeout)
return (result,)
# NODE_CLASS_MAPPINGS = { "LLMUionNode": LLMUionNode }
# NODE_DISPLAY_NAME_MAPPINGS = { "LLMUionNode": "聚合LLM节点(视频/图像)" }