241 lines
12 KiB
Python
241 lines
12 KiB
Python
# -*- coding:utf-8 -*-
|
||
"""
|
||
File union_llm_node.py
|
||
Author silence
|
||
Date 2025/9/5
|
||
"""
|
||
import os
|
||
import requests
|
||
import base64
|
||
import mimetypes
|
||
import torch
|
||
import httpx
|
||
import numpy as np
|
||
from PIL import Image
|
||
import folder_paths
|
||
|
||
try:
|
||
import scipy.io.wavfile as wavfile
|
||
except ImportError:
|
||
print("------------------------------------------------------------------------------------")
|
||
print("Scipy 库未安装, 请运行: pip install scipy")
|
||
print("------------------------------------------------------------------------------------")
|
||
|
||
|
||
def handler_google_analytics(prompt: str, model_id: str, media_file_path: str, base_url: str, timeout: int):
|
||
headers = {'accept': 'application/json'}
|
||
files = {'prompt': (None, prompt), 'model_id': (None, model_id)}
|
||
if media_file_path and os.path.exists(media_file_path):
|
||
files['img_file'] = (os.path.basename(media_file_path), open(media_file_path, 'rb'),
|
||
mimetypes.guess_type(media_file_path)[0] or 'application/octet-stream')
|
||
if bool(media_file_path) and media_file_path.startswith("gs:"):
|
||
files['img_url'] = (None, media_file_path)
|
||
try:
|
||
response = requests.post(f'{base_url}/api/llm/google/analysis', headers=headers, files=files,
|
||
timeout=timeout)
|
||
response.raise_for_status()
|
||
resp_json = response.json()
|
||
result = resp_json.get('data') if resp_json else None
|
||
return result or f"API返回成功,但没有有效的 'data' 内容。 响应: {response.text}"
|
||
except requests.RequestException as e:
|
||
return f"Error calling Gemini API: {str(e)}"
|
||
|
||
|
||
class LLMUionNode:
|
||
MODELS = ['gemini-2.5-flash', 'gemini-2.5-pro', "gpt-4o-1120", "gpt-4.1"]
|
||
|
||
ENVIRONMENTS = ["prod", "dev", "test"]
|
||
ENV_URLS = {
|
||
"prod": 'https://bowongai-prod--text-video-agent-fastapi-app.modal.run',
|
||
"dev": 'https://bowongai-dev--text-video-agent-fastapi-app.modal.run',
|
||
"test": 'https://bowongai-test--text-video-agent-fastapi-app.modal.run'
|
||
}
|
||
|
||
@classmethod
|
||
def INPUT_TYPES(s):
|
||
return {
|
||
"required": {
|
||
"model_name": (s.MODELS,),
|
||
"prompt": ("STRING", {"multiline": True, "default": "", "placeholder": "请输入提示词"}),
|
||
},
|
||
"optional": {
|
||
"video": ("*",),
|
||
"image": ("IMAGE",),
|
||
"audio": ("AUDIO",),
|
||
"url": ("STRING", {"multiline": True, "default": None, "placeholder": "【可选】输入要分析的链接"}),
|
||
"environment": (s.ENVIRONMENTS,),
|
||
"timeout": ("INT", {"default": 300, "min": 10, "max": 1200}),
|
||
}
|
||
}
|
||
|
||
RETURN_TYPES = ("STRING",)
|
||
RETURN_NAMES = ("text",)
|
||
FUNCTION = "execute"
|
||
CATEGORY = "不忘科技-自定义节点🚩/LLM"
|
||
|
||
def tensor_to_pil(self, tensor):
|
||
if tensor is None:
|
||
return None
|
||
image_np = tensor[0].cpu().numpy()
|
||
image_np = (image_np * 255).astype(np.uint8)
|
||
return Image.fromarray(image_np)
|
||
|
||
def save_pil_to_temp(self, pil_image):
|
||
output_dir = folder_paths.get_temp_directory()
|
||
(full_output_folder, filename, counter, _, _) = folder_paths.get_save_image_path("llm_temp_image", output_dir)
|
||
filepath = os.path.join(full_output_folder, f"{filename}_{counter:05}.png")
|
||
pil_image.save(filepath, 'PNG')
|
||
return filepath
|
||
|
||
def save_link_file(self, link_url: str, is_google: bool = False):
|
||
def download_file(url):
|
||
suffix = url.rsplit('.', 1)[-1]
|
||
response = httpx.get(url, timeout=120)
|
||
output_dir = folder_paths.get_temp_directory()
|
||
(full_output_folder, filename, counter, _, _) = folder_paths.get_save_image_path("llm_temp_image",
|
||
output_dir)
|
||
filepath = os.path.join(full_output_folder, f"{filename}_{counter:05}.{suffix}")
|
||
with open(filepath, 'wb') as f:
|
||
f.write(response.content)
|
||
return filepath
|
||
|
||
link_url = link_url.strip()
|
||
if is_google and link_url.startswith("gs:"):
|
||
return link_url
|
||
else:
|
||
return download_file(link_url)
|
||
|
||
def save_audio_tensor_to_temp(self, waveform_tensor, sample_rate):
|
||
if 'wavfile' not in globals():
|
||
raise ImportError("Scipy 库未安装。请在您的 ComfyUI 环境中运行 'pip install scipy' 来启用此功能。")
|
||
waveform_np = waveform_tensor.cpu().numpy()
|
||
if waveform_np.ndim == 3:
|
||
waveform_np = waveform_np[0]
|
||
|
||
waveform_np = waveform_np.T
|
||
waveform_int16 = np.int16(waveform_np * 32767)
|
||
output_dir = folder_paths.get_temp_directory()
|
||
(full_output_folder, filename, counter, _, _) = folder_paths.get_save_image_path("llm_temp_audio", output_dir)
|
||
filepath = os.path.join(full_output_folder, f"{filename}_{counter:05}.wav")
|
||
wavfile.write(filepath, sample_rate, waveform_int16)
|
||
print(f"音频张量已使用 Scipy 保存到临时文件: {filepath}")
|
||
return filepath
|
||
|
||
def handler_other_llm(self, model_name: str, prompt: str, media_path: str, timeout: int):
|
||
messages_content = [{"type": "text", "text": prompt}]
|
||
if media_path and os.path.exists(media_path):
|
||
try:
|
||
with open(media_path, "rb") as media_file:
|
||
base64_media = base64.b64encode(media_file.read()).decode('utf-8')
|
||
mime_type = mimetypes.guess_type(media_path)[0] or "application/octet-stream"
|
||
data_url = f"data:{mime_type};base64,{base64_media}"
|
||
messages_content.append({"type": "image_url", "image_url": {"url": data_url}})
|
||
except Exception as e:
|
||
return f"Error encoding media file: {str(e)}"
|
||
|
||
json_payload = {"model": model_name, "messages": [{"role": "user", "content": messages_content}],
|
||
"temperature": 0.7, "max_tokens": 4096}
|
||
try:
|
||
resp = requests.post("https://gateway.bowong.cc/chat/completions",
|
||
headers={"Content-Type": "application/json",
|
||
"Authorization": "Bearer auth-bowong7777"}, json=json_payload,
|
||
timeout=timeout)
|
||
resp.raise_for_status()
|
||
resp_json = resp.json()
|
||
if 'choices' in resp_json and resp_json['choices']:
|
||
return resp_json['choices'][0]['message']['content']
|
||
else:
|
||
return f'Call LLM failed: {resp_json.get("error", {}).get("message", "LLM API returned no choices.")}'
|
||
except requests.RequestException as e:
|
||
return f"Error calling other LLM API: {str(e)}"
|
||
|
||
def execute(self, model_name: str, prompt: str, environment: str = "prod",
|
||
video: object = None, image: torch.Tensor = None, audio: object = None,
|
||
url: str = None,
|
||
timeout=300):
|
||
base_url = self.ENV_URLS.get(environment, self.ENV_URLS["prod"])
|
||
media_path = None
|
||
url = url.strip()
|
||
if video is not None:
|
||
if 'gemini' not in model_name:
|
||
raise ValueError(f'{model_name}暂不支持视频分析,\n请使用gemini-2.5-flash或者gemini-2.5-pro')
|
||
print('多模态处理视频输入...')
|
||
unwrapped_input = video[0] if isinstance(video, (list, tuple)) and video else video
|
||
if hasattr(unwrapped_input, 'save_to'):
|
||
try:
|
||
output_dir = folder_paths.get_temp_directory()
|
||
(full_output_folder, filename, counter, _, _) = folder_paths.get_save_image_path("llm_temp_video",
|
||
output_dir)
|
||
temp_video_path = os.path.join(full_output_folder, f"{filename}_{counter:05}.mp4")
|
||
unwrapped_input.save_to(temp_video_path)
|
||
if os.path.exists(temp_video_path):
|
||
media_path = temp_video_path
|
||
else:
|
||
return (f"错误: 调用 save_to() 后文件未成功创建。",)
|
||
except Exception as e:
|
||
return (f"调用 save_to() 时出错: {e}",)
|
||
elif isinstance(unwrapped_input, str):
|
||
full_path = folder_paths.get_full_path("input", unwrapped_input)
|
||
if full_path and os.path.exists(full_path):
|
||
media_path = full_path
|
||
else:
|
||
return (f"错误: 无法在 'input' 文件夹中找到文件 '{unwrapped_input}'。",)
|
||
elif image is not None:
|
||
print('多模态处理图片输出...')
|
||
pil_image = self.tensor_to_pil(image)
|
||
media_path = self.save_pil_to_temp(pil_image)
|
||
|
||
elif audio is not None:
|
||
if 'gemini' not in model_name:
|
||
raise ValueError(f'{model_name}暂不支持音频分析,\n请使用gemini-2.5-flash或者gemini-2.5-pro')
|
||
print("多模态处理音频输入...")
|
||
audio_info = audio[0] if isinstance(audio, (list, tuple)) and audio else audio
|
||
|
||
if isinstance(audio_info, dict) and 'filename' in audio_info:
|
||
filename = audio_info['filename']
|
||
print(f"从音频对象中找到 'filename': '{filename}'")
|
||
full_path = folder_paths.get_full_path("input", filename)
|
||
if full_path and os.path.exists(full_path):
|
||
media_path = full_path
|
||
else:
|
||
return (f"错误: 无法在 'input' 文件夹中找到文件 '{filename}'。",)
|
||
|
||
elif isinstance(audio_info, dict) and 'waveform' in audio_info and 'sample_rate' in audio_info:
|
||
print("从音频对象中找到 'waveform' 数据,正在使用 Scipy 保存为临时文件...")
|
||
try:
|
||
media_path = self.save_audio_tensor_to_temp(audio_info['waveform'], audio_info['sample_rate'])
|
||
except Exception as e:
|
||
return (f"错误: 保存音频张量时出错: {e}",)
|
||
|
||
elif isinstance(audio_info, str):
|
||
print(f"检测到音频输入为字符串,作为文件名处理: '{audio_info}'")
|
||
full_path = folder_paths.get_full_path("input", audio_info)
|
||
if full_path and os.path.exists(full_path):
|
||
media_path = full_path
|
||
else:
|
||
return (f"错误: 无法在 'input' 文件夹中找到文件 '{audio_info}'。",)
|
||
|
||
else:
|
||
return (f"错误: 不支持的音频输入格式或结构。收到类型: {type(audio_info)}",)
|
||
elif url:
|
||
url = url.strip()
|
||
model_name = model_name.strip()
|
||
is_google = model_name.startswith('gemini')
|
||
media_path = self.save_link_file(link_url=url, is_google=is_google)
|
||
else:
|
||
print("纯文本运行llm")
|
||
if media_path:
|
||
print(f"成功解析媒体文件路径: {media_path}")
|
||
model_name = model_name.strip()
|
||
if model_name.startswith('gemini'):
|
||
result = handler_google_analytics(prompt, model_name, media_path, base_url=base_url, timeout=timeout)
|
||
else:
|
||
result = self.handler_other_llm(model_name, prompt, media_path, timeout=timeout)
|
||
|
||
return (result,)
|
||
|
||
|
||
# 节点映射
|
||
NODE_CLASS_MAPPINGS = {"LLMUionNode": LLMUionNode}
|
||
NODE_DISPLAY_NAME_MAPPINGS = {"LLMUionNode": "聚合LLM节点(视频/图像/音频)"}
|