168 lines
7.6 KiB
Python
168 lines
7.6 KiB
Python
# -*- coding:utf-8 -*-
|
||
"""
|
||
File union_llm_node.py
|
||
Author silence
|
||
Date 2025/9/5
|
||
"""
|
||
import os
|
||
import requests
|
||
import base64
|
||
import mimetypes
|
||
import torch
|
||
import numpy as np
|
||
from PIL import Image
|
||
import folder_paths
|
||
|
||
tensor_to_file_map = {}
|
||
|
||
|
||
class LLMUionNode:
|
||
"""
|
||
一个聚合LLM节点。最终修复版,根据用户指正,彻底重构了执行逻辑,
|
||
确保代码的清晰、正确和稳定。
|
||
"""
|
||
MODELS = ['gemini-2.5-flash', 'gemini-2.5-pro', "gpt-4o-1120", "gpt-4.1"]
|
||
|
||
ENVIRONMENTS = ["prod", "dev", "test"]
|
||
ENV_URLS = {
|
||
"prod": 'https://bowongai-prod--text-video-agent-fastapi-app.modal.run',
|
||
"dev": 'https://bowongai-dev--text-video-agent-fastapi-app.modal.run',
|
||
"test": 'https://bowongai-test--text-video-agent-fastapi-app.modal.run'
|
||
}
|
||
|
||
@classmethod
|
||
def INPUT_TYPES(s):
|
||
return {
|
||
"required": {
|
||
"model_name": (s.MODELS,),
|
||
"prompt": ("STRING", { "multiline": True, "default": "详细描述这个视频" }),
|
||
},
|
||
"optional": {
|
||
"video_input": ("*",),
|
||
"image": ("IMAGE",),
|
||
"environment": (s.ENVIRONMENTS,),
|
||
"timeout": ("INT", {"default": 300, "min": 10, "max": 1200}),
|
||
}
|
||
}
|
||
|
||
RETURN_TYPES = ("STRING",)
|
||
RETURN_NAMES = ("text",)
|
||
FUNCTION = "execute"
|
||
CATEGORY = "不忘科技-自定义节点🚩/LLM"
|
||
|
||
def tensor_to_pil(self, tensor):
|
||
if tensor is None: return None
|
||
image_np = tensor[0].cpu().numpy()
|
||
image_np = (image_np * 255).astype(np.uint8)
|
||
return Image.fromarray(image_np)
|
||
|
||
def save_pil_to_temp(self, pil_image):
|
||
output_dir = folder_paths.get_temp_directory()
|
||
(full_output_folder, filename, counter, _, _) = folder_paths.get_save_image_path("llm_temp_image", output_dir)
|
||
filepath = os.path.join(full_output_folder, f"{filename}_{counter:05}.png")
|
||
pil_image.save(filepath, 'PNG')
|
||
return filepath
|
||
|
||
# --- API 处理函数无需改变, 它们接收文件路径 ---
|
||
def handler_google_analytics(self, prompt: str, model_id: str, media_file_path: str, base_url: str, timeout: int):
|
||
headers = {'accept': 'application/json'}
|
||
files = {'prompt': (None, prompt), 'model_id': (None, model_id)}
|
||
if media_file_path and os.path.exists(media_file_path):
|
||
files['img_file'] = (os.path.basename(media_file_path), open(media_file_path, 'rb'), mimetypes.guess_type(media_file_path)[0] or 'application/octet-stream')
|
||
try:
|
||
response = requests.post(f'{base_url}/api/llm/google/analysis', headers=headers, files=files, timeout=timeout)
|
||
response.raise_for_status()
|
||
resp_json = response.json()
|
||
result = resp_json.get('data') if resp_json else None
|
||
return result or f"API返回成功,但没有有效的 'data' 内容。 响应: {response.text}"
|
||
except requests.RequestException as e:
|
||
return f"Error calling Gemini API: {str(e)}"
|
||
|
||
def handler_other_llm(self, model_name: str, prompt: str, media_path: str, timeout: int):
|
||
messages_content = [{"type": "text", "text": prompt}]
|
||
if media_path and os.path.exists(media_path):
|
||
try:
|
||
with open(media_path, "rb") as media_file:
|
||
base64_media = base64.b64encode(media_file.read()).decode('utf-8')
|
||
mime_type = mimetypes.guess_type(media_path)[0] or "application/octet-stream"
|
||
data_url = f"data:{mime_type};base64,{base64_media}"
|
||
messages_content.append({"type": "image_url", "image_url": {"url": data_url}})
|
||
except Exception as e:
|
||
return f"Error encoding media file: {str(e)}"
|
||
|
||
json_payload = {"model": model_name, "messages": [{"role": "user", "content": messages_content}], "temperature": 0.7, "max_tokens": 4096}
|
||
try:
|
||
resp = requests.post("https://gateway.bowong.cc/chat/completions", headers={"Content-Type": "application/json", "Authorization": "Bearer auth-bowong7777"}, json=json_payload, timeout=timeout)
|
||
resp.raise_for_status()
|
||
resp_json = resp.json()
|
||
if 'choices' in resp_json and resp_json['choices']:
|
||
return resp_json['choices'][0]['message']['content']
|
||
else:
|
||
return f'Call LLM failed: {resp_json.get("error", {}).get("message", "LLM API returned no choices.")}'
|
||
except requests.RequestException as e:
|
||
return f"Error calling other LLM API: {str(e)}"
|
||
|
||
def execute(self, model_name: str, prompt: str, environment: str = "prod",
|
||
video_input: object = None, image: torch.Tensor = None, timeout=300):
|
||
|
||
base_url = self.ENV_URLS.get(environment, self.ENV_URLS["prod"])
|
||
media_path = None
|
||
|
||
# --- **最终的、唯一的、正确的修复逻辑** ---
|
||
|
||
# 优先级 1: 处理 video_input
|
||
if video_input is not None:
|
||
unwrapped_input = video_input[0] if isinstance(video_input, (list, tuple)) and video_input else video_input
|
||
|
||
# 检查是否是支持 save_to() 的视频对象
|
||
if hasattr(unwrapped_input, 'save_to'):
|
||
try:
|
||
output_dir = folder_paths.get_temp_directory()
|
||
(full_output_folder, filename, counter, _, _) = folder_paths.get_save_image_path("llm_temp_video", output_dir)
|
||
temp_video_path = os.path.join(full_output_folder, f"{filename}_{counter:05}.mp4")
|
||
|
||
print(f"检测到视频对象,使用 save_to() 保存到: {temp_video_path}")
|
||
unwrapped_input.save_to(temp_video_path)
|
||
|
||
if os.path.exists(temp_video_path):
|
||
media_path = temp_video_path
|
||
else:
|
||
return (f"错误: 调用 save_to() 后文件未成功创建。",)
|
||
except Exception as e:
|
||
return (f"调用 save_to() 时出错: {e}",)
|
||
|
||
# 兼容处理字符串输入的情况
|
||
elif isinstance(unwrapped_input, str):
|
||
filename = unwrapped_input
|
||
print(f"检测到字符串输入,作为文件名处理: '{filename}'")
|
||
full_path = folder_paths.get_full_path("input", filename)
|
||
if full_path and os.path.exists(full_path):
|
||
media_path = full_path
|
||
else:
|
||
return (f"错误: 无法在 'input' 文件夹中找到文件 '{filename}'。",)
|
||
|
||
# 优先级 2: 如果没有处理 video_input,再处理 image
|
||
elif image is not None:
|
||
print("检测到图像输入, 正在保存为临时文件...")
|
||
pil_image = self.tensor_to_pil(image)
|
||
media_path = self.save_pil_to_temp(pil_image)
|
||
|
||
# 优先级 3: 纯文本模式
|
||
else:
|
||
print("未提供媒体文件, 以纯文本模式运行。")
|
||
|
||
if media_path:
|
||
print(f"成功解析媒体文件路径: {media_path}")
|
||
|
||
# 分发到 API handlers
|
||
model_name = model_name.strip()
|
||
if model_name.startswith('gemini'):
|
||
result = self.handler_google_analytics(prompt, model_name, media_path, base_url=base_url, timeout=timeout)
|
||
else:
|
||
result = self.handler_other_llm(model_name, prompt, media_path, timeout=timeout)
|
||
|
||
return (result,)
|
||
|
||
|
||
# NODE_CLASS_MAPPINGS = { "LLMUionNode": LLMUionNode }
|
||
# NODE_DISPLAY_NAME_MAPPINGS = { "LLMUionNode": "聚合LLM节点(视频/图像)" } |