ADD 增加上传图片返回url接口
This commit is contained in:
parent
5fee351bbd
commit
bbea095da9
|
|
@ -4,7 +4,7 @@ from .nodes.image_face_nodes import FaceDetect, FaceExtract
|
||||||
from .nodes.image_gesture_nodes import JMGestureCorrect, JMCustom
|
from .nodes.image_gesture_nodes import JMGestureCorrect, JMCustom
|
||||||
from .nodes.image_nodes import SaveImagePath, SaveImageWithOutput, LoadImgOptional
|
from .nodes.image_nodes import SaveImagePath, SaveImageWithOutput, LoadImgOptional
|
||||||
from .nodes.llm_nodes import LLMChat, LLMChatMultiModalImageUpload, LLMChatMultiModalImageTensor, Jinja2RenderTemplate
|
from .nodes.llm_nodes import LLMChat, LLMChatMultiModalImageUpload, LLMChatMultiModalImageTensor, Jinja2RenderTemplate
|
||||||
from .nodes.object_storage_nodes import COSUpload, COSDownload, S3Download, S3Upload, S3UploadURL
|
from .nodes.object_storage_nodes import COSUpload, COSDownload, S3Download, S3Upload, S3UploadURL, S3UploadIMAGEURL
|
||||||
from .nodes.text_nodes import StringEmptyJudgement, LoadText, RandomLineSelector
|
from .nodes.text_nodes import StringEmptyJudgement, LoadText, RandomLineSelector
|
||||||
from .nodes.util_nodes import LogToDB, TaskIdGenerate, TraverseFolder, UnloadAllModels, VodToLocalNode, \
|
from .nodes.util_nodes import LogToDB, TaskIdGenerate, TraverseFolder, UnloadAllModels, VodToLocalNode, \
|
||||||
PlugAndPlayWebhook
|
PlugAndPlayWebhook
|
||||||
|
|
@ -18,6 +18,7 @@ NODE_CLASS_MAPPINGS = {
|
||||||
"COSDownload": COSDownload,
|
"COSDownload": COSDownload,
|
||||||
"S3Upload": S3Upload,
|
"S3Upload": S3Upload,
|
||||||
"S3UploadURL": S3UploadURL,
|
"S3UploadURL": S3UploadURL,
|
||||||
|
"S3UploadIMAGEURL": S3UploadIMAGEURL,
|
||||||
"S3Download": S3Download,
|
"S3Download": S3Download,
|
||||||
"VideoCutCustom": VideoCut,
|
"VideoCutCustom": VideoCut,
|
||||||
"VideoCutByFramePoint": VideoCutByFramePoint,
|
"VideoCutByFramePoint": VideoCutByFramePoint,
|
||||||
|
|
@ -52,10 +53,11 @@ NODE_CLASS_MAPPINGS = {
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
"FaceOccDetect": "面部遮挡检测",
|
"FaceOccDetect": "面部遮挡检测",
|
||||||
"FaceExtract": "面部提取",
|
"FaceExtract": "面部提取",
|
||||||
"COSUpload": "COS上传",
|
"COSUpload": "COS上传-返回key",
|
||||||
"COSDownload": "COS下载",
|
"COSDownload": "COS下载",
|
||||||
"S3Upload": "S3上传",
|
"S3Upload": "S3上传-返回key",
|
||||||
"S3UploadURL": "S3上传-返回URL",
|
"S3UploadURL": "S3上传-返回URL",
|
||||||
|
"S3UploadIMAGEURL": "S3上传图片-返回URL",
|
||||||
"S3Download": "S3下载",
|
"S3Download": "S3下载",
|
||||||
"VideoCutCustom": "视频剪裁",
|
"VideoCutCustom": "视频剪裁",
|
||||||
"VideoCutByFramePoint": "视频剪裁(精确帧位)",
|
"VideoCutByFramePoint": "视频剪裁(精确帧位)",
|
||||||
|
|
|
||||||
|
|
@ -2,9 +2,12 @@ import os
|
||||||
|
|
||||||
import boto3
|
import boto3
|
||||||
import loguru
|
import loguru
|
||||||
|
import torch
|
||||||
import yaml
|
import yaml
|
||||||
from qcloud_cos import CosConfig, CosS3Client, CosClientError, CosServiceError
|
from qcloud_cos import CosConfig, CosS3Client, CosClientError, CosServiceError
|
||||||
|
|
||||||
|
from ..utils.image_utils import tensor_to_tempfile
|
||||||
|
|
||||||
|
|
||||||
class COSDownload:
|
class COSDownload:
|
||||||
"""腾讯云COS下载"""
|
"""腾讯云COS下载"""
|
||||||
|
|
@ -235,7 +238,7 @@ class S3UploadURL:
|
||||||
}
|
}
|
||||||
|
|
||||||
RETURN_TYPES = ("STRING",)
|
RETURN_TYPES = ("STRING",)
|
||||||
RETURN_NAMES = ("S3文件Key",)
|
RETURN_NAMES = ("URL",)
|
||||||
|
|
||||||
FUNCTION = "upload"
|
FUNCTION = "upload"
|
||||||
CATEGORY = "不忘科技-自定义节点🚩/对象存储/S3"
|
CATEGORY = "不忘科技-自定义节点🚩/对象存储/S3"
|
||||||
|
|
@ -267,3 +270,51 @@ class S3UploadURL:
|
||||||
raise Exception(f"S3上传失败! bucket {s3_bucket}; local_path {path}; subfolder {subfolder}")
|
raise Exception(f"S3上传失败! bucket {s3_bucket}; local_path {path}; subfolder {subfolder}")
|
||||||
url = f"https://cdn.roasmax.cn/{dest_key}"
|
url = f"https://cdn.roasmax.cn/{dest_key}"
|
||||||
return (url,)
|
return (url,)
|
||||||
|
|
||||||
|
|
||||||
|
class S3UploadIMAGEURL:
|
||||||
|
"""AWS S3上传 返回URL"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"image": ("IMAGE", {"multiline": True}),
|
||||||
|
"subfolder": ("STRING", {"default": "test"}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("STRING",)
|
||||||
|
RETURN_NAMES = ("URL",)
|
||||||
|
|
||||||
|
FUNCTION = "upload"
|
||||||
|
CATEGORY = "不忘科技-自定义节点🚩/对象存储/S3"
|
||||||
|
|
||||||
|
def upload(self, image:torch.Tensor, subfolder):
|
||||||
|
s3_bucket = "modal-media-cache"
|
||||||
|
loguru.logger.info(f"S3 UPLOAD image to {s3_bucket}/{subfolder}")
|
||||||
|
path = tensor_to_tempfile(image).name
|
||||||
|
try:
|
||||||
|
if "aws_key_id" in list(os.environ.keys()):
|
||||||
|
yaml_config = os.environ
|
||||||
|
else:
|
||||||
|
with open(os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "config.yaml"),
|
||||||
|
encoding="utf-8", mode="r+") as f:
|
||||||
|
yaml_config = yaml.load(f, Loader=yaml.FullLoader)
|
||||||
|
client = boto3.client("s3", aws_access_key_id=yaml_config["aws_key_id"],
|
||||||
|
aws_secret_access_key=yaml_config["aws_access_key"])
|
||||||
|
dest_key = "/".join(
|
||||||
|
[
|
||||||
|
subfolder,
|
||||||
|
(
|
||||||
|
path.split("/")[-1]
|
||||||
|
if "/" in path
|
||||||
|
else path.split("\\")[-1]
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
client.upload_file(path, s3_bucket, dest_key)
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"S3上传失败! bucket {s3_bucket}; local_path {path}; subfolder {subfolder}")
|
||||||
|
url = f"https://cdn.roasmax.cn/{dest_key}"
|
||||||
|
return (url,)
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
import base64
|
import base64
|
||||||
import io
|
import io
|
||||||
|
import tempfile
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
@ -71,3 +72,63 @@ def tensor_to_image_bytes(tensor: torch.Tensor, format: str = 'PNG') -> bytes:
|
||||||
buffer.seek(0) # 重置指针到开始位置
|
buffer.seek(0) # 重置指针到开始位置
|
||||||
|
|
||||||
return buffer.getvalue()
|
return buffer.getvalue()
|
||||||
|
|
||||||
|
|
||||||
|
def tensor_to_tempfile(tensor: torch.Tensor, format: str = 'PNG',
|
||||||
|
normalize: bool = True, range=None) -> tempfile.NamedTemporaryFile:
|
||||||
|
"""
|
||||||
|
将PyTorch张量转换为图像并保存到临时文件
|
||||||
|
|
||||||
|
参数:
|
||||||
|
tensor: 输入的PyTorch张量,可以是4D(BCHW)、3D(CHW)或2D(HW)
|
||||||
|
format: 图像格式,如'PNG'、'JPEG'等
|
||||||
|
normalize: 是否对张量进行归一化处理
|
||||||
|
range: 归一化范围,元组(min, max),默认为张量的最小值和最大值
|
||||||
|
|
||||||
|
返回:
|
||||||
|
临时文件对象,关闭后会自动删除
|
||||||
|
"""
|
||||||
|
# 处理4D张量 (BCHW),只取第一个样本
|
||||||
|
|
||||||
|
if tensor.dim() == 4:
|
||||||
|
if tensor.size(0) > 1:
|
||||||
|
print(f"警告: 输入张量包含多个样本,仅使用第一个样本 ({tensor.size(0)} -> 1)")
|
||||||
|
tensor = tensor[0]
|
||||||
|
|
||||||
|
# 确保张量在CPU上
|
||||||
|
tensor = tensor.cpu()
|
||||||
|
tensor = tensor.permute(2,0,1)
|
||||||
|
|
||||||
|
# 归一化处理
|
||||||
|
if normalize:
|
||||||
|
if range is None:
|
||||||
|
min_val, max_val = tensor.min(), tensor.max()
|
||||||
|
else:
|
||||||
|
min_val, max_val = range
|
||||||
|
|
||||||
|
if max_val > min_val:
|
||||||
|
tensor = (tensor - min_val) / (max_val - min_val)
|
||||||
|
else:
|
||||||
|
tensor = torch.zeros_like(tensor)
|
||||||
|
|
||||||
|
# 转换为PIL图像
|
||||||
|
if tensor.dim() == 2: # HW格式 (灰度图)
|
||||||
|
pil_img = transforms.ToPILImage()(tensor.unsqueeze(0)) # 添加通道维度
|
||||||
|
elif tensor.dim() == 3: # CHW格式
|
||||||
|
pil_img = transforms.ToPILImage()(tensor)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"不支持的张量维度: {tensor.dim()}")
|
||||||
|
|
||||||
|
# 创建临时文件
|
||||||
|
temp_file = tempfile.NamedTemporaryFile(suffix=f'.{format.lower()}', delete=False)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 保存图像到临时文件
|
||||||
|
pil_img.save(temp_file, format=format)
|
||||||
|
except Exception as e:
|
||||||
|
# 发生错误时删除临时文件
|
||||||
|
temp_file.close()
|
||||||
|
raise e
|
||||||
|
|
||||||
|
temp_file.close() # 关闭文件但不删除
|
||||||
|
return temp_file
|
||||||
Loading…
Reference in New Issue