ADD 添加jinja2模板渲染节点

This commit is contained in:
kyj@bowong.ai 2025-07-08 18:28:05 +08:00
parent 2753a6b8a0
commit fe30757baf
2 changed files with 44 additions and 4 deletions

View File

@ -1,4 +1,4 @@
from .nodes.llm_api import LLMChat, LLMChatMultiModalImageUpload, LLMChatMultiModalImageTensor
from .nodes.llm_api import LLMChat, LLMChatMultiModalImageUpload, LLMChatMultiModalImageTensor, Jinja2RenderTemplate
from .nodes.compute_video_point import VideoStartPointDurationCompute
from .nodes.cos import COSUpload, COSDownload
from .nodes.face_detect import FaceDetect
@ -67,7 +67,8 @@ NODE_CLASS_MAPPINGS = {
"SaveImageWithOutput": SaveImageWithOutput,
"LLMChat": LLMChat,
"LLMChatMultiModalImageUpload": LLMChatMultiModalImageUpload,
"LLMChatMultiModalImageTensor": LLMChatMultiModalImageTensor
"LLMChatMultiModalImageTensor": LLMChatMultiModalImageTensor,
"Jinja2RenderTemplate": Jinja2RenderTemplate
}
# A dictionary that contains the friendly/humanly readable titles for the nodes
@ -108,5 +109,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"SaveImageWithOutput": "保存图片(带输出)",
"LLMChat": "LLM调用",
"LLMChatMultiModalImageUpload": "多模态LLM调用-图片Path",
"LLMChatMultiModalImageTensor": "多模态LLM调用-图片Tensor"
"LLMChatMultiModalImageTensor": "多模态LLM调用-图片Tensor",
"Jinja2RenderTemplate": "Jinja2格式Prompt模板渲染"
}

View File

@ -1,6 +1,7 @@
# LLM API 通过cloudflare gateway调用llm
import base64
import io
import json
import os
import re
from mimetypes import guess_type
@ -10,6 +11,7 @@ import httpx
import numpy as np
import torch
from PIL import Image
from jinja2 import Template, StrictUndefined
from retry import retry
import folder_paths
@ -230,4 +232,40 @@ class LLMChatMultiModalImageTensor:
# logger.exception("llm调用失败 {}".format(e))
raise Exception("llm调用失败 {}".format(e))
return (content,)
return _chat()
return _chat()
class Jinja2RenderTemplate:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"template": ("STRING", {"multiline": True}),
"kv_map": ("STRING", {"multiline": True}),
}
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("prompt",)
FUNCTION = "render_prompt"
CATEGORY = "不忘科技-自定义节点🚩/llm"
def render_prompt(self, template: str, kv_map: str) -> tuple:
"""
使用Jinja2渲染prompt模板
参数:
template: 包含Jinja2标记的模板字符串
kv_map: 键值映射字典用于提供模板渲染所需的变量
返回:
渲染后的字符串
异常:
如果模板中有未定义的变量抛出jinja2.exceptions.UndefinedError
"""
kv_map = json.loads(kv_map)
# 创建模板对象,设置为严格模式,未定义变量会抛出异常
template = Template(template, undefined=StrictUndefined)
# 渲染模板
return (template.render(kv_map),)