ComfyUI-CustomNode/nodes/image_modal_nodes.py

312 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import io
import json
from time import sleep
import folder_paths
import requests
import torch
from PIL import Image
from loguru import logger
from torchvision import transforms
from ..utils.http_utils import send_request
from ..utils.image_utils import tensor_to_image_bytes, base64_to_tensor
def url_to_tensor(image_url: str, max_retries: int = 3):
"""
从URL下载图片并转换为PyTorch张量增强错误处理能力
参数:
image_url (str): 图片URL
max_retries (int): 最大重试次数
返回:
torch.Tensor: 形状为[C, H, W]的张量
异常:
HTTPError: 网络请求失败
ValueError: 无效图片格式
"""
headers = {
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'}
for attempt in range(max_retries):
try:
# 发送带User-Agent的请求
response = requests.get(image_url, headers=headers, stream=True, timeout=15)
response.raise_for_status()
# 检查内容类型是否为图像
content_type = response.headers.get('Content-Type', '')
if not content_type.startswith('image/'):
raise ValueError(f"URL返回非图像内容: {content_type}")
# 验证图像完整性
img_data = response.content
if len(img_data) < 100: # 极小数据通常不是有效图像
raise ValueError("下载的内容过小,可能不是完整图像")
# 尝试打开图像
img = Image.open(io.BytesIO(img_data)).convert('RGB')
# 转换为张量
transform = transforms.Compose([
transforms.ToTensor()
])
return transform(img).unsqueeze(0).permute(0, 2, 3, 1)
except (requests.exceptions.RequestException, ValueError) as e:
logger.warning(f"尝试 {attempt + 1}/{max_retries} 失败: {e}")
if attempt == max_retries - 1:
raise e
class ModalClothesMask:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",),
"mask_color": ("STRING", {"default": "绿色"}),
"clothes_type": ("STRING", {"default": "裤子"}),
"endpoint": ("STRING", {"default": "bowongai-dev--bowong-ai-video-gemini-fastapi-webapp.modal.run"}),
},
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("image",)
FUNCTION = "process"
OUTPUT_NODE = False
CATEGORY = "不忘科技-自定义节点🚩/图片/Gemini"
def process(self, image: torch.Tensor, mask_color: str, clothes_type: str, endpoint: str):
try:
timeout = 60
logger.info("获取token")
api_key = send_request("get", f"https://{endpoint}/google/access-token",
headers={'Authorization': 'Bearer bowong7777'}, timeout=timeout).json()[
"access_token"]
format = "PNG"
logger.info("请求图像编辑")
job_resp = send_request("post", f"https://{endpoint}/google/image/clothes_mark",
headers={'x-google-api-key': api_key},
data={
"mark_clothes_type": clothes_type,
"mark_color": mask_color,
},
files={"origin_image": (
'image.' + format.lower(), tensor_to_image_bytes(image, format),
f'image/{format.lower()}')},
timeout=timeout)
job_resp.raise_for_status()
job_resp = job_resp.json()
if not job_resp["success"]:
raise Exception("请求Modal API失败")
job_id = job_resp["taskId"]
wait_time = 240
interval = 2
logger.info("开始轮询任务状态")
sleep(1)
for _ in range(0, wait_time, interval):
logger.info("查询任务状态")
result = send_request("get", f"https://{endpoint}/google/{job_id}",
headers={'Authorization': 'Bearer bowong7777'}, timeout=timeout)
if result.status_code == 200:
result = result.json()
if result["status"] == "success":
logger.success("任务成功")
image_b64 = json.loads(result["result"])[0]["image_b64"]
image_tensor = base64_to_tensor(image_b64)
return (image_tensor,)
elif "fail" in result["status"].lower():
raise Exception("任务失败")
sleep(interval)
raise Exception("查询任务状态超时")
except Exception as e:
raise Exception(e)
class ModalEditCustom:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"prompt": ("STRING", {"default": "将背景去除,输出原尺寸图片", "multiline": True}),
"temperature": ("FLOAT", {"default": 0.1, "min": 0, "max": 2}),
"topP": ("FLOAT", {"default": 0.7, "min": 0, "max": 1}),
"endpoint": ("STRING", {"default": "bowongai-dev--bowong-ai-video-gemini-fastapi-webapp.modal.run"}),
},
"optional": {
"image": ("IMAGE",),
}
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("image",)
FUNCTION = "process"
OUTPUT_NODE = False
CATEGORY = "不忘科技-自定义节点🚩/图片/Gemini"
def process(self, prompt: str, temperature: float, topP: float, endpoint: str, **kwargs):
try:
timeout = 60
logger.info("获取token")
api_key = send_request("get", f"https://{endpoint}/google/access-token",
headers={'Authorization': 'Bearer bowong7777'}, timeout=timeout).json()[
"access_token"]
format = "PNG"
if "image" in kwargs and kwargs["image"] is not None:
image = kwargs["image"]
files = {"origin_image": (
'image.' + format.lower(), tensor_to_image_bytes(image, format),
f'image/{format.lower()}')}
else:
files = None
logger.info("请求图像编辑")
job_resp = send_request("post", f"https://{endpoint}/google/image/edit_custom",
headers={'x-google-api-key': api_key},
data={
"prompt": prompt,
"temperature": temperature,
"topP": topP
},
files=files,
timeout=timeout)
job_resp.raise_for_status()
job_resp = job_resp.json()
if not job_resp["success"]:
raise Exception("请求Modal API失败")
job_id = job_resp["taskId"]
wait_time = 240
interval = 2
logger.info("开始轮询任务状态")
sleep(1)
for _ in range(0, wait_time, interval):
logger.info("查询任务状态")
result = send_request("get", f"https://{endpoint}/google/{job_id}",
headers={'Authorization': 'Bearer bowong7777'}, timeout=timeout)
if result.status_code == 200:
result = result.json()
if result["status"] == "success":
logger.success("任务成功")
image_b64 = json.loads(result["result"])[0]["image_b64"]
image_tensor = base64_to_tensor(image_b64)
return (image_tensor,)
elif "fail" in result["status"].lower():
raise Exception("任务失败")
sleep(interval)
raise Exception("查询任务状态超时")
except Exception as e:
raise Exception(e)
class ModalMidJourneyGenerateImage:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"prompt": ("STRING", {"default": "一幅宏大壮美的山川画卷", "multiline": True}),
"provider":(["ttapi","302ai"],),
"endpoint": ("STRING", {"default": "bowongai-test--text-video-agent-fastapi-app.modal.run"}),
"timeout": ("INT", {"default": 300, "min": 10, "max": 1200}),
},
"optional": {
"image": ("IMAGE",),
}
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("image",)
FUNCTION = "process"
OUTPUT_NODE = False
CATEGORY = "不忘科技-自定义节点🚩/图片/Midjourney"
def process(self, prompt: str, provider:str, endpoint: str, timeout: int, **kwargs):
try:
logger.info("请求同步接口")
format = "PNG"
if "image" in kwargs and kwargs["image"] is not None:
image = kwargs["image"]
files = {"img_file": (
'image.' + format.lower(), tensor_to_image_bytes(image, format),
f'image/{format.lower()}')}
else:
files = None
if provider == "302ai":
interval = 3
logger.info("提交任务")
logger.info(f"https://{endpoint}/api/custom/image/submit/task")
job_resp = send_request("post", f"https://{endpoint}/api/custom/image/submit/task",
data={"model_name":"302ai/mj", "prompt": prompt, "mode": "turbo"},
files=files,
timeout=timeout)
job_resp.raise_for_status()
job_resp = job_resp.json()
if not job_resp["status"]:
raise Exception("生成失败, 可能因为风控")
job_id = job_resp["data"]
for _ in range(0, timeout // interval, interval):
logger.info("等待" + str(interval) + "")
sleep(interval)
logger.info("查询结果")
resp = send_request("get", f"https://{endpoint}/api/custom/task/status?task_id={job_id}", timeout=30)
resp.raise_for_status()
if resp.json()["status"] == "running":
logger.info("任务正在运行")
continue
if resp.json()["status"] == "failed":
raise Exception(f"生成失败: {resp.json()['msg']}")
if resp.json()["status"] == "success":
result_url = resp.json()["data"]
if not isinstance(result_url, list):
raise Exception("生成失败,返回结果为空")
result_list = []
for url in result_url:
logger.success("img_url: " + url)
result_list.append(url_to_tensor(url).squeeze(0))
result_list = torch.stack(result_list, dim=0)
return (result_list,)
raise Exception("等待超时")
except Exception as e:
raise e
class ModalMidJourneyDescribeImage:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",),
"endpoint": ("STRING", {"default": "bowongai-test--text-video-agent-fastapi-app.modal.run"}),
},
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("描述内容",)
FUNCTION = "process"
OUTPUT_NODE = False
CATEGORY = "不忘科技-自定义节点🚩/图片/Midjourney"
def process(self, image: torch.Tensor, endpoint: str):
try:
logger.info("请求同步接口")
format = "PNG"
job_resp = send_request("post", f"https://{endpoint}/api/302/mj/sync/file/img/describe",
headers={'Authorization': 'Bearer bowong7777'},
files={"img_file": (
'image.' + format.lower(), tensor_to_image_bytes(image, format),
f'image/{format.lower()}')},
timeout=300)
job_resp.raise_for_status()
job_resp = job_resp.json()
if not job_resp["status"]:
raise Exception("描述失败")
result = job_resp["data"]
return (result,)
except Exception as e:
raise e