diff --git a/__init__.py b/__init__.py index 9892da6..f09aa23 100644 --- a/__init__.py +++ b/__init__.py @@ -1,4 +1,4 @@ -from .nodes.llm_api import LLMChat, LLMChatMultiModal +from .nodes.llm_api import LLMChat, LLMChatMultiModalImageUpload, LLMChatMultiModalImageTensor from .nodes.compute_video_point import VideoStartPointDurationCompute from .nodes.cos import COSUpload, COSDownload from .nodes.face_detect import FaceDetect @@ -66,7 +66,8 @@ NODE_CLASS_MAPPINGS = { "PlugAndPlayWebhook": PlugAndPlayWebhook, "SaveImageWithOutput": SaveImageWithOutput, "LLMChat": LLMChat, - "LLMChatMultiModal": LLMChatMultiModal + "LLMChatMultiModalImageUpload": LLMChatMultiModalImageUpload, + "LLMChatMultiModalImageTensor": LLMChatMultiModalImageTensor } # A dictionary that contains the friendly/humanly readable titles for the nodes @@ -106,5 +107,6 @@ NODE_DISPLAY_NAME_MAPPINGS = { "PlugAndPlayWebhook": "Webhook转发器", "SaveImageWithOutput": "保存图片(带输出)", "LLMChat": "LLM调用", - "LLMChatMultiModal": "多模态LLM调用" + "LLMChatMultiModalImageUpload": "多模态LLM调用-图片Path", + "LLMChatMultiModalImageTensor": "多模态LLM调用-图片Tensor" } diff --git a/nodes/llm_api.py b/nodes/llm_api.py index 829a86e..fb575d2 100644 --- a/nodes/llm_api.py +++ b/nodes/llm_api.py @@ -1,11 +1,15 @@ # LLM API 通过cloudflare gateway调用llm import base64 +import io import os import re from mimetypes import guess_type from typing import Any, Union import httpx +import numpy as np +import torch +from PIL import Image from retry import retry import folder_paths @@ -26,6 +30,22 @@ def find_value_recursive(key:str, data:Union[dict, list]) -> str | None | Any: if result is not None: return result +def image_tensor_to_base64(image): + pil_image = Image.fromarray(np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8)) + # 创建一个BytesIO对象,用于临时存储图像数据 + image_data = io.BytesIO() + + # 将图像保存到BytesIO对象中,格式为PNG + pil_image.save(image_data, format='PNG') + + # 将BytesIO对象的内容转换为字节串 + image_data_bytes = image_data.getvalue() + + # 将图像数据编码为Base64字符串 + encoded_image = "data:image/png;base64," + base64.b64encode(image_data_bytes).decode('utf-8') + + return encoded_image + class LLMChat: """llm chat""" @@ -85,7 +105,7 @@ class LLMChat: return (content,) return _chat() -class LLMChatMultiModal: +class LLMChatMultiModalImageUpload: """llm chat""" @classmethod @@ -150,4 +170,64 @@ class LLMChatMultiModal: # logger.exception("llm调用失败 {}".format(e)) raise Exception("llm调用失败 {}".format(e)) return (content,) + return _chat() + +class LLMChatMultiModalImageTensor: + """llm chat""" + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "llm_provider": (["gpt-4o-1120", + "gpt-4.1"],), + "prompt": ("STRING", {"multiline": True}), + "image": ("IMAGE",), + "temperature": ("FLOAT",{"default": 0.7, "min": 0.0, "max": 1.0}), + "max_tokens": ("INT",{"default": 4096, "min":1, "max":65535}), + "timeout": ("INT", {"default": 120, "min": 30, "max": 900}), + } + } + + RETURN_TYPES = ("STRING",) + RETURN_NAMES = ("llm输出",) + FUNCTION = "chat" + CATEGORY = "不忘科技-自定义节点🚩/llm" + + def chat(self, llm_provider:str, prompt:str, image:torch.Tensor, temperature:float, max_tokens:int, timeout:int): + @retry(Exception, tries=3, delay=1) + def _chat(): + try: + with httpx.Client(timeout=httpx.Timeout(timeout, connect=15)) as session: + resp = session.post("https://gateway.bowong.cc/chat/completions", + headers={ + "Content-Type": "application/json", + "Accept": "application/json", + "Authorization": "Bearer auth-bowong7777" + }, + json={ + "model": llm_provider, + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": prompt}, + { + "type": "image_url", + "image_url": {"url":image_tensor_to_base64(image)}, + }, + ] + } + ], + "temperature": temperature, + "max_tokens": max_tokens + }) + resp.raise_for_status() + resp = resp.json() + content = find_value_recursive("content", resp) + content = re.sub(r'\n{2,}', '\n', content) + except Exception as e: + # logger.exception("llm调用失败 {}".format(e)) + raise Exception("llm调用失败 {}".format(e)) + return (content,) return _chat() \ No newline at end of file