ComfyUI-CustomNode/nodes/union_llm_node.py

241 lines
12 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding:utf-8 -*-
"""
File union_llm_node.py
Author silence
Date 2025/9/5
"""
import os
import requests
import base64
import mimetypes
import torch
import httpx
import numpy as np
from PIL import Image
import folder_paths
try:
import scipy.io.wavfile as wavfile
except ImportError:
print("------------------------------------------------------------------------------------")
print("Scipy 库未安装, 请运行: pip install scipy")
print("------------------------------------------------------------------------------------")
def handler_google_analytics(prompt: str, model_id: str, media_file_path: str, base_url: str, timeout: int):
headers = {'accept': 'application/json'}
files = {'prompt': (None, prompt), 'model_id': (None, model_id)}
if media_file_path and os.path.exists(media_file_path):
files['img_file'] = (os.path.basename(media_file_path), open(media_file_path, 'rb'),
mimetypes.guess_type(media_file_path)[0] or 'application/octet-stream')
if bool(media_file_path) and media_file_path.startswith("gs:"):
files['img_url'] = (None, media_file_path)
try:
response = requests.post(f'{base_url}/api/llm/google/analysis', headers=headers, files=files,
timeout=timeout)
response.raise_for_status()
resp_json = response.json()
result = resp_json.get('data') if resp_json else None
return result or f"API返回成功但没有有效的 'data' 内容。 响应: {response.text}"
except requests.RequestException as e:
return f"Error calling Gemini API: {str(e)}"
class LLMUionNode:
MODELS = ['gemini-2.5-flash', 'gemini-2.5-pro', "gpt-4o-1120", "gpt-4.1"]
ENVIRONMENTS = ["prod", "dev", "test"]
ENV_URLS = {
"prod": 'https://bowongai-prod--text-video-agent-fastapi-app.modal.run',
"dev": 'https://bowongai-dev--text-video-agent-fastapi-app.modal.run',
"test": 'https://bowongai-test--text-video-agent-fastapi-app.modal.run'
}
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model_name": (s.MODELS,),
"prompt": ("STRING", {"multiline": True, "default": "", "placeholder": "请输入提示词"}),
},
"optional": {
"video": ("*",),
"image": ("IMAGE",),
"audio": ("AUDIO",),
"url": ("STRING", {"multiline": True, "default": None, "placeholder": "【可选】输入要分析的链接"}),
"environment": (s.ENVIRONMENTS,),
"timeout": ("INT", {"default": 300, "min": 10, "max": 1200}),
}
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("text",)
FUNCTION = "execute"
def tensor_to_pil(self, tensor):
if tensor is None:
return None
image_np = tensor[0].cpu().numpy()
image_np = (image_np * 255).astype(np.uint8)
return Image.fromarray(image_np)
def save_pil_to_temp(self, pil_image):
output_dir = folder_paths.get_temp_directory()
(full_output_folder, filename, counter, _, _) = folder_paths.get_save_image_path("llm_temp_image", output_dir)
filepath = os.path.join(full_output_folder, f"{filename}_{counter:05}.png")
pil_image.save(filepath, 'PNG')
return filepath
def save_link_file(self, link_url: str, is_google: bool = False):
def download_file(url):
suffix = url.rsplit('.', 1)[-1]
response = httpx.get(url, timeout=120)
output_dir = folder_paths.get_temp_directory()
(full_output_folder, filename, counter, _, _) = folder_paths.get_save_image_path("llm_temp_image",
output_dir)
filepath = os.path.join(full_output_folder, f"{filename}_{counter:05}.{suffix}")
with open(filepath, 'wb') as f:
f.write(response.content)
return filepath
link_url = link_url.strip()
if is_google and link_url.startswith("gs:"):
return link_url
else:
return download_file(link_url)
def save_audio_tensor_to_temp(self, waveform_tensor, sample_rate):
if 'wavfile' not in globals():
raise ImportError("Scipy 库未安装。请在您的 ComfyUI 环境中运行 'pip install scipy' 来启用此功能。")
waveform_np = waveform_tensor.cpu().numpy()
if waveform_np.ndim == 3:
waveform_np = waveform_np[0]
waveform_np = waveform_np.T
waveform_int16 = np.int16(waveform_np * 32767)
output_dir = folder_paths.get_temp_directory()
(full_output_folder, filename, counter, _, _) = folder_paths.get_save_image_path("llm_temp_audio", output_dir)
filepath = os.path.join(full_output_folder, f"{filename}_{counter:05}.wav")
wavfile.write(filepath, sample_rate, waveform_int16)
print(f"音频张量已使用 Scipy 保存到临时文件: {filepath}")
return filepath
def handler_other_llm(self, model_name: str, prompt: str, media_path: str, timeout: int):
messages_content = [{"type": "text", "text": prompt}]
if media_path and os.path.exists(media_path):
try:
with open(media_path, "rb") as media_file:
base64_media = base64.b64encode(media_file.read()).decode('utf-8')
mime_type = mimetypes.guess_type(media_path)[0] or "application/octet-stream"
data_url = f"data:{mime_type};base64,{base64_media}"
messages_content.append({"type": "image_url", "image_url": {"url": data_url}})
except Exception as e:
return f"Error encoding media file: {str(e)}"
json_payload = {"model": model_name, "messages": [{"role": "user", "content": messages_content}],
"temperature": 0.7, "max_tokens": 4096}
try:
resp = requests.post("https://gateway.bowong.cc/chat/completions",
headers={"Content-Type": "application/json",
"Authorization": "Bearer auth-bowong7777"}, json=json_payload,
timeout=timeout)
resp.raise_for_status()
resp_json = resp.json()
if 'choices' in resp_json and resp_json['choices']:
return resp_json['choices'][0]['message']['content']
else:
return f'Call LLM failed: {resp_json.get("error", {}).get("message", "LLM API returned no choices.")}'
except requests.RequestException as e:
return f"Error calling other LLM API: {str(e)}"
def execute(self, model_name: str, prompt: str, environment: str = "prod",
video: object = None, image: torch.Tensor = None, audio: object = None,
url: str = None,
timeout=300):
base_url = self.ENV_URLS.get(environment, self.ENV_URLS["prod"])
media_path = None
url = url.strip()
if video is not None:
if 'gemini' not in model_name:
raise ValueError(f'{model_name}暂不支持视频分析,\n请使用gemini-2.5-flash或者gemini-2.5-pro')
print('多模态处理视频输入...')
unwrapped_input = video[0] if isinstance(video, (list, tuple)) and video else video
if hasattr(unwrapped_input, 'save_to'):
try:
output_dir = folder_paths.get_temp_directory()
(full_output_folder, filename, counter, _, _) = folder_paths.get_save_image_path("llm_temp_video",
output_dir)
temp_video_path = os.path.join(full_output_folder, f"{filename}_{counter:05}.mp4")
unwrapped_input.save_to(temp_video_path)
if os.path.exists(temp_video_path):
media_path = temp_video_path
else:
return (f"错误: 调用 save_to() 后文件未成功创建。",)
except Exception as e:
return (f"调用 save_to() 时出错: {e}",)
elif isinstance(unwrapped_input, str):
full_path = folder_paths.get_full_path("input", unwrapped_input)
if full_path and os.path.exists(full_path):
media_path = full_path
else:
return (f"错误: 无法在 'input' 文件夹中找到文件 '{unwrapped_input}'",)
elif image is not None:
print('多模态处理图片输出...')
pil_image = self.tensor_to_pil(image)
media_path = self.save_pil_to_temp(pil_image)
elif audio is not None:
if 'gemini' not in model_name:
raise ValueError(f'{model_name}暂不支持音频分析,\n请使用gemini-2.5-flash或者gemini-2.5-pro')
print("多模态处理音频输入...")
audio_info = audio[0] if isinstance(audio, (list, tuple)) and audio else audio
if isinstance(audio_info, dict) and 'filename' in audio_info:
filename = audio_info['filename']
print(f"从音频对象中找到 'filename': '{filename}'")
full_path = folder_paths.get_full_path("input", filename)
if full_path and os.path.exists(full_path):
media_path = full_path
else:
return (f"错误: 无法在 'input' 文件夹中找到文件 '{filename}'",)
elif isinstance(audio_info, dict) and 'waveform' in audio_info and 'sample_rate' in audio_info:
print("从音频对象中找到 'waveform' 数据,正在使用 Scipy 保存为临时文件...")
try:
media_path = self.save_audio_tensor_to_temp(audio_info['waveform'], audio_info['sample_rate'])
except Exception as e:
return (f"错误: 保存音频张量时出错: {e}",)
elif isinstance(audio_info, str):
print(f"检测到音频输入为字符串,作为文件名处理: '{audio_info}'")
full_path = folder_paths.get_full_path("input", audio_info)
if full_path and os.path.exists(full_path):
media_path = full_path
else:
return (f"错误: 无法在 'input' 文件夹中找到文件 '{audio_info}'",)
else:
return (f"错误: 不支持的音频输入格式或结构。收到类型: {type(audio_info)}",)
elif url:
url = url.strip()
model_name = model_name.strip()
is_google = model_name.startswith('gemini')
media_path = self.save_link_file(link_url=url, is_google=is_google)
else:
print("纯文本运行llm")
if media_path:
print(f"成功解析媒体文件路径: {media_path}")
model_name = model_name.strip()
if model_name.startswith('gemini'):
result = handler_google_analytics(prompt, model_name, media_path, base_url=base_url, timeout=timeout)
else:
result = self.handler_other_llm(model_name, prompt, media_path, timeout=timeout)
return (result,)
# 节点映射
NODE_CLASS_MAPPINGS = {"LLMUionNode": LLMUionNode}
NODE_DISPLAY_NAME_MAPPINGS = {"LLMUionNode": "聚合LLM节点(视频/图像/音频)"}