# -*- coding:utf-8 -*- """ File union_llm_node.py Author silence Date 2025/9/5 """ import os import requests import base64 import mimetypes import torch import httpx import numpy as np from PIL import Image import folder_paths try: import scipy.io.wavfile as wavfile except ImportError: print("------------------------------------------------------------------------------------") print("Scipy 库未安装, 请运行: pip install scipy") print("------------------------------------------------------------------------------------") def handler_google_analytics(prompt: str, model_id: str, media_file_path: str, base_url: str, timeout: int): headers = {'accept': 'application/json'} files = {'prompt': (None, prompt), 'model_id': (None, model_id)} if media_file_path and os.path.exists(media_file_path): files['img_file'] = (os.path.basename(media_file_path), open(media_file_path, 'rb'), mimetypes.guess_type(media_file_path)[0] or 'application/octet-stream') if bool(media_file_path) and media_file_path.startswith("gs:"): files['img_url'] = (None, media_file_path) try: response = requests.post(f'{base_url}/api/llm/google/analysis', headers=headers, files=files, timeout=timeout) response.raise_for_status() resp_json = response.json() result = resp_json.get('data') if resp_json else None return result or f"API返回成功,但没有有效的 'data' 内容。 响应: {response.text}" except requests.RequestException as e: return f"Error calling Gemini API: {str(e)}" class LLMUionNode: MODELS = ['gemini-2.5-flash', 'gemini-2.5-pro', "gpt-4o-1120", "gpt-4.1"] ENVIRONMENTS = ["prod", "dev", "test"] ENV_URLS = { "prod": 'https://bowongai-prod--text-video-agent-fastapi-app.modal.run', "dev": 'https://bowongai-dev--text-video-agent-fastapi-app.modal.run', "test": 'https://bowongai-test--text-video-agent-fastapi-app.modal.run' } @classmethod def INPUT_TYPES(s): return { "required": { "model_name": (s.MODELS,), "prompt": ("STRING", {"multiline": True, "default": "", "placeholder": "请输入提示词"}), }, "optional": { "video": ("*",), "image": ("IMAGE",), "audio": ("AUDIO",), "url": ("STRING", {"multiline": True, "default": None, "placeholder": "【可选】输入要分析的链接"}), "environment": (s.ENVIRONMENTS,), "timeout": ("INT", {"default": 300, "min": 10, "max": 1200}), } } RETURN_TYPES = ("STRING",) RETURN_NAMES = ("text",) FUNCTION = "execute" def tensor_to_pil(self, tensor): if tensor is None: return None image_np = tensor[0].cpu().numpy() image_np = (image_np * 255).astype(np.uint8) return Image.fromarray(image_np) def save_pil_to_temp(self, pil_image): output_dir = folder_paths.get_temp_directory() (full_output_folder, filename, counter, _, _) = folder_paths.get_save_image_path("llm_temp_image", output_dir) filepath = os.path.join(full_output_folder, f"{filename}_{counter:05}.png") pil_image.save(filepath, 'PNG') return filepath def save_link_file(self, link_url: str, is_google: bool = False): def download_file(url): suffix = url.rsplit('.', 1)[-1] response = httpx.get(url, timeout=120) output_dir = folder_paths.get_temp_directory() (full_output_folder, filename, counter, _, _) = folder_paths.get_save_image_path("llm_temp_image", output_dir) filepath = os.path.join(full_output_folder, f"{filename}_{counter:05}.{suffix}") with open(filepath, 'wb') as f: f.write(response.content) return filepath link_url = link_url.strip() if is_google and link_url.startswith("gs:"): return link_url else: return download_file(link_url) def save_audio_tensor_to_temp(self, waveform_tensor, sample_rate): if 'wavfile' not in globals(): raise ImportError("Scipy 库未安装。请在您的 ComfyUI 环境中运行 'pip install scipy' 来启用此功能。") waveform_np = waveform_tensor.cpu().numpy() if waveform_np.ndim == 3: waveform_np = waveform_np[0] waveform_np = waveform_np.T waveform_int16 = np.int16(waveform_np * 32767) output_dir = folder_paths.get_temp_directory() (full_output_folder, filename, counter, _, _) = folder_paths.get_save_image_path("llm_temp_audio", output_dir) filepath = os.path.join(full_output_folder, f"{filename}_{counter:05}.wav") wavfile.write(filepath, sample_rate, waveform_int16) print(f"音频张量已使用 Scipy 保存到临时文件: {filepath}") return filepath def handler_other_llm(self, model_name: str, prompt: str, media_path: str, timeout: int): messages_content = [{"type": "text", "text": prompt}] if media_path and os.path.exists(media_path): try: with open(media_path, "rb") as media_file: base64_media = base64.b64encode(media_file.read()).decode('utf-8') mime_type = mimetypes.guess_type(media_path)[0] or "application/octet-stream" data_url = f"data:{mime_type};base64,{base64_media}" messages_content.append({"type": "image_url", "image_url": {"url": data_url}}) except Exception as e: return f"Error encoding media file: {str(e)}" json_payload = {"model": model_name, "messages": [{"role": "user", "content": messages_content}], "temperature": 0.7, "max_tokens": 4096} try: resp = requests.post("https://gateway.bowong.cc/chat/completions", headers={"Content-Type": "application/json", "Authorization": "Bearer auth-bowong7777"}, json=json_payload, timeout=timeout) resp.raise_for_status() resp_json = resp.json() if 'choices' in resp_json and resp_json['choices']: return resp_json['choices'][0]['message']['content'] else: return f'Call LLM failed: {resp_json.get("error", {}).get("message", "LLM API returned no choices.")}' except requests.RequestException as e: return f"Error calling other LLM API: {str(e)}" def execute(self, model_name: str, prompt: str, environment: str = "prod", video: object = None, image: torch.Tensor = None, audio: object = None, url: str = None, timeout=300): base_url = self.ENV_URLS.get(environment, self.ENV_URLS["prod"]) media_path = None url = url.strip() if video is not None: if 'gemini' not in model_name: raise ValueError(f'{model_name}暂不支持视频分析,\n请使用gemini-2.5-flash或者gemini-2.5-pro') print('多模态处理视频输入...') unwrapped_input = video[0] if isinstance(video, (list, tuple)) and video else video if hasattr(unwrapped_input, 'save_to'): try: output_dir = folder_paths.get_temp_directory() (full_output_folder, filename, counter, _, _) = folder_paths.get_save_image_path("llm_temp_video", output_dir) temp_video_path = os.path.join(full_output_folder, f"{filename}_{counter:05}.mp4") unwrapped_input.save_to(temp_video_path) if os.path.exists(temp_video_path): media_path = temp_video_path else: return (f"错误: 调用 save_to() 后文件未成功创建。",) except Exception as e: return (f"调用 save_to() 时出错: {e}",) elif isinstance(unwrapped_input, str): full_path = folder_paths.get_full_path("input", unwrapped_input) if full_path and os.path.exists(full_path): media_path = full_path else: return (f"错误: 无法在 'input' 文件夹中找到文件 '{unwrapped_input}'。",) elif image is not None: print('多模态处理图片输出...') pil_image = self.tensor_to_pil(image) media_path = self.save_pil_to_temp(pil_image) elif audio is not None: if 'gemini' not in model_name: raise ValueError(f'{model_name}暂不支持音频分析,\n请使用gemini-2.5-flash或者gemini-2.5-pro') print("多模态处理音频输入...") audio_info = audio[0] if isinstance(audio, (list, tuple)) and audio else audio if isinstance(audio_info, dict) and 'filename' in audio_info: filename = audio_info['filename'] print(f"从音频对象中找到 'filename': '{filename}'") full_path = folder_paths.get_full_path("input", filename) if full_path and os.path.exists(full_path): media_path = full_path else: return (f"错误: 无法在 'input' 文件夹中找到文件 '{filename}'。",) elif isinstance(audio_info, dict) and 'waveform' in audio_info and 'sample_rate' in audio_info: print("从音频对象中找到 'waveform' 数据,正在使用 Scipy 保存为临时文件...") try: media_path = self.save_audio_tensor_to_temp(audio_info['waveform'], audio_info['sample_rate']) except Exception as e: return (f"错误: 保存音频张量时出错: {e}",) elif isinstance(audio_info, str): print(f"检测到音频输入为字符串,作为文件名处理: '{audio_info}'") full_path = folder_paths.get_full_path("input", audio_info) if full_path and os.path.exists(full_path): media_path = full_path else: return (f"错误: 无法在 'input' 文件夹中找到文件 '{audio_info}'。",) else: return (f"错误: 不支持的音频输入格式或结构。收到类型: {type(audio_info)}",) elif url: url = url.strip() model_name = model_name.strip() is_google = model_name.startswith('gemini') media_path = self.save_link_file(link_url=url, is_google=is_google) else: print("纯文本运行llm") if media_path: print(f"成功解析媒体文件路径: {media_path}") model_name = model_name.strip() if model_name.startswith('gemini'): result = handler_google_analytics(prompt, model_name, media_path, base_url=base_url, timeout=timeout) else: result = self.handler_other_llm(model_name, prompt, media_path, timeout=timeout) return (result,) # 节点映射 NODE_CLASS_MAPPINGS = {"LLMUionNode": LLMUionNode} NODE_DISPLAY_NAME_MAPPINGS = {"LLMUionNode": "聚合LLM节点(视频/图像/音频)"}