REFINE 优化图像处理逻辑,统一代码风格并增强可读性
This commit is contained in:
parent
667c546449
commit
bbd0c5325e
|
|
@ -1,6 +1,6 @@
|
||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
from time import sleep
|
from time import sleep, time
|
||||||
|
|
||||||
import folder_paths
|
import folder_paths
|
||||||
import requests
|
import requests
|
||||||
|
|
@ -29,7 +29,8 @@ def url_to_tensor(image_url: str, max_retries: int = 3):
|
||||||
ValueError: 无效图片格式
|
ValueError: 无效图片格式
|
||||||
"""
|
"""
|
||||||
headers = {
|
headers = {
|
||||||
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'}
|
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"
|
||||||
|
}
|
||||||
|
|
||||||
for attempt in range(max_retries):
|
for attempt in range(max_retries):
|
||||||
try:
|
try:
|
||||||
|
|
@ -38,8 +39,8 @@ def url_to_tensor(image_url: str, max_retries: int = 3):
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
# 检查内容类型是否为图像
|
# 检查内容类型是否为图像
|
||||||
content_type = response.headers.get('Content-Type', '')
|
content_type = response.headers.get("Content-Type", "")
|
||||||
if not content_type.startswith('image/'):
|
if not content_type.startswith("image/"):
|
||||||
raise ValueError(f"URL返回非图像内容: {content_type}")
|
raise ValueError(f"URL返回非图像内容: {content_type}")
|
||||||
|
|
||||||
# 验证图像完整性
|
# 验证图像完整性
|
||||||
|
|
@ -48,12 +49,10 @@ def url_to_tensor(image_url: str, max_retries: int = 3):
|
||||||
raise ValueError("下载的内容过小,可能不是完整图像")
|
raise ValueError("下载的内容过小,可能不是完整图像")
|
||||||
|
|
||||||
# 尝试打开图像
|
# 尝试打开图像
|
||||||
img = Image.open(io.BytesIO(img_data)).convert('RGB')
|
img = Image.open(io.BytesIO(img_data)).convert("RGB")
|
||||||
|
|
||||||
# 转换为张量
|
# 转换为张量
|
||||||
transform = transforms.Compose([
|
transform = transforms.Compose([transforms.ToTensor()])
|
||||||
transforms.ToTensor()
|
|
||||||
])
|
|
||||||
return transform(img).unsqueeze(0).permute(0, 2, 3, 1)
|
return transform(img).unsqueeze(0).permute(0, 2, 3, 1)
|
||||||
|
|
||||||
except (requests.exceptions.RequestException, ValueError) as e:
|
except (requests.exceptions.RequestException, ValueError) as e:
|
||||||
|
|
@ -70,7 +69,12 @@ class ModalClothesMask:
|
||||||
"image": ("IMAGE",),
|
"image": ("IMAGE",),
|
||||||
"mask_color": ("STRING", {"default": "绿色"}),
|
"mask_color": ("STRING", {"default": "绿色"}),
|
||||||
"clothes_type": ("STRING", {"default": "裤子"}),
|
"clothes_type": ("STRING", {"default": "裤子"}),
|
||||||
"endpoint": ("STRING", {"default": "bowongai-dev--bowong-ai-video-gemini-fastapi-webapp.modal.run"}),
|
"endpoint": (
|
||||||
|
"STRING",
|
||||||
|
{
|
||||||
|
"default": "bowongai-dev--bowong-ai-video-gemini-fastapi-webapp.modal.run"
|
||||||
|
},
|
||||||
|
),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -80,25 +84,37 @@ class ModalClothesMask:
|
||||||
OUTPUT_NODE = False
|
OUTPUT_NODE = False
|
||||||
CATEGORY = "不忘科技-自定义节点🚩/图片/Gemini"
|
CATEGORY = "不忘科技-自定义节点🚩/图片/Gemini"
|
||||||
|
|
||||||
def process(self, image: torch.Tensor, mask_color: str, clothes_type: str, endpoint: str):
|
def process(
|
||||||
|
self, image: torch.Tensor, mask_color: str, clothes_type: str, endpoint: str
|
||||||
|
):
|
||||||
try:
|
try:
|
||||||
timeout = 60
|
timeout = 60
|
||||||
logger.info("获取token")
|
logger.info("获取token")
|
||||||
api_key = send_request("get", f"https://{endpoint}/google/access-token",
|
api_key = send_request(
|
||||||
headers={'Authorization': 'Bearer bowong7777'}, timeout=timeout).json()[
|
"get",
|
||||||
"access_token"]
|
f"https://{endpoint}/google/access-token",
|
||||||
|
headers={"Authorization": "Bearer bowong7777"},
|
||||||
|
timeout=timeout,
|
||||||
|
).json()["access_token"]
|
||||||
format = "PNG"
|
format = "PNG"
|
||||||
logger.info("请求图像编辑")
|
logger.info("请求图像编辑")
|
||||||
job_resp = send_request("post", f"https://{endpoint}/google/image/clothes_mark",
|
job_resp = send_request(
|
||||||
headers={'x-google-api-key': api_key},
|
"post",
|
||||||
|
f"https://{endpoint}/google/image/clothes_mark",
|
||||||
|
headers={"x-google-api-key": api_key},
|
||||||
data={
|
data={
|
||||||
"mark_clothes_type": clothes_type,
|
"mark_clothes_type": clothes_type,
|
||||||
"mark_color": mask_color,
|
"mark_color": mask_color,
|
||||||
},
|
},
|
||||||
files={"origin_image": (
|
files={
|
||||||
'image.' + format.lower(), tensor_to_image_bytes(image, format),
|
"origin_image": (
|
||||||
f'image/{format.lower()}')},
|
"image." + format.lower(),
|
||||||
timeout=timeout)
|
tensor_to_image_bytes(image, format),
|
||||||
|
f"image/{format.lower()}",
|
||||||
|
)
|
||||||
|
},
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
job_resp.raise_for_status()
|
job_resp.raise_for_status()
|
||||||
job_resp = job_resp.json()
|
job_resp = job_resp.json()
|
||||||
if not job_resp["success"]:
|
if not job_resp["success"]:
|
||||||
|
|
@ -111,8 +127,12 @@ class ModalClothesMask:
|
||||||
sleep(1)
|
sleep(1)
|
||||||
for _ in range(0, wait_time, interval):
|
for _ in range(0, wait_time, interval):
|
||||||
logger.info("查询任务状态")
|
logger.info("查询任务状态")
|
||||||
result = send_request("get", f"https://{endpoint}/google/{job_id}",
|
result = send_request(
|
||||||
headers={'Authorization': 'Bearer bowong7777'}, timeout=timeout)
|
"get",
|
||||||
|
f"https://{endpoint}/google/{job_id}",
|
||||||
|
headers={"Authorization": "Bearer bowong7777"},
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
if result.status_code == 200:
|
if result.status_code == 200:
|
||||||
result = result.json()
|
result = result.json()
|
||||||
if result["status"] == "success":
|
if result["status"] == "success":
|
||||||
|
|
@ -133,14 +153,22 @@ class ModalEditCustom:
|
||||||
def INPUT_TYPES(cls):
|
def INPUT_TYPES(cls):
|
||||||
return {
|
return {
|
||||||
"required": {
|
"required": {
|
||||||
"prompt": ("STRING", {"default": "将背景去除,输出原尺寸图片", "multiline": True}),
|
"prompt": (
|
||||||
|
"STRING",
|
||||||
|
{"default": "将背景去除,输出原尺寸图片", "multiline": True},
|
||||||
|
),
|
||||||
"temperature": ("FLOAT", {"default": 0.1, "min": 0, "max": 2}),
|
"temperature": ("FLOAT", {"default": 0.1, "min": 0, "max": 2}),
|
||||||
"topP": ("FLOAT", {"default": 0.7, "min": 0, "max": 1}),
|
"topP": ("FLOAT", {"default": 0.7, "min": 0, "max": 1}),
|
||||||
"endpoint": ("STRING", {"default": "bowongai-dev--bowong-ai-video-gemini-fastapi-webapp.modal.run"}),
|
"endpoint": (
|
||||||
|
"STRING",
|
||||||
|
{
|
||||||
|
"default": "bowongai-dev--bowong-ai-video-gemini-fastapi-webapp.modal.run"
|
||||||
|
},
|
||||||
|
),
|
||||||
},
|
},
|
||||||
"optional": {
|
"optional": {
|
||||||
"image": ("IMAGE",),
|
"image": ("IMAGE",),
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
RETURN_TYPES = ("IMAGE",)
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
|
@ -149,31 +177,39 @@ class ModalEditCustom:
|
||||||
OUTPUT_NODE = False
|
OUTPUT_NODE = False
|
||||||
CATEGORY = "不忘科技-自定义节点🚩/图片/Gemini"
|
CATEGORY = "不忘科技-自定义节点🚩/图片/Gemini"
|
||||||
|
|
||||||
def process(self, prompt: str, temperature: float, topP: float, endpoint: str, **kwargs):
|
def process(
|
||||||
|
self, prompt: str, temperature: float, topP: float, endpoint: str, **kwargs
|
||||||
|
):
|
||||||
try:
|
try:
|
||||||
timeout = 60
|
timeout = 60
|
||||||
logger.info("获取token")
|
logger.info("获取token")
|
||||||
api_key = send_request("get", f"https://{endpoint}/google/access-token",
|
api_key = send_request(
|
||||||
headers={'Authorization': 'Bearer bowong7777'}, timeout=timeout).json()[
|
"get",
|
||||||
"access_token"]
|
f"https://{endpoint}/google/access-token",
|
||||||
|
headers={"Authorization": "Bearer bowong7777"},
|
||||||
|
timeout=timeout,
|
||||||
|
).json()["access_token"]
|
||||||
format = "PNG"
|
format = "PNG"
|
||||||
if "image" in kwargs and kwargs["image"] is not None:
|
if "image" in kwargs and kwargs["image"] is not None:
|
||||||
image = kwargs["image"]
|
image = kwargs["image"]
|
||||||
files = {"origin_image": (
|
files = {
|
||||||
'image.' + format.lower(), tensor_to_image_bytes(image, format),
|
"origin_image": (
|
||||||
f'image/{format.lower()}')}
|
"image." + format.lower(),
|
||||||
|
tensor_to_image_bytes(image, format),
|
||||||
|
f"image/{format.lower()}",
|
||||||
|
)
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
files = None
|
files = None
|
||||||
logger.info("请求图像编辑")
|
logger.info("请求图像编辑")
|
||||||
job_resp = send_request("post", f"https://{endpoint}/google/image/edit_custom",
|
job_resp = send_request(
|
||||||
headers={'x-google-api-key': api_key},
|
"post",
|
||||||
data={
|
f"https://{endpoint}/google/image/edit_custom",
|
||||||
"prompt": prompt,
|
headers={"x-google-api-key": api_key},
|
||||||
"temperature": temperature,
|
data={"prompt": prompt, "temperature": temperature, "topP": topP},
|
||||||
"topP": topP
|
|
||||||
},
|
|
||||||
files=files,
|
files=files,
|
||||||
timeout=timeout)
|
timeout=timeout,
|
||||||
|
)
|
||||||
job_resp.raise_for_status()
|
job_resp.raise_for_status()
|
||||||
job_resp = job_resp.json()
|
job_resp = job_resp.json()
|
||||||
if not job_resp["success"]:
|
if not job_resp["success"]:
|
||||||
|
|
@ -186,8 +222,12 @@ class ModalEditCustom:
|
||||||
sleep(1)
|
sleep(1)
|
||||||
for _ in range(0, wait_time, interval):
|
for _ in range(0, wait_time, interval):
|
||||||
logger.info("查询任务状态")
|
logger.info("查询任务状态")
|
||||||
result = send_request("get", f"https://{endpoint}/google/{job_id}",
|
result = send_request(
|
||||||
headers={'Authorization': 'Bearer bowong7777'}, timeout=timeout)
|
"get",
|
||||||
|
f"https://{endpoint}/google/{job_id}",
|
||||||
|
headers={"Authorization": "Bearer bowong7777"},
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
if result.status_code == 200:
|
if result.status_code == 200:
|
||||||
result = result.json()
|
result = result.json()
|
||||||
if result["status"] == "success":
|
if result["status"] == "success":
|
||||||
|
|
@ -208,14 +248,22 @@ class ModalMidJourneyGenerateImage:
|
||||||
def INPUT_TYPES(cls):
|
def INPUT_TYPES(cls):
|
||||||
return {
|
return {
|
||||||
"required": {
|
"required": {
|
||||||
"prompt": ("STRING", {"default": "一幅宏大壮美的山川画卷", "multiline": True}),
|
"prompt": (
|
||||||
"provider":(["ttapi","302ai"],),
|
"STRING",
|
||||||
"endpoint": ("STRING", {"default": "bowongai-test--text-video-agent-fastapi-app.modal.run"}),
|
{"default": "一幅宏大壮美的山川画卷", "multiline": True},
|
||||||
|
),
|
||||||
|
"provider": (["ttapi", "302ai"],),
|
||||||
|
"endpoint": (
|
||||||
|
"STRING",
|
||||||
|
{
|
||||||
|
"default": "bowongai-test--text-video-agent-fastapi-app.modal.run"
|
||||||
|
},
|
||||||
|
),
|
||||||
"timeout": ("INT", {"default": 300, "min": 10, "max": 1200}),
|
"timeout": ("INT", {"default": 300, "min": 10, "max": 1200}),
|
||||||
},
|
},
|
||||||
"optional": {
|
"optional": {
|
||||||
"image": ("IMAGE",),
|
"image": ("IMAGE",),
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
RETURN_TYPES = ("IMAGE",)
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
|
@ -224,39 +272,52 @@ class ModalMidJourneyGenerateImage:
|
||||||
OUTPUT_NODE = False
|
OUTPUT_NODE = False
|
||||||
CATEGORY = "不忘科技-自定义节点🚩/图片/Midjourney"
|
CATEGORY = "不忘科技-自定义节点🚩/图片/Midjourney"
|
||||||
|
|
||||||
def process(self, prompt: str, provider:str, endpoint: str, timeout: int, **kwargs):
|
def process(
|
||||||
|
self, prompt: str, provider: str, endpoint: str, timeout: int, **kwargs
|
||||||
|
):
|
||||||
try:
|
try:
|
||||||
logger.info("请求同步接口")
|
logger.info("请求同步接口")
|
||||||
format = "PNG"
|
format = "PNG"
|
||||||
if "image" in kwargs and kwargs["image"] is not None:
|
if "image" in kwargs and kwargs["image"] is not None:
|
||||||
image = kwargs["image"]
|
image = kwargs["image"]
|
||||||
files = {"img_file": (
|
files = {
|
||||||
'image.' + format.lower(), tensor_to_image_bytes(image, format),
|
"img_file": (
|
||||||
f'image/{format.lower()}')}
|
"image." + format.lower(),
|
||||||
|
tensor_to_image_bytes(image, format),
|
||||||
|
f"image/{format.lower()}",
|
||||||
|
)
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
files = None
|
files = None
|
||||||
|
|
||||||
if provider == "302ai":
|
|
||||||
interval = 3
|
interval = 3
|
||||||
logger.info("提交任务")
|
logger.info("提交任务")
|
||||||
logger.info(f"https://{endpoint}/api/custom/image/submit/task")
|
logger.info(f"https://{endpoint}/api/custom/image/submit/task")
|
||||||
job_resp = send_request("post", f"https://{endpoint}/api/custom/image/submit/task",
|
job_resp = send_request(
|
||||||
data={"model_name":"302/mj", "prompt": prompt, "mode": "turbo"},
|
"post",
|
||||||
|
f"https://{endpoint}/api/custom/image/submit/task",
|
||||||
|
data={"model_name": "302/mj", "prompt": prompt, "mode": "turbo"},
|
||||||
files=files,
|
files=files,
|
||||||
timeout=timeout)
|
timeout=timeout,
|
||||||
|
)
|
||||||
job_resp.raise_for_status()
|
job_resp.raise_for_status()
|
||||||
job_resp = job_resp.json()
|
job_resp = job_resp.json()
|
||||||
if not job_resp["status"]:
|
if not job_resp["status"]:
|
||||||
raise Exception("生成失败, 可能因为风控")
|
raise Exception("生成失败, 可能因为风控")
|
||||||
job_id = job_resp["data"]
|
job_id = job_resp["data"]
|
||||||
|
start_time = time()
|
||||||
for _ in range(0, timeout // interval, interval):
|
for _ in range(0, timeout // interval, interval):
|
||||||
logger.info("等待" + str(interval) + "秒")
|
logger.info(f"已等待 {time() - start_time} 秒,{interval} 秒后查询...")
|
||||||
sleep(interval)
|
sleep(interval)
|
||||||
logger.info("查询结果")
|
logger.info(f"查询结果 {job_id}")
|
||||||
resp = send_request("get", f"https://{endpoint}/api/custom/task/status?task_id={job_id}", timeout=30)
|
resp = send_request(
|
||||||
|
"get",
|
||||||
|
f"https://{endpoint}/api/custom/task/status?task_id={job_id}",
|
||||||
|
timeout=30,
|
||||||
|
)
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
if resp.json()["status"] == "running":
|
if resp.json()["status"] == "running":
|
||||||
logger.info("任务正在运行")
|
logger.info(f"任务正在运行 {job_id}")
|
||||||
continue
|
continue
|
||||||
if resp.json()["status"] == "failed":
|
if resp.json()["status"] == "failed":
|
||||||
raise Exception(f"生成失败: {resp.json()['msg']}")
|
raise Exception(f"生成失败: {resp.json()['msg']}")
|
||||||
|
|
@ -269,6 +330,7 @@ class ModalMidJourneyGenerateImage:
|
||||||
logger.success("img_url: " + url)
|
logger.success("img_url: " + url)
|
||||||
result_list.append(url_to_tensor(url).squeeze(0))
|
result_list.append(url_to_tensor(url).squeeze(0))
|
||||||
result_list = torch.stack(result_list, dim=0)
|
result_list = torch.stack(result_list, dim=0)
|
||||||
|
logger.success(f"生成成功 {job_id}")
|
||||||
return (result_list,)
|
return (result_list,)
|
||||||
raise Exception("等待超时")
|
raise Exception("等待超时")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -281,7 +343,12 @@ class ModalMidJourneyDescribeImage:
|
||||||
return {
|
return {
|
||||||
"required": {
|
"required": {
|
||||||
"image": ("IMAGE",),
|
"image": ("IMAGE",),
|
||||||
"endpoint": ("STRING", {"default": "bowongai-test--text-video-agent-fastapi-app.modal.run"}),
|
"endpoint": (
|
||||||
|
"STRING",
|
||||||
|
{
|
||||||
|
"default": "bowongai-test--text-video-agent-fastapi-app.modal.run"
|
||||||
|
},
|
||||||
|
),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -295,12 +362,19 @@ class ModalMidJourneyDescribeImage:
|
||||||
try:
|
try:
|
||||||
logger.info("请求同步接口")
|
logger.info("请求同步接口")
|
||||||
format = "PNG"
|
format = "PNG"
|
||||||
job_resp = send_request("post", f"https://{endpoint}/api/302/mj/sync/file/img/describe",
|
job_resp = send_request(
|
||||||
headers={'Authorization': 'Bearer bowong7777'},
|
"post",
|
||||||
files={"img_file": (
|
f"https://{endpoint}/api/302/mj/sync/file/img/describe",
|
||||||
'image.' + format.lower(), tensor_to_image_bytes(image, format),
|
headers={"Authorization": "Bearer bowong7777"},
|
||||||
f'image/{format.lower()}')},
|
files={
|
||||||
timeout=300)
|
"img_file": (
|
||||||
|
"image." + format.lower(),
|
||||||
|
tensor_to_image_bytes(image, format),
|
||||||
|
f"image/{format.lower()}",
|
||||||
|
)
|
||||||
|
},
|
||||||
|
timeout=300,
|
||||||
|
)
|
||||||
job_resp.raise_for_status()
|
job_resp.raise_for_status()
|
||||||
job_resp = job_resp.json()
|
job_resp = job_resp.json()
|
||||||
if not job_resp["status"]:
|
if not job_resp["status"]:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue