REFINE 优化图像处理逻辑,统一代码风格并增强可读性

This commit is contained in:
iHeyTang 2025-08-12 15:46:46 +08:00
parent 667c546449
commit bbd0c5325e
1 changed files with 170 additions and 96 deletions

View File

@ -1,6 +1,6 @@
import io import io
import json import json
from time import sleep from time import sleep, time
import folder_paths import folder_paths
import requests import requests
@ -29,7 +29,8 @@ def url_to_tensor(image_url: str, max_retries: int = 3):
ValueError: 无效图片格式 ValueError: 无效图片格式
""" """
headers = { headers = {
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'} "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"
}
for attempt in range(max_retries): for attempt in range(max_retries):
try: try:
@ -38,8 +39,8 @@ def url_to_tensor(image_url: str, max_retries: int = 3):
response.raise_for_status() response.raise_for_status()
# 检查内容类型是否为图像 # 检查内容类型是否为图像
content_type = response.headers.get('Content-Type', '') content_type = response.headers.get("Content-Type", "")
if not content_type.startswith('image/'): if not content_type.startswith("image/"):
raise ValueError(f"URL返回非图像内容: {content_type}") raise ValueError(f"URL返回非图像内容: {content_type}")
# 验证图像完整性 # 验证图像完整性
@ -48,12 +49,10 @@ def url_to_tensor(image_url: str, max_retries: int = 3):
raise ValueError("下载的内容过小,可能不是完整图像") raise ValueError("下载的内容过小,可能不是完整图像")
# 尝试打开图像 # 尝试打开图像
img = Image.open(io.BytesIO(img_data)).convert('RGB') img = Image.open(io.BytesIO(img_data)).convert("RGB")
# 转换为张量 # 转换为张量
transform = transforms.Compose([ transform = transforms.Compose([transforms.ToTensor()])
transforms.ToTensor()
])
return transform(img).unsqueeze(0).permute(0, 2, 3, 1) return transform(img).unsqueeze(0).permute(0, 2, 3, 1)
except (requests.exceptions.RequestException, ValueError) as e: except (requests.exceptions.RequestException, ValueError) as e:
@ -70,7 +69,12 @@ class ModalClothesMask:
"image": ("IMAGE",), "image": ("IMAGE",),
"mask_color": ("STRING", {"default": "绿色"}), "mask_color": ("STRING", {"default": "绿色"}),
"clothes_type": ("STRING", {"default": "裤子"}), "clothes_type": ("STRING", {"default": "裤子"}),
"endpoint": ("STRING", {"default": "bowongai-dev--bowong-ai-video-gemini-fastapi-webapp.modal.run"}), "endpoint": (
"STRING",
{
"default": "bowongai-dev--bowong-ai-video-gemini-fastapi-webapp.modal.run"
},
),
}, },
} }
@ -80,25 +84,37 @@ class ModalClothesMask:
OUTPUT_NODE = False OUTPUT_NODE = False
CATEGORY = "不忘科技-自定义节点🚩/图片/Gemini" CATEGORY = "不忘科技-自定义节点🚩/图片/Gemini"
def process(self, image: torch.Tensor, mask_color: str, clothes_type: str, endpoint: str): def process(
self, image: torch.Tensor, mask_color: str, clothes_type: str, endpoint: str
):
try: try:
timeout = 60 timeout = 60
logger.info("获取token") logger.info("获取token")
api_key = send_request("get", f"https://{endpoint}/google/access-token", api_key = send_request(
headers={'Authorization': 'Bearer bowong7777'}, timeout=timeout).json()[ "get",
"access_token"] f"https://{endpoint}/google/access-token",
headers={"Authorization": "Bearer bowong7777"},
timeout=timeout,
).json()["access_token"]
format = "PNG" format = "PNG"
logger.info("请求图像编辑") logger.info("请求图像编辑")
job_resp = send_request("post", f"https://{endpoint}/google/image/clothes_mark", job_resp = send_request(
headers={'x-google-api-key': api_key}, "post",
f"https://{endpoint}/google/image/clothes_mark",
headers={"x-google-api-key": api_key},
data={ data={
"mark_clothes_type": clothes_type, "mark_clothes_type": clothes_type,
"mark_color": mask_color, "mark_color": mask_color,
}, },
files={"origin_image": ( files={
'image.' + format.lower(), tensor_to_image_bytes(image, format), "origin_image": (
f'image/{format.lower()}')}, "image." + format.lower(),
timeout=timeout) tensor_to_image_bytes(image, format),
f"image/{format.lower()}",
)
},
timeout=timeout,
)
job_resp.raise_for_status() job_resp.raise_for_status()
job_resp = job_resp.json() job_resp = job_resp.json()
if not job_resp["success"]: if not job_resp["success"]:
@ -111,8 +127,12 @@ class ModalClothesMask:
sleep(1) sleep(1)
for _ in range(0, wait_time, interval): for _ in range(0, wait_time, interval):
logger.info("查询任务状态") logger.info("查询任务状态")
result = send_request("get", f"https://{endpoint}/google/{job_id}", result = send_request(
headers={'Authorization': 'Bearer bowong7777'}, timeout=timeout) "get",
f"https://{endpoint}/google/{job_id}",
headers={"Authorization": "Bearer bowong7777"},
timeout=timeout,
)
if result.status_code == 200: if result.status_code == 200:
result = result.json() result = result.json()
if result["status"] == "success": if result["status"] == "success":
@ -133,14 +153,22 @@ class ModalEditCustom:
def INPUT_TYPES(cls): def INPUT_TYPES(cls):
return { return {
"required": { "required": {
"prompt": ("STRING", {"default": "将背景去除,输出原尺寸图片", "multiline": True}), "prompt": (
"STRING",
{"default": "将背景去除,输出原尺寸图片", "multiline": True},
),
"temperature": ("FLOAT", {"default": 0.1, "min": 0, "max": 2}), "temperature": ("FLOAT", {"default": 0.1, "min": 0, "max": 2}),
"topP": ("FLOAT", {"default": 0.7, "min": 0, "max": 1}), "topP": ("FLOAT", {"default": 0.7, "min": 0, "max": 1}),
"endpoint": ("STRING", {"default": "bowongai-dev--bowong-ai-video-gemini-fastapi-webapp.modal.run"}), "endpoint": (
"STRING",
{
"default": "bowongai-dev--bowong-ai-video-gemini-fastapi-webapp.modal.run"
},
),
}, },
"optional": { "optional": {
"image": ("IMAGE",), "image": ("IMAGE",),
} },
} }
RETURN_TYPES = ("IMAGE",) RETURN_TYPES = ("IMAGE",)
@ -149,31 +177,39 @@ class ModalEditCustom:
OUTPUT_NODE = False OUTPUT_NODE = False
CATEGORY = "不忘科技-自定义节点🚩/图片/Gemini" CATEGORY = "不忘科技-自定义节点🚩/图片/Gemini"
def process(self, prompt: str, temperature: float, topP: float, endpoint: str, **kwargs): def process(
self, prompt: str, temperature: float, topP: float, endpoint: str, **kwargs
):
try: try:
timeout = 60 timeout = 60
logger.info("获取token") logger.info("获取token")
api_key = send_request("get", f"https://{endpoint}/google/access-token", api_key = send_request(
headers={'Authorization': 'Bearer bowong7777'}, timeout=timeout).json()[ "get",
"access_token"] f"https://{endpoint}/google/access-token",
headers={"Authorization": "Bearer bowong7777"},
timeout=timeout,
).json()["access_token"]
format = "PNG" format = "PNG"
if "image" in kwargs and kwargs["image"] is not None: if "image" in kwargs and kwargs["image"] is not None:
image = kwargs["image"] image = kwargs["image"]
files = {"origin_image": ( files = {
'image.' + format.lower(), tensor_to_image_bytes(image, format), "origin_image": (
f'image/{format.lower()}')} "image." + format.lower(),
tensor_to_image_bytes(image, format),
f"image/{format.lower()}",
)
}
else: else:
files = None files = None
logger.info("请求图像编辑") logger.info("请求图像编辑")
job_resp = send_request("post", f"https://{endpoint}/google/image/edit_custom", job_resp = send_request(
headers={'x-google-api-key': api_key}, "post",
data={ f"https://{endpoint}/google/image/edit_custom",
"prompt": prompt, headers={"x-google-api-key": api_key},
"temperature": temperature, data={"prompt": prompt, "temperature": temperature, "topP": topP},
"topP": topP
},
files=files, files=files,
timeout=timeout) timeout=timeout,
)
job_resp.raise_for_status() job_resp.raise_for_status()
job_resp = job_resp.json() job_resp = job_resp.json()
if not job_resp["success"]: if not job_resp["success"]:
@ -186,8 +222,12 @@ class ModalEditCustom:
sleep(1) sleep(1)
for _ in range(0, wait_time, interval): for _ in range(0, wait_time, interval):
logger.info("查询任务状态") logger.info("查询任务状态")
result = send_request("get", f"https://{endpoint}/google/{job_id}", result = send_request(
headers={'Authorization': 'Bearer bowong7777'}, timeout=timeout) "get",
f"https://{endpoint}/google/{job_id}",
headers={"Authorization": "Bearer bowong7777"},
timeout=timeout,
)
if result.status_code == 200: if result.status_code == 200:
result = result.json() result = result.json()
if result["status"] == "success": if result["status"] == "success":
@ -208,14 +248,22 @@ class ModalMidJourneyGenerateImage:
def INPUT_TYPES(cls): def INPUT_TYPES(cls):
return { return {
"required": { "required": {
"prompt": ("STRING", {"default": "一幅宏大壮美的山川画卷", "multiline": True}), "prompt": (
"provider":(["ttapi","302ai"],), "STRING",
"endpoint": ("STRING", {"default": "bowongai-test--text-video-agent-fastapi-app.modal.run"}), {"default": "一幅宏大壮美的山川画卷", "multiline": True},
),
"provider": (["ttapi", "302ai"],),
"endpoint": (
"STRING",
{
"default": "bowongai-test--text-video-agent-fastapi-app.modal.run"
},
),
"timeout": ("INT", {"default": 300, "min": 10, "max": 1200}), "timeout": ("INT", {"default": 300, "min": 10, "max": 1200}),
}, },
"optional": { "optional": {
"image": ("IMAGE",), "image": ("IMAGE",),
} },
} }
RETURN_TYPES = ("IMAGE",) RETURN_TYPES = ("IMAGE",)
@ -224,39 +272,52 @@ class ModalMidJourneyGenerateImage:
OUTPUT_NODE = False OUTPUT_NODE = False
CATEGORY = "不忘科技-自定义节点🚩/图片/Midjourney" CATEGORY = "不忘科技-自定义节点🚩/图片/Midjourney"
def process(self, prompt: str, provider:str, endpoint: str, timeout: int, **kwargs): def process(
self, prompt: str, provider: str, endpoint: str, timeout: int, **kwargs
):
try: try:
logger.info("请求同步接口") logger.info("请求同步接口")
format = "PNG" format = "PNG"
if "image" in kwargs and kwargs["image"] is not None: if "image" in kwargs and kwargs["image"] is not None:
image = kwargs["image"] image = kwargs["image"]
files = {"img_file": ( files = {
'image.' + format.lower(), tensor_to_image_bytes(image, format), "img_file": (
f'image/{format.lower()}')} "image." + format.lower(),
tensor_to_image_bytes(image, format),
f"image/{format.lower()}",
)
}
else: else:
files = None files = None
if provider == "302ai":
interval = 3 interval = 3
logger.info("提交任务") logger.info("提交任务")
logger.info(f"https://{endpoint}/api/custom/image/submit/task") logger.info(f"https://{endpoint}/api/custom/image/submit/task")
job_resp = send_request("post", f"https://{endpoint}/api/custom/image/submit/task", job_resp = send_request(
data={"model_name":"302/mj", "prompt": prompt, "mode": "turbo"}, "post",
f"https://{endpoint}/api/custom/image/submit/task",
data={"model_name": "302/mj", "prompt": prompt, "mode": "turbo"},
files=files, files=files,
timeout=timeout) timeout=timeout,
)
job_resp.raise_for_status() job_resp.raise_for_status()
job_resp = job_resp.json() job_resp = job_resp.json()
if not job_resp["status"]: if not job_resp["status"]:
raise Exception("生成失败, 可能因为风控") raise Exception("生成失败, 可能因为风控")
job_id = job_resp["data"] job_id = job_resp["data"]
start_time = time()
for _ in range(0, timeout // interval, interval): for _ in range(0, timeout // interval, interval):
logger.info("等待" + str(interval) + "") logger.info(f"已等待 {time() - start_time} 秒,{interval} 秒后查询...")
sleep(interval) sleep(interval)
logger.info("查询结果") logger.info(f"查询结果 {job_id}")
resp = send_request("get", f"https://{endpoint}/api/custom/task/status?task_id={job_id}", timeout=30) resp = send_request(
"get",
f"https://{endpoint}/api/custom/task/status?task_id={job_id}",
timeout=30,
)
resp.raise_for_status() resp.raise_for_status()
if resp.json()["status"] == "running": if resp.json()["status"] == "running":
logger.info("任务正在运行") logger.info(f"任务正在运行 {job_id}")
continue continue
if resp.json()["status"] == "failed": if resp.json()["status"] == "failed":
raise Exception(f"生成失败: {resp.json()['msg']}") raise Exception(f"生成失败: {resp.json()['msg']}")
@ -269,6 +330,7 @@ class ModalMidJourneyGenerateImage:
logger.success("img_url: " + url) logger.success("img_url: " + url)
result_list.append(url_to_tensor(url).squeeze(0)) result_list.append(url_to_tensor(url).squeeze(0))
result_list = torch.stack(result_list, dim=0) result_list = torch.stack(result_list, dim=0)
logger.success(f"生成成功 {job_id}")
return (result_list,) return (result_list,)
raise Exception("等待超时") raise Exception("等待超时")
except Exception as e: except Exception as e:
@ -281,7 +343,12 @@ class ModalMidJourneyDescribeImage:
return { return {
"required": { "required": {
"image": ("IMAGE",), "image": ("IMAGE",),
"endpoint": ("STRING", {"default": "bowongai-test--text-video-agent-fastapi-app.modal.run"}), "endpoint": (
"STRING",
{
"default": "bowongai-test--text-video-agent-fastapi-app.modal.run"
},
),
}, },
} }
@ -295,12 +362,19 @@ class ModalMidJourneyDescribeImage:
try: try:
logger.info("请求同步接口") logger.info("请求同步接口")
format = "PNG" format = "PNG"
job_resp = send_request("post", f"https://{endpoint}/api/302/mj/sync/file/img/describe", job_resp = send_request(
headers={'Authorization': 'Bearer bowong7777'}, "post",
files={"img_file": ( f"https://{endpoint}/api/302/mj/sync/file/img/describe",
'image.' + format.lower(), tensor_to_image_bytes(image, format), headers={"Authorization": "Bearer bowong7777"},
f'image/{format.lower()}')}, files={
timeout=300) "img_file": (
"image." + format.lower(),
tensor_to_image_bytes(image, format),
f"image/{format.lower()}",
)
},
timeout=300,
)
job_resp.raise_for_status() job_resp.raise_for_status()
job_resp = job_resp.json() job_resp = job_resp.json()
if not job_resp["status"]: if not job_resp["status"]: