ComfyUI-CustomNode/nodes/llm_nodes.py

412 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# LLM API 通过cloudflare gateway调用llm
import base64
import io
import json
import os
import re
from mimetypes import guess_type
from time import sleep
from typing import Any, Union
import folder_paths
import httpx
import numpy as np
import requests
import torch
from PIL import Image
from jinja2 import Template, StrictUndefined
from loguru import logger
from retry import retry
from ..utils.http_utils import send_request
from ..utils.image_utils import tensor_to_image_bytes, base64_to_tensor
def find_value_recursive(key: str, data: Union[dict, list]) -> str | None | Any:
if isinstance(data, dict):
if key in data:
return data[key]
# 递归检查所有其他键的值
for value in data.values():
result = find_value_recursive(key, value)
if result is not None:
return result
elif isinstance(data, list):
for item in data:
result = find_value_recursive(key, item)
if result is not None:
return result
def image_tensor_to_base64(image):
pil_image = Image.fromarray(np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8))
# 创建一个BytesIO对象用于临时存储图像数据
image_data = io.BytesIO()
# 将图像保存到BytesIO对象中格式为PNG
pil_image.save(image_data, format='PNG')
# 将BytesIO对象的内容转换为字节串
image_data_bytes = image_data.getvalue()
# 将图像数据编码为Base64字符串
encoded_image = "data:image/png;base64," + base64.b64encode(image_data_bytes).decode('utf-8')
return encoded_image
class LLMChat:
"""llm chat"""
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"llm_provider": (["claude-3-5-sonnet-20241022-v2",
"claude-3-5-sonnet-20241022-v3",
"claude-3-7-sonnet-20250219-v1",
"claude-4-sonnet-20250514-v1",
"gpt-4o-1120",
"gpt-4.1",
"deepseek-v3",
"deepseek-r1"],),
"prompt": ("STRING", {"multiline": True}),
"temperature": ("FLOAT", {"default": 0.7, "min": 0.0, "max": 1.0}),
"max_tokens": ("INT", {"default": 4096, "min": 1, "max": 65535}),
"timeout": ("INT", {"default": 120, "min": 30, "max": 900}),
}
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("llm输出",)
FUNCTION = "chat"
CATEGORY = "不忘科技-自定义节点🚩/LLM"
def chat(self, llm_provider: str, prompt: str, temperature: float, max_tokens: int, timeout: int):
@retry(Exception, tries=3, delay=1)
def _chat():
try:
with httpx.Client(timeout=httpx.Timeout(timeout, connect=15)) as session:
resp = session.post("https://gateway.bowong.cc/chat/completions",
headers={
"Content-Type": "application/json",
"Accept": "application/json",
"Authorization": "Bearer auth-bowong7777"
},
json={
"model": llm_provider,
"messages": [
{
"role": "user",
"content": prompt
}
],
"temperature": temperature,
"max_tokens": max_tokens
})
resp.raise_for_status()
resp = resp.json()
content = find_value_recursive("content", resp)
content = re.sub(r'\n{2,}', '\n', content)
except Exception as e:
raise Exception("llm调用失败 {}".format(e))
return (content,)
return _chat()
class LLMChatMultiModalImageUpload:
"""llm chat"""
@classmethod
def INPUT_TYPES(s):
input_dir = folder_paths.get_input_directory()
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
files = folder_paths.filter_files_content_types(files, ["image"])
return {
"required": {
"llm_provider": (["gpt-4o-1120",
"gpt-4.1"],),
"prompt": ("STRING", {"multiline": True}),
"image": (sorted(files), {"image_upload": True}),
"temperature": ("FLOAT", {"default": 0.7, "min": 0.0, "max": 1.0}),
"max_tokens": ("INT", {"default": 4096, "min": 1, "max": 65535}),
"timeout": ("INT", {"default": 120, "min": 30, "max": 900}),
}
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("llm输出",)
FUNCTION = "chat"
CATEGORY = "不忘科技-自定义节点🚩/LLM"
def chat(self, llm_provider: str, prompt: str, image, temperature: float, max_tokens: int, timeout: int):
@retry(Exception, tries=3, delay=1)
def _chat():
try:
image_path = folder_paths.get_annotated_filepath(image)
mime_type, _ = guess_type(image_path)
with open(image_path, "rb") as image_file:
base64_encoded_data = base64.b64encode(image_file.read()).decode('utf-8')
with httpx.Client(timeout=httpx.Timeout(timeout, connect=15)) as session:
resp = session.post("https://gateway.bowong.cc/chat/completions",
headers={
"Content-Type": "application/json",
"Accept": "application/json",
"Authorization": "Bearer auth-bowong7777"
},
json={
"model": llm_provider,
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{
"type": "image_url",
"image_url": {
"url": f"data:{mime_type};base64,{base64_encoded_data}"},
},
]
}
],
"temperature": temperature,
"max_tokens": max_tokens
})
resp.raise_for_status()
resp = resp.json()
content = find_value_recursive("content", resp)
content = re.sub(r'\n{2,}', '\n', content)
except Exception as e:
# logger.exception("llm调用失败 {}".format(e))
raise Exception("llm调用失败 {}".format(e))
return (content,)
return _chat()
class LLMChatMultiModalImageTensor:
"""llm chat"""
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"llm_provider": (["gpt-4o-1120",
"gpt-4.1"],),
"prompt": ("STRING", {"multiline": True}),
"image": ("IMAGE",),
"temperature": ("FLOAT", {"default": 0.7, "min": 0.0, "max": 1.0}),
"max_tokens": ("INT", {"default": 4096, "min": 1, "max": 65535}),
"timeout": ("INT", {"default": 120, "min": 30, "max": 900}),
}
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("llm输出",)
FUNCTION = "chat"
CATEGORY = "不忘科技-自定义节点🚩/LLM"
def chat(self, llm_provider: str, prompt: str, image: torch.Tensor, temperature: float, max_tokens: int,
timeout: int):
@retry(Exception, tries=3, delay=1)
def _chat():
try:
with httpx.Client(timeout=httpx.Timeout(timeout, connect=15)) as session:
resp = session.post("https://gateway.bowong.cc/chat/completions",
headers={
"Content-Type": "application/json",
"Accept": "application/json",
"Authorization": "Bearer auth-bowong7777"
},
json={
"model": llm_provider,
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{
"type": "image_url",
"image_url": {"url": image_tensor_to_base64(image)},
},
]
}
],
"temperature": temperature,
"max_tokens": max_tokens
})
resp.raise_for_status()
resp = resp.json()
content = find_value_recursive("content", resp)
content = re.sub(r'\n{2,}', '\n', content)
except Exception as e:
# logger.exception("llm调用失败 {}".format(e))
raise Exception("llm调用失败 {}".format(e))
return (content,)
return _chat()
class Jinja2RenderTemplate:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"template": ("STRING", {"multiline": True}),
"kv_map": ("STRING", {"multiline": True}),
}
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("prompt",)
FUNCTION = "render_prompt"
CATEGORY = "不忘科技-自定义节点🚩/LLM"
def render_prompt(self, template: str, kv_map: str) -> tuple:
"""
使用Jinja2渲染prompt模板
参数:
template: 包含Jinja2标记的模板字符串
kv_map: 键值映射字典,用于提供模板渲染所需的变量
返回:
渲染后的字符串
异常:
如果模板中有未定义的变量抛出jinja2.exceptions.UndefinedError
"""
kv_map = json.loads(kv_map)
# 创建模板对象,设置为严格模式,未定义变量会抛出异常
template = Template(template, undefined=StrictUndefined)
# 渲染模板
return (template.render(kv_map),)
class ModalClothesMask:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",),
"mask_color": ("STRING", {"default": "绿色"}),
"clothes_type": ("STRING", {"default": "裤子"}),
"endpoint": ("STRING", {"default": "bowongai-dev--bowong-ai-video-gemini-fastapi-webapp.modal.run"}),
},
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("image",)
FUNCTION = "process"
OUTPUT_NODE = False
CATEGORY = "不忘科技-自定义节点🚩/图片/Gemini图像编辑"
def process(self, image: torch.Tensor, mask_color: str, clothes_type: str, endpoint: str):
try:
timeout = 60
logger.info("获取token")
api_key = send_request("get", f"https://{endpoint}/google/access-token",
headers={'Authorization': 'Bearer bowong7777'}, timeout=timeout).json()[
"access_token"]
format = "PNG"
logger.info("请求图像编辑")
job_resp = send_request("post", f"https://{endpoint}/google/image/clothes_mark",
headers={'x-google-api-key': api_key},
data={
"mark_clothes_type": clothes_type,
"mark_color": mask_color,
},
files={"origin_image": (
'image.' + format.lower(), tensor_to_image_bytes(image, format),
f'image/{format.lower()}')},
timeout=timeout)
job_resp.raise_for_status()
job_resp = job_resp.json()
if not job_resp["success"]:
raise Exception("请求Modal API失败")
job_id = job_resp["taskId"]
wait_time = 240
interval = 3
logger.info("开始轮询任务状态")
for _ in range(0, wait_time, interval):
logger.info("查询任务状态")
result = send_request("get", f"https://{endpoint}/google/{job_id}", timeout=timeout)
if result.status_code == 200:
result = result.json()
if result["status"] == "success":
logger.success("任务成功")
image_b64 = json.loads(result["result"])[0]["image_b64"]
image_tensor = base64_to_tensor(image_b64)
return (image_tensor,)
elif "fail" in result["status"].lower():
raise Exception("任务失败")
sleep(interval)
raise Exception("查询任务状态超时")
except Exception as e:
raise Exception(e)
class ModalEditCustom:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",),
"prompt": ("STRING", {"default": "将背景去除,输出原尺寸图片","multiline": True}),
"endpoint": ("STRING", {"default": "bowongai-dev--bowong-ai-video-gemini-fastapi-webapp.modal.run"}),
},
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("image",)
FUNCTION = "process"
OUTPUT_NODE = False
CATEGORY = "不忘科技-自定义节点🚩/图片/Gemini图像编辑"
def process(self, image: torch.Tensor, prompt: str, endpoint: str):
try:
timeout = 60
logger.info("获取token")
api_key = send_request("get", f"https://{endpoint}/google/access-token",
headers={'Authorization': 'Bearer bowong7777'}, timeout=timeout).json()[
"access_token"]
format = "PNG"
logger.info("请求图像编辑")
job_resp = send_request("post", f"https://{endpoint}/google/image/edit_custom",
headers={'x-google-api-key': api_key},
data={
"prompt": prompt
},
files={"origin_image": (
'image.' + format.lower(), tensor_to_image_bytes(image, format),
f'image/{format.lower()}')},
timeout=timeout)
job_resp.raise_for_status()
job_resp = job_resp.json()
if not job_resp["success"]:
raise Exception("请求Modal API失败")
job_id = job_resp["taskId"]
wait_time = 240
interval = 3
logger.info("开始轮询任务状态")
for _ in range(0, wait_time, interval):
logger.info("查询任务状态")
result = send_request("get", f"https://{endpoint}/google/{job_id}", timeout=timeout)
if result.status_code == 200:
result = result.json()
if result["status"] == "success":
logger.success("任务成功")
image_b64 = json.loads(result["result"])[0]["image_b64"]
image_tensor = base64_to_tensor(image_b64)
return (image_tensor,)
elif "fail" in result["status"].lower():
raise Exception("任务失败")
sleep(interval)
raise Exception("查询任务状态超时")
except Exception as e:
raise Exception(e)