diff --git a/__init__.py b/__init__.py index f09aa23..d3f1b89 100644 --- a/__init__.py +++ b/__init__.py @@ -1,4 +1,4 @@ -from .nodes.llm_api import LLMChat, LLMChatMultiModalImageUpload, LLMChatMultiModalImageTensor +from .nodes.llm_api import LLMChat, LLMChatMultiModalImageUpload, LLMChatMultiModalImageTensor, Jinja2RenderTemplate from .nodes.compute_video_point import VideoStartPointDurationCompute from .nodes.cos import COSUpload, COSDownload from .nodes.face_detect import FaceDetect @@ -67,7 +67,8 @@ NODE_CLASS_MAPPINGS = { "SaveImageWithOutput": SaveImageWithOutput, "LLMChat": LLMChat, "LLMChatMultiModalImageUpload": LLMChatMultiModalImageUpload, - "LLMChatMultiModalImageTensor": LLMChatMultiModalImageTensor + "LLMChatMultiModalImageTensor": LLMChatMultiModalImageTensor, + "Jinja2RenderTemplate": Jinja2RenderTemplate } # A dictionary that contains the friendly/humanly readable titles for the nodes @@ -108,5 +109,6 @@ NODE_DISPLAY_NAME_MAPPINGS = { "SaveImageWithOutput": "保存图片(带输出)", "LLMChat": "LLM调用", "LLMChatMultiModalImageUpload": "多模态LLM调用-图片Path", - "LLMChatMultiModalImageTensor": "多模态LLM调用-图片Tensor" + "LLMChatMultiModalImageTensor": "多模态LLM调用-图片Tensor", + "Jinja2RenderTemplate": "Jinja2格式Prompt模板渲染" } diff --git a/nodes/llm_api.py b/nodes/llm_api.py index fb575d2..10ce680 100644 --- a/nodes/llm_api.py +++ b/nodes/llm_api.py @@ -1,6 +1,7 @@ # LLM API 通过cloudflare gateway调用llm import base64 import io +import json import os import re from mimetypes import guess_type @@ -10,6 +11,7 @@ import httpx import numpy as np import torch from PIL import Image +from jinja2 import Template, StrictUndefined from retry import retry import folder_paths @@ -230,4 +232,40 @@ class LLMChatMultiModalImageTensor: # logger.exception("llm调用失败 {}".format(e)) raise Exception("llm调用失败 {}".format(e)) return (content,) - return _chat() \ No newline at end of file + return _chat() + +class Jinja2RenderTemplate: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "template": ("STRING", {"multiline": True}), + "kv_map": ("STRING", {"multiline": True}), + } + } + + RETURN_TYPES = ("STRING",) + RETURN_NAMES = ("prompt",) + FUNCTION = "render_prompt" + CATEGORY = "不忘科技-自定义节点🚩/llm" + + def render_prompt(self, template: str, kv_map: str) -> tuple: + """ + 使用Jinja2渲染prompt模板 + + 参数: + template: 包含Jinja2标记的模板字符串 + kv_map: 键值映射字典,用于提供模板渲染所需的变量 + + 返回: + 渲染后的字符串 + + 异常: + 如果模板中有未定义的变量,抛出jinja2.exceptions.UndefinedError + """ + kv_map = json.loads(kv_map) + # 创建模板对象,设置为严格模式,未定义变量会抛出异常 + template = Template(template, undefined=StrictUndefined) + + # 渲染模板 + return (template.render(kv_map),) \ No newline at end of file