ADD 添加jinja2模板渲染节点
This commit is contained in:
parent
2753a6b8a0
commit
fe30757baf
|
|
@ -1,4 +1,4 @@
|
|||
from .nodes.llm_api import LLMChat, LLMChatMultiModalImageUpload, LLMChatMultiModalImageTensor
|
||||
from .nodes.llm_api import LLMChat, LLMChatMultiModalImageUpload, LLMChatMultiModalImageTensor, Jinja2RenderTemplate
|
||||
from .nodes.compute_video_point import VideoStartPointDurationCompute
|
||||
from .nodes.cos import COSUpload, COSDownload
|
||||
from .nodes.face_detect import FaceDetect
|
||||
|
|
@ -67,7 +67,8 @@ NODE_CLASS_MAPPINGS = {
|
|||
"SaveImageWithOutput": SaveImageWithOutput,
|
||||
"LLMChat": LLMChat,
|
||||
"LLMChatMultiModalImageUpload": LLMChatMultiModalImageUpload,
|
||||
"LLMChatMultiModalImageTensor": LLMChatMultiModalImageTensor
|
||||
"LLMChatMultiModalImageTensor": LLMChatMultiModalImageTensor,
|
||||
"Jinja2RenderTemplate": Jinja2RenderTemplate
|
||||
}
|
||||
|
||||
# A dictionary that contains the friendly/humanly readable titles for the nodes
|
||||
|
|
@ -108,5 +109,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||
"SaveImageWithOutput": "保存图片(带输出)",
|
||||
"LLMChat": "LLM调用",
|
||||
"LLMChatMultiModalImageUpload": "多模态LLM调用-图片Path",
|
||||
"LLMChatMultiModalImageTensor": "多模态LLM调用-图片Tensor"
|
||||
"LLMChatMultiModalImageTensor": "多模态LLM调用-图片Tensor",
|
||||
"Jinja2RenderTemplate": "Jinja2格式Prompt模板渲染"
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
# LLM API 通过cloudflare gateway调用llm
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from mimetypes import guess_type
|
||||
|
|
@ -10,6 +11,7 @@ import httpx
|
|||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from jinja2 import Template, StrictUndefined
|
||||
from retry import retry
|
||||
|
||||
import folder_paths
|
||||
|
|
@ -230,4 +232,40 @@ class LLMChatMultiModalImageTensor:
|
|||
# logger.exception("llm调用失败 {}".format(e))
|
||||
raise Exception("llm调用失败 {}".format(e))
|
||||
return (content,)
|
||||
return _chat()
|
||||
return _chat()
|
||||
|
||||
class Jinja2RenderTemplate:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"template": ("STRING", {"multiline": True}),
|
||||
"kv_map": ("STRING", {"multiline": True}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("STRING",)
|
||||
RETURN_NAMES = ("prompt",)
|
||||
FUNCTION = "render_prompt"
|
||||
CATEGORY = "不忘科技-自定义节点🚩/llm"
|
||||
|
||||
def render_prompt(self, template: str, kv_map: str) -> tuple:
|
||||
"""
|
||||
使用Jinja2渲染prompt模板
|
||||
|
||||
参数:
|
||||
template: 包含Jinja2标记的模板字符串
|
||||
kv_map: 键值映射字典,用于提供模板渲染所需的变量
|
||||
|
||||
返回:
|
||||
渲染后的字符串
|
||||
|
||||
异常:
|
||||
如果模板中有未定义的变量,抛出jinja2.exceptions.UndefinedError
|
||||
"""
|
||||
kv_map = json.loads(kv_map)
|
||||
# 创建模板对象,设置为严格模式,未定义变量会抛出异常
|
||||
template = Template(template, undefined=StrictUndefined)
|
||||
|
||||
# 渲染模板
|
||||
return (template.render(kv_map),)
|
||||
Loading…
Reference in New Issue