ComfyUI-CustomNode/nodes/llm_nodes.py

280 lines
12 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# LLM API 通过cloudflare gateway调用llm
import base64
import io
import json
import os
import re
from mimetypes import guess_type
from typing import Any, Union
import folder_paths
import httpx
import numpy as np
import torch
from PIL import Image
from jinja2 import Template, StrictUndefined
from retry import retry
def find_value_recursive(key: str, data: Union[dict, list]) -> str | None | Any:
if isinstance(data, dict):
if key in data:
return data[key]
# 递归检查所有其他键的值
for value in data.values():
result = find_value_recursive(key, value)
if result is not None:
return result
elif isinstance(data, list):
for item in data:
result = find_value_recursive(key, item)
if result is not None:
return result
def image_tensor_to_base64(image):
pil_image = Image.fromarray(np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8))
# 创建一个BytesIO对象用于临时存储图像数据
image_data = io.BytesIO()
# 将图像保存到BytesIO对象中格式为PNG
pil_image.save(image_data, format='PNG')
# 将BytesIO对象的内容转换为字节串
image_data_bytes = image_data.getvalue()
# 将图像数据编码为Base64字符串
encoded_image = "data:image/png;base64," + base64.b64encode(image_data_bytes).decode('utf-8')
return encoded_image
class LLMChat:
"""llm chat"""
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"llm_provider": (["claude-3-5-sonnet-20241022-v2",
"claude-3-5-sonnet-20241022-v3",
"claude-3-7-sonnet-20250219-v1",
"claude-4-sonnet-20250514-v1",
"gpt-4o-1120",
"gpt-4.1",
"deepseek-v3",
"deepseek-r1"],),
"prompt": ("STRING", {"multiline": True}),
"temperature": ("FLOAT", {"default": 0.7, "min": 0.0, "max": 1.0}),
"max_tokens": ("INT", {"default": 4096, "min": 1, "max": 65535}),
"timeout": ("INT", {"default": 120, "min": 30, "max": 900}),
}
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("llm输出",)
FUNCTION = "chat"
CATEGORY = "不忘科技-自定义节点🚩/LLM"
def chat(self, llm_provider: str, prompt: str, temperature: float, max_tokens: int, timeout: int):
@retry(Exception, tries=3, delay=1)
def _chat():
try:
with httpx.Client(timeout=httpx.Timeout(timeout, connect=15)) as session:
resp = session.post("https://gateway.bowong.cc/chat/completions",
headers={
"Content-Type": "application/json",
"Accept": "application/json",
"Authorization": "Bearer auth-bowong7777"
},
json={
"model": llm_provider,
"messages": [
{
"role": "user",
"content": prompt
}
],
"temperature": temperature,
"max_tokens": max_tokens
})
resp.raise_for_status()
resp = resp.json()
content = find_value_recursive("content", resp)
content = re.sub(r'\n{2,}', '\n', content)
except Exception as e:
raise Exception("llm调用失败 {}".format(e))
return (content,)
return _chat()
class LLMChatMultiModalImageUpload:
"""llm chat"""
@classmethod
def INPUT_TYPES(s):
input_dir = folder_paths.get_input_directory()
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
files = folder_paths.filter_files_content_types(files, ["image"])
return {
"required": {
"llm_provider": (["gpt-4o-1120",
"gpt-4.1"],),
"prompt": ("STRING", {"multiline": True}),
"image": (sorted(files), {"image_upload": True}),
"temperature": ("FLOAT", {"default": 0.7, "min": 0.0, "max": 1.0}),
"max_tokens": ("INT", {"default": 4096, "min": 1, "max": 65535}),
"timeout": ("INT", {"default": 120, "min": 30, "max": 900}),
}
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("llm输出",)
FUNCTION = "chat"
CATEGORY = "不忘科技-自定义节点🚩/LLM"
def chat(self, llm_provider: str, prompt: str, image, temperature: float, max_tokens: int, timeout: int):
@retry(Exception, tries=3, delay=1)
def _chat():
try:
image_path = folder_paths.get_annotated_filepath(image)
mime_type, _ = guess_type(image_path)
with open(image_path, "rb") as image_file:
base64_encoded_data = base64.b64encode(image_file.read()).decode('utf-8')
with httpx.Client(timeout=httpx.Timeout(timeout, connect=15)) as session:
resp = session.post("https://gateway.bowong.cc/chat/completions",
headers={
"Content-Type": "application/json",
"Accept": "application/json",
"Authorization": "Bearer auth-bowong7777"
},
json={
"model": llm_provider,
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{
"type": "image_url",
"image_url": {
"url": f"data:{mime_type};base64,{base64_encoded_data}"},
},
]
}
],
"temperature": temperature,
"max_tokens": max_tokens
})
resp.raise_for_status()
resp = resp.json()
content = find_value_recursive("content", resp)
content = re.sub(r'\n{2,}', '\n', content)
except Exception as e:
# logger.exception("llm调用失败 {}".format(e))
raise Exception("llm调用失败 {}".format(e))
return (content,)
return _chat()
class LLMChatMultiModalImageTensor:
"""llm chat"""
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"llm_provider": (["gpt-4o-1120",
"gpt-4.1"],),
"prompt": ("STRING", {"multiline": True}),
"image": ("IMAGE",),
"temperature": ("FLOAT", {"default": 0.7, "min": 0.0, "max": 1.0}),
"max_tokens": ("INT", {"default": 4096, "min": 1, "max": 65535}),
"timeout": ("INT", {"default": 120, "min": 30, "max": 900}),
}
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("llm输出",)
FUNCTION = "chat"
CATEGORY = "不忘科技-自定义节点🚩/LLM"
def chat(self, llm_provider: str, prompt: str, image: torch.Tensor, temperature: float, max_tokens: int,
timeout: int):
@retry(Exception, tries=3, delay=1)
def _chat():
try:
with httpx.Client(timeout=httpx.Timeout(timeout, connect=15)) as session:
resp = session.post("https://gateway.bowong.cc/chat/completions",
headers={
"Content-Type": "application/json",
"Accept": "application/json",
"Authorization": "Bearer auth-bowong7777"
},
json={
"model": llm_provider,
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{
"type": "image_url",
"image_url": {"url": image_tensor_to_base64(image)},
},
]
}
],
"temperature": temperature,
"max_tokens": max_tokens
})
resp.raise_for_status()
resp = resp.json()
content = find_value_recursive("content", resp)
content = re.sub(r'\n{2,}', '\n', content)
except Exception as e:
# logger.exception("llm调用失败 {}".format(e))
raise Exception("llm调用失败 {}".format(e))
return (content,)
return _chat()
class Jinja2RenderTemplate:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"template": ("STRING", {"multiline": True}),
"kv_map": ("STRING", {"multiline": True}),
}
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("prompt",)
FUNCTION = "render_prompt"
CATEGORY = "不忘科技-自定义节点🚩/LLM"
def render_prompt(self, template: str, kv_map: str) -> tuple:
"""
使用Jinja2渲染prompt模板
参数:
template: 包含Jinja2标记的模板字符串
kv_map: 键值映射字典,用于提供模板渲染所需的变量
返回:
渲染后的字符串
异常:
如果模板中有未定义的变量抛出jinja2.exceptions.UndefinedError
"""
kv_map = json.loads(kv_map)
# 创建模板对象,设置为严格模式,未定义变量会抛出异常
template = Template(template, undefined=StrictUndefined)
# 渲染模板
return (template.render(kv_map),)