412 lines
17 KiB
Python
412 lines
17 KiB
Python
# LLM API 通过cloudflare gateway调用llm
|
||
import base64
|
||
import io
|
||
import json
|
||
import os
|
||
import re
|
||
from mimetypes import guess_type
|
||
from time import sleep
|
||
from typing import Any, Union
|
||
|
||
import folder_paths
|
||
import httpx
|
||
import numpy as np
|
||
import requests
|
||
import torch
|
||
from PIL import Image
|
||
from jinja2 import Template, StrictUndefined
|
||
from loguru import logger
|
||
from retry import retry
|
||
|
||
from ..utils.http_utils import send_request
|
||
from ..utils.image_utils import tensor_to_image_bytes, base64_to_tensor
|
||
|
||
|
||
def find_value_recursive(key: str, data: Union[dict, list]) -> str | None | Any:
|
||
if isinstance(data, dict):
|
||
if key in data:
|
||
return data[key]
|
||
# 递归检查所有其他键的值
|
||
for value in data.values():
|
||
result = find_value_recursive(key, value)
|
||
if result is not None:
|
||
return result
|
||
elif isinstance(data, list):
|
||
for item in data:
|
||
result = find_value_recursive(key, item)
|
||
if result is not None:
|
||
return result
|
||
|
||
|
||
def image_tensor_to_base64(image):
|
||
pil_image = Image.fromarray(np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8))
|
||
# 创建一个BytesIO对象,用于临时存储图像数据
|
||
image_data = io.BytesIO()
|
||
|
||
# 将图像保存到BytesIO对象中,格式为PNG
|
||
pil_image.save(image_data, format='PNG')
|
||
|
||
# 将BytesIO对象的内容转换为字节串
|
||
image_data_bytes = image_data.getvalue()
|
||
|
||
# 将图像数据编码为Base64字符串
|
||
encoded_image = "data:image/png;base64," + base64.b64encode(image_data_bytes).decode('utf-8')
|
||
|
||
return encoded_image
|
||
|
||
|
||
class LLMChat:
|
||
"""llm chat"""
|
||
|
||
@classmethod
|
||
def INPUT_TYPES(s):
|
||
return {
|
||
"required": {
|
||
"llm_provider": (["claude-3-5-sonnet-20241022-v2",
|
||
"claude-3-5-sonnet-20241022-v3",
|
||
"claude-3-7-sonnet-20250219-v1",
|
||
"claude-4-sonnet-20250514-v1",
|
||
"gpt-4o-1120",
|
||
"gpt-4.1",
|
||
"deepseek-v3",
|
||
"deepseek-r1"],),
|
||
"prompt": ("STRING", {"multiline": True}),
|
||
"temperature": ("FLOAT", {"default": 0.7, "min": 0.0, "max": 1.0}),
|
||
"max_tokens": ("INT", {"default": 4096, "min": 1, "max": 65535}),
|
||
"timeout": ("INT", {"default": 120, "min": 30, "max": 900}),
|
||
}
|
||
}
|
||
|
||
RETURN_TYPES = ("STRING",)
|
||
RETURN_NAMES = ("llm输出",)
|
||
FUNCTION = "chat"
|
||
CATEGORY = "不忘科技-自定义节点🚩/LLM"
|
||
|
||
def chat(self, llm_provider: str, prompt: str, temperature: float, max_tokens: int, timeout: int):
|
||
@retry(Exception, tries=3, delay=1)
|
||
def _chat():
|
||
try:
|
||
with httpx.Client(timeout=httpx.Timeout(timeout, connect=15)) as session:
|
||
resp = session.post("https://gateway.bowong.cc/chat/completions",
|
||
headers={
|
||
"Content-Type": "application/json",
|
||
"Accept": "application/json",
|
||
"Authorization": "Bearer auth-bowong7777"
|
||
},
|
||
json={
|
||
"model": llm_provider,
|
||
"messages": [
|
||
{
|
||
"role": "user",
|
||
"content": prompt
|
||
}
|
||
],
|
||
"temperature": temperature,
|
||
"max_tokens": max_tokens
|
||
})
|
||
resp.raise_for_status()
|
||
resp = resp.json()
|
||
content = find_value_recursive("content", resp)
|
||
content = re.sub(r'\n{2,}', '\n', content)
|
||
except Exception as e:
|
||
raise Exception("llm调用失败 {}".format(e))
|
||
return (content,)
|
||
|
||
return _chat()
|
||
|
||
|
||
class LLMChatMultiModalImageUpload:
|
||
"""llm chat"""
|
||
|
||
@classmethod
|
||
def INPUT_TYPES(s):
|
||
input_dir = folder_paths.get_input_directory()
|
||
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
|
||
files = folder_paths.filter_files_content_types(files, ["image"])
|
||
return {
|
||
"required": {
|
||
"llm_provider": (["gpt-4o-1120",
|
||
"gpt-4.1"],),
|
||
"prompt": ("STRING", {"multiline": True}),
|
||
"image": (sorted(files), {"image_upload": True}),
|
||
"temperature": ("FLOAT", {"default": 0.7, "min": 0.0, "max": 1.0}),
|
||
"max_tokens": ("INT", {"default": 4096, "min": 1, "max": 65535}),
|
||
"timeout": ("INT", {"default": 120, "min": 30, "max": 900}),
|
||
}
|
||
}
|
||
|
||
RETURN_TYPES = ("STRING",)
|
||
RETURN_NAMES = ("llm输出",)
|
||
FUNCTION = "chat"
|
||
CATEGORY = "不忘科技-自定义节点🚩/LLM"
|
||
|
||
def chat(self, llm_provider: str, prompt: str, image, temperature: float, max_tokens: int, timeout: int):
|
||
@retry(Exception, tries=3, delay=1)
|
||
def _chat():
|
||
try:
|
||
image_path = folder_paths.get_annotated_filepath(image)
|
||
mime_type, _ = guess_type(image_path)
|
||
with open(image_path, "rb") as image_file:
|
||
base64_encoded_data = base64.b64encode(image_file.read()).decode('utf-8')
|
||
with httpx.Client(timeout=httpx.Timeout(timeout, connect=15)) as session:
|
||
resp = session.post("https://gateway.bowong.cc/chat/completions",
|
||
headers={
|
||
"Content-Type": "application/json",
|
||
"Accept": "application/json",
|
||
"Authorization": "Bearer auth-bowong7777"
|
||
},
|
||
json={
|
||
"model": llm_provider,
|
||
"messages": [
|
||
{
|
||
"role": "user",
|
||
"content": [
|
||
{"type": "text", "text": prompt},
|
||
{
|
||
"type": "image_url",
|
||
"image_url": {
|
||
"url": f"data:{mime_type};base64,{base64_encoded_data}"},
|
||
},
|
||
]
|
||
}
|
||
],
|
||
"temperature": temperature,
|
||
"max_tokens": max_tokens
|
||
})
|
||
resp.raise_for_status()
|
||
resp = resp.json()
|
||
content = find_value_recursive("content", resp)
|
||
content = re.sub(r'\n{2,}', '\n', content)
|
||
except Exception as e:
|
||
# logger.exception("llm调用失败 {}".format(e))
|
||
raise Exception("llm调用失败 {}".format(e))
|
||
return (content,)
|
||
|
||
return _chat()
|
||
|
||
|
||
class LLMChatMultiModalImageTensor:
|
||
"""llm chat"""
|
||
|
||
@classmethod
|
||
def INPUT_TYPES(s):
|
||
return {
|
||
"required": {
|
||
"llm_provider": (["gpt-4o-1120",
|
||
"gpt-4.1"],),
|
||
"prompt": ("STRING", {"multiline": True}),
|
||
"image": ("IMAGE",),
|
||
"temperature": ("FLOAT", {"default": 0.7, "min": 0.0, "max": 1.0}),
|
||
"max_tokens": ("INT", {"default": 4096, "min": 1, "max": 65535}),
|
||
"timeout": ("INT", {"default": 120, "min": 30, "max": 900}),
|
||
}
|
||
}
|
||
|
||
RETURN_TYPES = ("STRING",)
|
||
RETURN_NAMES = ("llm输出",)
|
||
FUNCTION = "chat"
|
||
CATEGORY = "不忘科技-自定义节点🚩/LLM"
|
||
|
||
def chat(self, llm_provider: str, prompt: str, image: torch.Tensor, temperature: float, max_tokens: int,
|
||
timeout: int):
|
||
@retry(Exception, tries=3, delay=1)
|
||
def _chat():
|
||
try:
|
||
with httpx.Client(timeout=httpx.Timeout(timeout, connect=15)) as session:
|
||
resp = session.post("https://gateway.bowong.cc/chat/completions",
|
||
headers={
|
||
"Content-Type": "application/json",
|
||
"Accept": "application/json",
|
||
"Authorization": "Bearer auth-bowong7777"
|
||
},
|
||
json={
|
||
"model": llm_provider,
|
||
"messages": [
|
||
{
|
||
"role": "user",
|
||
"content": [
|
||
{"type": "text", "text": prompt},
|
||
{
|
||
"type": "image_url",
|
||
"image_url": {"url": image_tensor_to_base64(image)},
|
||
},
|
||
]
|
||
}
|
||
],
|
||
"temperature": temperature,
|
||
"max_tokens": max_tokens
|
||
})
|
||
resp.raise_for_status()
|
||
resp = resp.json()
|
||
content = find_value_recursive("content", resp)
|
||
content = re.sub(r'\n{2,}', '\n', content)
|
||
except Exception as e:
|
||
# logger.exception("llm调用失败 {}".format(e))
|
||
raise Exception("llm调用失败 {}".format(e))
|
||
return (content,)
|
||
|
||
return _chat()
|
||
|
||
|
||
class Jinja2RenderTemplate:
|
||
@classmethod
|
||
def INPUT_TYPES(s):
|
||
return {
|
||
"required": {
|
||
"template": ("STRING", {"multiline": True}),
|
||
"kv_map": ("STRING", {"multiline": True}),
|
||
}
|
||
}
|
||
|
||
RETURN_TYPES = ("STRING",)
|
||
RETURN_NAMES = ("prompt",)
|
||
FUNCTION = "render_prompt"
|
||
CATEGORY = "不忘科技-自定义节点🚩/LLM"
|
||
|
||
def render_prompt(self, template: str, kv_map: str) -> tuple:
|
||
"""
|
||
使用Jinja2渲染prompt模板
|
||
|
||
参数:
|
||
template: 包含Jinja2标记的模板字符串
|
||
kv_map: 键值映射字典,用于提供模板渲染所需的变量
|
||
|
||
返回:
|
||
渲染后的字符串
|
||
|
||
异常:
|
||
如果模板中有未定义的变量,抛出jinja2.exceptions.UndefinedError
|
||
"""
|
||
kv_map = json.loads(kv_map)
|
||
# 创建模板对象,设置为严格模式,未定义变量会抛出异常
|
||
template = Template(template, undefined=StrictUndefined)
|
||
|
||
# 渲染模板
|
||
return (template.render(kv_map),)
|
||
|
||
|
||
class ModalClothesMask:
|
||
@classmethod
|
||
def INPUT_TYPES(cls):
|
||
return {
|
||
"required": {
|
||
"image": ("IMAGE",),
|
||
"mask_color": ("STRING", {"default": "绿色"}),
|
||
"clothes_type": ("STRING", {"default": "裤子"}),
|
||
"endpoint": ("STRING", {"default": "bowongai-dev--bowong-ai-video-gemini-fastapi-webapp.modal.run"}),
|
||
},
|
||
}
|
||
|
||
RETURN_TYPES = ("IMAGE",)
|
||
RETURN_NAMES = ("image",)
|
||
FUNCTION = "process"
|
||
OUTPUT_NODE = False
|
||
CATEGORY = "不忘科技-自定义节点🚩/图片/Gemini图像编辑"
|
||
|
||
def process(self, image: torch.Tensor, mask_color: str, clothes_type: str, endpoint: str):
|
||
try:
|
||
timeout = 60
|
||
logger.info("获取token")
|
||
api_key = send_request("get", f"https://{endpoint}/google/access-token",
|
||
headers={'Authorization': 'Bearer bowong7777'}, timeout=timeout).json()[
|
||
"access_token"]
|
||
format = "PNG"
|
||
logger.info("请求图像编辑")
|
||
job_resp = send_request("post", f"https://{endpoint}/google/image/clothes_mark",
|
||
headers={'x-google-api-key': api_key},
|
||
data={
|
||
"mark_clothes_type": clothes_type,
|
||
"mark_color": mask_color,
|
||
},
|
||
files={"origin_image": (
|
||
'image.' + format.lower(), tensor_to_image_bytes(image, format),
|
||
f'image/{format.lower()}')},
|
||
timeout=timeout)
|
||
job_resp.raise_for_status()
|
||
job_resp = job_resp.json()
|
||
if not job_resp["success"]:
|
||
raise Exception("请求Modal API失败")
|
||
job_id = job_resp["taskId"]
|
||
|
||
wait_time = 240
|
||
interval = 3
|
||
logger.info("开始轮询任务状态")
|
||
for _ in range(0, wait_time, interval):
|
||
logger.info("查询任务状态")
|
||
result = send_request("get", f"https://{endpoint}/google/{job_id}", timeout=timeout)
|
||
if result.status_code == 200:
|
||
result = result.json()
|
||
if result["status"] == "success":
|
||
logger.success("任务成功")
|
||
image_b64 = json.loads(result["result"])[0]["image_b64"]
|
||
image_tensor = base64_to_tensor(image_b64)
|
||
return (image_tensor,)
|
||
elif "fail" in result["status"].lower():
|
||
raise Exception("任务失败")
|
||
sleep(interval)
|
||
raise Exception("查询任务状态超时")
|
||
except Exception as e:
|
||
raise Exception(e)
|
||
|
||
|
||
class ModalEditCustom:
|
||
@classmethod
|
||
def INPUT_TYPES(cls):
|
||
return {
|
||
"required": {
|
||
"image": ("IMAGE",),
|
||
"prompt": ("STRING", {"default": "将背景去除,输出原尺寸图片"}),
|
||
"endpoint": ("STRING", {"default": "bowongai-dev--bowong-ai-video-gemini-fastapi-webapp.modal.run"}),
|
||
},
|
||
}
|
||
|
||
RETURN_TYPES = ("IMAGE",)
|
||
RETURN_NAMES = ("image",)
|
||
FUNCTION = "process"
|
||
OUTPUT_NODE = False
|
||
CATEGORY = "不忘科技-自定义节点🚩/图片/Gemini图像编辑"
|
||
|
||
def process(self, image: torch.Tensor, prompt: str, endpoint: str):
|
||
try:
|
||
timeout = 60
|
||
logger.info("获取token")
|
||
api_key = send_request("get", f"https://{endpoint}/google/access-token",
|
||
headers={'Authorization': 'Bearer bowong7777'}, timeout=timeout).json()[
|
||
"access_token"]
|
||
format = "PNG"
|
||
logger.info("请求图像编辑")
|
||
job_resp = send_request("post", f"https://{endpoint}/google/image/edit_custom",
|
||
headers={'x-google-api-key': api_key},
|
||
data={
|
||
"prompt": prompt
|
||
},
|
||
files={"origin_image": (
|
||
'image.' + format.lower(), tensor_to_image_bytes(image, format),
|
||
f'image/{format.lower()}')},
|
||
timeout=timeout)
|
||
job_resp.raise_for_status()
|
||
job_resp = job_resp.json()
|
||
if not job_resp["success"]:
|
||
raise Exception("请求Modal API失败")
|
||
job_id = job_resp["taskId"]
|
||
|
||
wait_time = 240
|
||
interval = 3
|
||
logger.info("开始轮询任务状态")
|
||
for _ in range(0, wait_time, interval):
|
||
logger.info("查询任务状态")
|
||
result = send_request("get", f"https://{endpoint}/google/{job_id}", timeout=timeout)
|
||
if result.status_code == 200:
|
||
result = result.json()
|
||
if result["status"] == "success":
|
||
logger.success("任务成功")
|
||
image_b64 = json.loads(result["result"])[0]["image_b64"]
|
||
image_tensor = base64_to_tensor(image_b64)
|
||
return (image_tensor,)
|
||
elif "fail" in result["status"].lower():
|
||
raise Exception("任务失败")
|
||
sleep(interval)
|
||
raise Exception("查询任务状态超时")
|
||
except Exception as e:
|
||
raise Exception(e)
|