386 lines
14 KiB
Python
386 lines
14 KiB
Python
import io
|
||
import json
|
||
from time import sleep, time
|
||
|
||
import folder_paths
|
||
import requests
|
||
import torch
|
||
from PIL import Image
|
||
from loguru import logger
|
||
from torchvision import transforms
|
||
|
||
from ..utils.http_utils import send_request
|
||
from ..utils.image_utils import tensor_to_image_bytes, base64_to_tensor
|
||
|
||
|
||
def url_to_tensor(image_url: str, max_retries: int = 3):
|
||
"""
|
||
从URL下载图片并转换为PyTorch张量,增强错误处理能力
|
||
|
||
参数:
|
||
image_url (str): 图片URL
|
||
max_retries (int): 最大重试次数
|
||
|
||
返回:
|
||
torch.Tensor: 形状为[C, H, W]的张量
|
||
|
||
异常:
|
||
HTTPError: 网络请求失败
|
||
ValueError: 无效图片格式
|
||
"""
|
||
headers = {
|
||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"
|
||
}
|
||
|
||
for attempt in range(max_retries):
|
||
try:
|
||
# 发送带User-Agent的请求
|
||
response = requests.get(image_url, headers=headers, stream=True, timeout=15)
|
||
response.raise_for_status()
|
||
|
||
# 检查内容类型是否为图像
|
||
content_type = response.headers.get("Content-Type", "")
|
||
if not content_type.startswith("image/"):
|
||
raise ValueError(f"URL返回非图像内容: {content_type}")
|
||
|
||
# 验证图像完整性
|
||
img_data = response.content
|
||
if len(img_data) < 100: # 极小数据通常不是有效图像
|
||
raise ValueError("下载的内容过小,可能不是完整图像")
|
||
|
||
# 尝试打开图像
|
||
img = Image.open(io.BytesIO(img_data)).convert("RGB")
|
||
|
||
# 转换为张量
|
||
transform = transforms.Compose([transforms.ToTensor()])
|
||
return transform(img).unsqueeze(0).permute(0, 2, 3, 1)
|
||
|
||
except (requests.exceptions.RequestException, ValueError) as e:
|
||
logger.warning(f"尝试 {attempt + 1}/{max_retries} 失败: {e}")
|
||
if attempt == max_retries - 1:
|
||
raise e
|
||
|
||
|
||
class ModalClothesMask:
|
||
@classmethod
|
||
def INPUT_TYPES(cls):
|
||
return {
|
||
"required": {
|
||
"image": ("IMAGE",),
|
||
"mask_color": ("STRING", {"default": "绿色"}),
|
||
"clothes_type": ("STRING", {"default": "裤子"}),
|
||
"endpoint": (
|
||
"STRING",
|
||
{
|
||
"default": "bowongai-dev--bowong-ai-video-gemini-fastapi-webapp.modal.run"
|
||
},
|
||
),
|
||
},
|
||
}
|
||
|
||
RETURN_TYPES = ("IMAGE",)
|
||
RETURN_NAMES = ("image",)
|
||
FUNCTION = "process"
|
||
OUTPUT_NODE = False
|
||
CATEGORY = "不忘科技-自定义节点🚩/图片/Gemini"
|
||
|
||
def process(
|
||
self, image: torch.Tensor, mask_color: str, clothes_type: str, endpoint: str
|
||
):
|
||
try:
|
||
timeout = 60
|
||
logger.info("获取token")
|
||
api_key = send_request(
|
||
"get",
|
||
f"https://{endpoint}/google/access-token",
|
||
headers={"Authorization": "Bearer bowong7777"},
|
||
timeout=timeout,
|
||
).json()["access_token"]
|
||
format = "PNG"
|
||
logger.info("请求图像编辑")
|
||
job_resp = send_request(
|
||
"post",
|
||
f"https://{endpoint}/google/image/clothes_mark",
|
||
headers={"x-google-api-key": api_key},
|
||
data={
|
||
"mark_clothes_type": clothes_type,
|
||
"mark_color": mask_color,
|
||
},
|
||
files={
|
||
"origin_image": (
|
||
"image." + format.lower(),
|
||
tensor_to_image_bytes(image, format),
|
||
f"image/{format.lower()}",
|
||
)
|
||
},
|
||
timeout=timeout,
|
||
)
|
||
job_resp.raise_for_status()
|
||
job_resp = job_resp.json()
|
||
if not job_resp["success"]:
|
||
raise Exception("请求Modal API失败")
|
||
job_id = job_resp["taskId"]
|
||
|
||
wait_time = 240
|
||
interval = 2
|
||
logger.info("开始轮询任务状态")
|
||
sleep(1)
|
||
for _ in range(0, wait_time, interval):
|
||
logger.info("查询任务状态")
|
||
result = send_request(
|
||
"get",
|
||
f"https://{endpoint}/google/{job_id}",
|
||
headers={"Authorization": "Bearer bowong7777"},
|
||
timeout=timeout,
|
||
)
|
||
if result.status_code == 200:
|
||
result = result.json()
|
||
if result["status"] == "success":
|
||
logger.success("任务成功")
|
||
image_b64 = json.loads(result["result"])[0]["image_b64"]
|
||
image_tensor = base64_to_tensor(image_b64)
|
||
return (image_tensor,)
|
||
elif "fail" in result["status"].lower():
|
||
raise Exception("任务失败")
|
||
sleep(interval)
|
||
raise Exception("查询任务状态超时")
|
||
except Exception as e:
|
||
raise Exception(e)
|
||
|
||
|
||
class ModalEditCustom:
|
||
@classmethod
|
||
def INPUT_TYPES(cls):
|
||
return {
|
||
"required": {
|
||
"prompt": (
|
||
"STRING",
|
||
{"default": "将背景去除,输出原尺寸图片", "multiline": True},
|
||
),
|
||
"temperature": ("FLOAT", {"default": 0.1, "min": 0, "max": 2}),
|
||
"topP": ("FLOAT", {"default": 0.7, "min": 0, "max": 1}),
|
||
"endpoint": (
|
||
"STRING",
|
||
{
|
||
"default": "bowongai-dev--bowong-ai-video-gemini-fastapi-webapp.modal.run"
|
||
},
|
||
),
|
||
},
|
||
"optional": {
|
||
"image": ("IMAGE",),
|
||
},
|
||
}
|
||
|
||
RETURN_TYPES = ("IMAGE",)
|
||
RETURN_NAMES = ("image",)
|
||
FUNCTION = "process"
|
||
OUTPUT_NODE = False
|
||
CATEGORY = "不忘科技-自定义节点🚩/图片/Gemini"
|
||
|
||
def process(
|
||
self, prompt: str, temperature: float, topP: float, endpoint: str, **kwargs
|
||
):
|
||
try:
|
||
timeout = 60
|
||
logger.info("获取token")
|
||
api_key = send_request(
|
||
"get",
|
||
f"https://{endpoint}/google/access-token",
|
||
headers={"Authorization": "Bearer bowong7777"},
|
||
timeout=timeout,
|
||
).json()["access_token"]
|
||
format = "PNG"
|
||
if "image" in kwargs and kwargs["image"] is not None:
|
||
image = kwargs["image"]
|
||
files = {
|
||
"origin_image": (
|
||
"image." + format.lower(),
|
||
tensor_to_image_bytes(image, format),
|
||
f"image/{format.lower()}",
|
||
)
|
||
}
|
||
else:
|
||
files = None
|
||
logger.info("请求图像编辑")
|
||
job_resp = send_request(
|
||
"post",
|
||
f"https://{endpoint}/google/image/edit_custom",
|
||
headers={"x-google-api-key": api_key},
|
||
data={"prompt": prompt, "temperature": temperature, "topP": topP},
|
||
files=files,
|
||
timeout=timeout,
|
||
)
|
||
job_resp.raise_for_status()
|
||
job_resp = job_resp.json()
|
||
if not job_resp["success"]:
|
||
raise Exception("请求Modal API失败")
|
||
job_id = job_resp["taskId"]
|
||
|
||
wait_time = 240
|
||
interval = 2
|
||
logger.info("开始轮询任务状态")
|
||
sleep(1)
|
||
for _ in range(0, wait_time, interval):
|
||
logger.info("查询任务状态")
|
||
result = send_request(
|
||
"get",
|
||
f"https://{endpoint}/google/{job_id}",
|
||
headers={"Authorization": "Bearer bowong7777"},
|
||
timeout=timeout,
|
||
)
|
||
if result.status_code == 200:
|
||
result = result.json()
|
||
if result["status"] == "success":
|
||
logger.success("任务成功")
|
||
image_b64 = json.loads(result["result"])[0]["image_b64"]
|
||
image_tensor = base64_to_tensor(image_b64)
|
||
return (image_tensor,)
|
||
elif "fail" in result["status"].lower():
|
||
raise Exception("任务失败")
|
||
sleep(interval)
|
||
raise Exception("查询任务状态超时")
|
||
except Exception as e:
|
||
raise Exception(e)
|
||
|
||
|
||
class ModalMidJourneyGenerateImage:
|
||
@classmethod
|
||
def INPUT_TYPES(cls):
|
||
return {
|
||
"required": {
|
||
"prompt": (
|
||
"STRING",
|
||
{"default": "一幅宏大壮美的山川画卷", "multiline": True},
|
||
),
|
||
"provider": (["ttapi", "302ai"],),
|
||
"endpoint": (
|
||
"STRING",
|
||
{
|
||
"default": "bowongai-test--text-video-agent-fastapi-app.modal.run"
|
||
},
|
||
),
|
||
"timeout": ("INT", {"default": 300, "min": 10, "max": 1200}),
|
||
},
|
||
"optional": {
|
||
"image": ("IMAGE",),
|
||
},
|
||
}
|
||
|
||
RETURN_TYPES = ("IMAGE",)
|
||
RETURN_NAMES = ("image",)
|
||
FUNCTION = "process"
|
||
OUTPUT_NODE = False
|
||
CATEGORY = "不忘科技-自定义节点🚩/图片/Midjourney"
|
||
|
||
def process(
|
||
self, prompt: str, provider: str, endpoint: str, timeout: int, **kwargs
|
||
):
|
||
try:
|
||
logger.info("请求同步接口")
|
||
format = "PNG"
|
||
if "image" in kwargs and kwargs["image"] is not None:
|
||
image = kwargs["image"]
|
||
files = {
|
||
"img_file": (
|
||
"image." + format.lower(),
|
||
tensor_to_image_bytes(image, format),
|
||
f"image/{format.lower()}",
|
||
)
|
||
}
|
||
else:
|
||
files = None
|
||
|
||
interval = 3
|
||
logger.info("提交任务")
|
||
logger.info(f"https://{endpoint}/api/custom/image/submit/task")
|
||
job_resp = send_request(
|
||
"post",
|
||
f"https://{endpoint}/api/custom/image/submit/task",
|
||
data={"model_name": "302/mj", "prompt": prompt, "mode": "turbo"},
|
||
files=files,
|
||
timeout=timeout,
|
||
)
|
||
job_resp.raise_for_status()
|
||
job_resp = job_resp.json()
|
||
if not job_resp["status"]:
|
||
raise Exception("生成失败, 可能因为风控")
|
||
job_id = job_resp["data"]
|
||
start_time = time()
|
||
for _ in range(0, timeout // interval, interval):
|
||
logger.info(f"已等待 {time() - start_time} 秒,{interval} 秒后查询...")
|
||
sleep(interval)
|
||
logger.info(f"查询结果 {job_id}")
|
||
resp = send_request(
|
||
"get",
|
||
f"https://{endpoint}/api/custom/task/status?task_id={job_id}",
|
||
timeout=30,
|
||
)
|
||
resp.raise_for_status()
|
||
if resp.json()["status"] == "running":
|
||
logger.info(f"任务正在运行 {job_id}")
|
||
continue
|
||
if resp.json()["status"] == "failed":
|
||
raise Exception(f"生成失败: {resp.json()['msg']}")
|
||
if resp.json()["status"] == "success":
|
||
result_url = resp.json()["data"]
|
||
if not isinstance(result_url, list):
|
||
raise Exception("生成失败,返回结果为空")
|
||
result_list = []
|
||
for url in result_url:
|
||
logger.success("img_url: " + url)
|
||
result_list.append(url_to_tensor(url).squeeze(0))
|
||
result_list = torch.stack(result_list, dim=0)
|
||
logger.success(f"生成成功 {job_id}")
|
||
return (result_list,)
|
||
raise Exception("等待超时")
|
||
except Exception as e:
|
||
raise e
|
||
|
||
|
||
class ModalMidJourneyDescribeImage:
|
||
@classmethod
|
||
def INPUT_TYPES(cls):
|
||
return {
|
||
"required": {
|
||
"image": ("IMAGE",),
|
||
"endpoint": (
|
||
"STRING",
|
||
{
|
||
"default": "bowongai-test--text-video-agent-fastapi-app.modal.run"
|
||
},
|
||
),
|
||
},
|
||
}
|
||
|
||
RETURN_TYPES = ("STRING",)
|
||
RETURN_NAMES = ("描述内容",)
|
||
FUNCTION = "process"
|
||
OUTPUT_NODE = False
|
||
CATEGORY = "不忘科技-自定义节点🚩/图片/Midjourney"
|
||
|
||
def process(self, image: torch.Tensor, endpoint: str):
|
||
try:
|
||
logger.info("请求同步接口")
|
||
format = "PNG"
|
||
job_resp = send_request(
|
||
"post",
|
||
f"https://{endpoint}/api/302/mj/sync/file/img/describe",
|
||
headers={"Authorization": "Bearer bowong7777"},
|
||
files={
|
||
"img_file": (
|
||
"image." + format.lower(),
|
||
tensor_to_image_bytes(image, format),
|
||
f"image/{format.lower()}",
|
||
)
|
||
},
|
||
timeout=300,
|
||
)
|
||
job_resp.raise_for_status()
|
||
job_resp = job_resp.json()
|
||
if not job_resp["status"]:
|
||
raise Exception("描述失败")
|
||
result = job_resp["data"]
|
||
return (result,)
|
||
except Exception as e:
|
||
raise e
|