62 lines
2.0 KiB
Python
62 lines
2.0 KiB
Python
from io import BytesIO
|
|
|
|
import numpy as np
|
|
import requests
|
|
import torch
|
|
from PIL import Image
|
|
|
|
|
|
# 定义节点类
|
|
class LoadNetImg:
|
|
# 定义节点输入类型
|
|
@classmethod
|
|
def INPUT_TYPES(cls):
|
|
return {
|
|
"required": {
|
|
"image_url": ("STRING", {
|
|
"default": "https://example.com/sample.jpg",
|
|
"multiline": False
|
|
}),
|
|
}
|
|
}
|
|
|
|
# 定义节点输出类型
|
|
RETURN_TYPES = ("IMAGE",) # 返回图像数据
|
|
RETURN_NAMES = ("image",) # 命名返回值
|
|
FUNCTION = "load_image_task" # 函数标识符,方便注册多个节点功能
|
|
OUTPUT_NODE = False # 不允许该节点直接作为最终输出节点
|
|
CATEGORY = "image" # 节点所属类别(在 ComfyUI 界面中分类)
|
|
|
|
def load_image_task(self, image_url):
|
|
try:
|
|
if not image_url or not image_url.strip():
|
|
raise ValueError("需要提供图片URL")
|
|
|
|
# 下载网络图片
|
|
response = requests.get(image_url)
|
|
response.raise_for_status() # 请求失败时抛出异常
|
|
image = Image.open(BytesIO(response.content)).convert("RGB")
|
|
|
|
# 按照官方格式转换图像数据
|
|
# Convert to numpy array and normalize to 0-1
|
|
image_array = np.array(image).astype(np.float32) / 255.0
|
|
# Convert to torch tensor and add batch dimension
|
|
image_tensor = torch.from_numpy(image_array)[None,]
|
|
|
|
return (image_tensor,) # 返回torch张量
|
|
except Exception as e:
|
|
print(f"Error loading image: {e}")
|
|
# 返回一个空的黑色图片作为错误处理
|
|
empty_image = torch.zeros((1, 512, 512, 3), dtype=torch.float32)
|
|
return (empty_image,)
|
|
|
|
|
|
|
|
|
|
# 映射节点类和名称
|
|
NODE_CLASS_MAPPINGS = {
|
|
"LoadNetImg": LoadNetImg, # 将类映射到节点名称
|
|
}
|
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
|
"LoadNetImg": "load_net_image", # 节点显示名称
|
|
} |