diff --git a/__init__.py b/__init__.py index 97322b7..3305108 100644 --- a/__init__.py +++ b/__init__.py @@ -4,7 +4,7 @@ from .nodes.image_face_nodes import FaceDetect, FaceExtract from .nodes.image_gesture_nodes import JMGestureCorrect, JMCustom from .nodes.image_nodes import SaveImagePath, SaveImageWithOutput, LoadImgOptional from .nodes.llm_nodes import LLMChat, LLMChatMultiModalImageUpload, LLMChatMultiModalImageTensor, Jinja2RenderTemplate -from .nodes.object_storage_nodes import COSUpload, COSDownload, S3Download, S3Upload, S3UploadURL +from .nodes.object_storage_nodes import COSUpload, COSDownload, S3Download, S3Upload, S3UploadURL, S3UploadIMAGEURL from .nodes.text_nodes import StringEmptyJudgement, LoadText, RandomLineSelector from .nodes.util_nodes import LogToDB, TaskIdGenerate, TraverseFolder, UnloadAllModels, VodToLocalNode, \ PlugAndPlayWebhook @@ -18,6 +18,7 @@ NODE_CLASS_MAPPINGS = { "COSDownload": COSDownload, "S3Upload": S3Upload, "S3UploadURL": S3UploadURL, + "S3UploadIMAGEURL": S3UploadIMAGEURL, "S3Download": S3Download, "VideoCutCustom": VideoCut, "VideoCutByFramePoint": VideoCutByFramePoint, @@ -52,10 +53,11 @@ NODE_CLASS_MAPPINGS = { NODE_DISPLAY_NAME_MAPPINGS = { "FaceOccDetect": "面部遮挡检测", "FaceExtract": "面部提取", - "COSUpload": "COS上传", + "COSUpload": "COS上传-返回key", "COSDownload": "COS下载", - "S3Upload": "S3上传", + "S3Upload": "S3上传-返回key", "S3UploadURL": "S3上传-返回URL", + "S3UploadIMAGEURL": "S3上传图片-返回URL", "S3Download": "S3下载", "VideoCutCustom": "视频剪裁", "VideoCutByFramePoint": "视频剪裁(精确帧位)", diff --git a/nodes/object_storage_nodes.py b/nodes/object_storage_nodes.py index 93af8fd..2c93235 100644 --- a/nodes/object_storage_nodes.py +++ b/nodes/object_storage_nodes.py @@ -2,9 +2,12 @@ import os import boto3 import loguru +import torch import yaml from qcloud_cos import CosConfig, CosS3Client, CosClientError, CosServiceError +from ..utils.image_utils import tensor_to_tempfile + class COSDownload: """腾讯云COS下载""" @@ -235,7 +238,7 @@ class S3UploadURL: } RETURN_TYPES = ("STRING",) - RETURN_NAMES = ("S3文件Key",) + RETURN_NAMES = ("URL",) FUNCTION = "upload" CATEGORY = "不忘科技-自定义节点🚩/对象存储/S3" @@ -267,3 +270,51 @@ class S3UploadURL: raise Exception(f"S3上传失败! bucket {s3_bucket}; local_path {path}; subfolder {subfolder}") url = f"https://cdn.roasmax.cn/{dest_key}" return (url,) + + +class S3UploadIMAGEURL: + """AWS S3上传 返回URL""" + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image": ("IMAGE", {"multiline": True}), + "subfolder": ("STRING", {"default": "test"}), + } + } + + RETURN_TYPES = ("STRING",) + RETURN_NAMES = ("URL",) + + FUNCTION = "upload" + CATEGORY = "不忘科技-自定义节点🚩/对象存储/S3" + + def upload(self, image:torch.Tensor, subfolder): + s3_bucket = "modal-media-cache" + loguru.logger.info(f"S3 UPLOAD image to {s3_bucket}/{subfolder}") + path = tensor_to_tempfile(image).name + try: + if "aws_key_id" in list(os.environ.keys()): + yaml_config = os.environ + else: + with open(os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "config.yaml"), + encoding="utf-8", mode="r+") as f: + yaml_config = yaml.load(f, Loader=yaml.FullLoader) + client = boto3.client("s3", aws_access_key_id=yaml_config["aws_key_id"], + aws_secret_access_key=yaml_config["aws_access_key"]) + dest_key = "/".join( + [ + subfolder, + ( + path.split("/")[-1] + if "/" in path + else path.split("\\")[-1] + ), + ] + ) + client.upload_file(path, s3_bucket, dest_key) + except Exception as e: + raise Exception(f"S3上传失败! bucket {s3_bucket}; local_path {path}; subfolder {subfolder}") + url = f"https://cdn.roasmax.cn/{dest_key}" + return (url,) \ No newline at end of file diff --git a/utils/image_utils.py b/utils/image_utils.py index 3bc29dd..a7be855 100644 --- a/utils/image_utils.py +++ b/utils/image_utils.py @@ -1,5 +1,6 @@ import base64 import io +import tempfile import torch from PIL import Image @@ -71,3 +72,63 @@ def tensor_to_image_bytes(tensor: torch.Tensor, format: str = 'PNG') -> bytes: buffer.seek(0) # 重置指针到开始位置 return buffer.getvalue() + + +def tensor_to_tempfile(tensor: torch.Tensor, format: str = 'PNG', + normalize: bool = True, range=None) -> tempfile.NamedTemporaryFile: + """ + 将PyTorch张量转换为图像并保存到临时文件 + + 参数: + tensor: 输入的PyTorch张量,可以是4D(BCHW)、3D(CHW)或2D(HW) + format: 图像格式,如'PNG'、'JPEG'等 + normalize: 是否对张量进行归一化处理 + range: 归一化范围,元组(min, max),默认为张量的最小值和最大值 + + 返回: + 临时文件对象,关闭后会自动删除 + """ + # 处理4D张量 (BCHW),只取第一个样本 + + if tensor.dim() == 4: + if tensor.size(0) > 1: + print(f"警告: 输入张量包含多个样本,仅使用第一个样本 ({tensor.size(0)} -> 1)") + tensor = tensor[0] + + # 确保张量在CPU上 + tensor = tensor.cpu() + tensor = tensor.permute(2,0,1) + + # 归一化处理 + if normalize: + if range is None: + min_val, max_val = tensor.min(), tensor.max() + else: + min_val, max_val = range + + if max_val > min_val: + tensor = (tensor - min_val) / (max_val - min_val) + else: + tensor = torch.zeros_like(tensor) + + # 转换为PIL图像 + if tensor.dim() == 2: # HW格式 (灰度图) + pil_img = transforms.ToPILImage()(tensor.unsqueeze(0)) # 添加通道维度 + elif tensor.dim() == 3: # CHW格式 + pil_img = transforms.ToPILImage()(tensor) + else: + raise ValueError(f"不支持的张量维度: {tensor.dim()}") + + # 创建临时文件 + temp_file = tempfile.NamedTemporaryFile(suffix=f'.{format.lower()}', delete=False) + + try: + # 保存图像到临时文件 + pil_img.save(temp_file, format=format) + except Exception as e: + # 发生错误时删除临时文件 + temp_file.close() + raise e + + temp_file.close() # 关闭文件但不删除 + return temp_file \ No newline at end of file