ComfyUI-CustomNode/nodes/llm_api.py

233 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# LLM API 通过cloudflare gateway调用llm
import base64
import io
import os
import re
from mimetypes import guess_type
from typing import Any, Union
import httpx
import numpy as np
import torch
from PIL import Image
from retry import retry
import folder_paths
def find_value_recursive(key:str, data:Union[dict, list]) -> str | None | Any:
if isinstance(data, dict):
if key in data:
return data[key]
# 递归检查所有其他键的值
for value in data.values():
result = find_value_recursive(key, value)
if result is not None:
return result
elif isinstance(data, list):
for item in data:
result = find_value_recursive(key, item)
if result is not None:
return result
def image_tensor_to_base64(image):
pil_image = Image.fromarray(np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8))
# 创建一个BytesIO对象用于临时存储图像数据
image_data = io.BytesIO()
# 将图像保存到BytesIO对象中格式为PNG
pil_image.save(image_data, format='PNG')
# 将BytesIO对象的内容转换为字节串
image_data_bytes = image_data.getvalue()
# 将图像数据编码为Base64字符串
encoded_image = "data:image/png;base64," + base64.b64encode(image_data_bytes).decode('utf-8')
return encoded_image
class LLMChat:
"""llm chat"""
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"llm_provider": (["claude-3-5-sonnet-20241022-v2",
"claude-3-5-sonnet-20241022-v3",
"claude-3-7-sonnet-20250219-v1",
"claude-4-sonnet-20250514-v1",
"gpt-4o-1120",
"gpt-4.1",
"deepseek-v3",
"deepseek-r1"],),
"prompt": ("STRING", {"multiline": True}),
"temperature": ("FLOAT",{"default": 0.7, "min": 0.0, "max": 1.0}),
"max_tokens": ("INT",{"default": 4096, "min":1, "max":65535}),
"timeout": ("INT", {"default": 120, "min": 30, "max": 900}),
}
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("llm输出",)
FUNCTION = "chat"
CATEGORY = "不忘科技-自定义节点🚩/llm"
def chat(self, llm_provider:str, prompt:str, temperature:float, max_tokens:int, timeout:int):
@retry(Exception, tries=3, delay=1)
def _chat():
try:
with httpx.Client(timeout=httpx.Timeout(timeout, connect=15)) as session:
resp = session.post("https://gateway.bowong.cc/chat/completions",
headers={
"Content-Type": "application/json",
"Accept": "application/json",
"Authorization": "Bearer auth-bowong7777"
},
json={
"model": llm_provider,
"messages": [
{
"role": "user",
"content": prompt
}
],
"temperature": temperature,
"max_tokens": max_tokens
})
resp.raise_for_status()
resp = resp.json()
content = find_value_recursive("content", resp)
content = re.sub(r'\n{2,}', '\n', content)
except Exception as e:
# logger.exception("llm调用失败 {}".format(e))
raise Exception("llm调用失败 {}".format(e))
return (content,)
return _chat()
class LLMChatMultiModalImageUpload:
"""llm chat"""
@classmethod
def INPUT_TYPES(s):
input_dir = folder_paths.get_input_directory()
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
files = folder_paths.filter_files_content_types(files, ["image"])
return {
"required": {
"llm_provider": (["gpt-4o-1120",
"gpt-4.1"],),
"prompt": ("STRING", {"multiline": True}),
"image": (sorted(files), {"image_upload": True}),
"temperature": ("FLOAT",{"default": 0.7, "min": 0.0, "max": 1.0}),
"max_tokens": ("INT",{"default": 4096, "min":1, "max":65535}),
"timeout": ("INT", {"default": 120, "min": 30, "max": 900}),
}
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("llm输出",)
FUNCTION = "chat"
CATEGORY = "不忘科技-自定义节点🚩/llm"
def chat(self, llm_provider:str, prompt:str, image, temperature:float, max_tokens:int, timeout:int):
@retry(Exception, tries=3, delay=1)
def _chat():
try:
image_path = folder_paths.get_annotated_filepath(image)
mime_type, _ = guess_type(image_path)
with open(image_path, "rb") as image_file:
base64_encoded_data = base64.b64encode(image_file.read()).decode('utf-8')
with httpx.Client(timeout=httpx.Timeout(timeout, connect=15)) as session:
resp = session.post("https://gateway.bowong.cc/chat/completions",
headers={
"Content-Type": "application/json",
"Accept": "application/json",
"Authorization": "Bearer auth-bowong7777"
},
json={
"model": llm_provider,
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{
"type": "image_url",
"image_url": {"url":f"data:{mime_type};base64,{base64_encoded_data}"},
},
]
}
],
"temperature": temperature,
"max_tokens": max_tokens
})
resp.raise_for_status()
resp = resp.json()
content = find_value_recursive("content", resp)
content = re.sub(r'\n{2,}', '\n', content)
except Exception as e:
# logger.exception("llm调用失败 {}".format(e))
raise Exception("llm调用失败 {}".format(e))
return (content,)
return _chat()
class LLMChatMultiModalImageTensor:
"""llm chat"""
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"llm_provider": (["gpt-4o-1120",
"gpt-4.1"],),
"prompt": ("STRING", {"multiline": True}),
"image": ("IMAGE",),
"temperature": ("FLOAT",{"default": 0.7, "min": 0.0, "max": 1.0}),
"max_tokens": ("INT",{"default": 4096, "min":1, "max":65535}),
"timeout": ("INT", {"default": 120, "min": 30, "max": 900}),
}
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("llm输出",)
FUNCTION = "chat"
CATEGORY = "不忘科技-自定义节点🚩/llm"
def chat(self, llm_provider:str, prompt:str, image:torch.Tensor, temperature:float, max_tokens:int, timeout:int):
@retry(Exception, tries=3, delay=1)
def _chat():
try:
with httpx.Client(timeout=httpx.Timeout(timeout, connect=15)) as session:
resp = session.post("https://gateway.bowong.cc/chat/completions",
headers={
"Content-Type": "application/json",
"Accept": "application/json",
"Authorization": "Bearer auth-bowong7777"
},
json={
"model": llm_provider,
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{
"type": "image_url",
"image_url": {"url":image_tensor_to_base64(image)},
},
]
}
],
"temperature": temperature,
"max_tokens": max_tokens
})
resp.raise_for_status()
resp = resp.json()
content = find_value_recursive("content", resp)
content = re.sub(r'\n{2,}', '\n', content)
except Exception as e:
# logger.exception("llm调用失败 {}".format(e))
raise Exception("llm调用失败 {}".format(e))
return (content,)
return _chat()