263 lines
11 KiB
Python
263 lines
11 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""
|
||
File img_agent.py
|
||
Author silence
|
||
Date 2025/9/6
|
||
"""
|
||
|
||
import json
|
||
import requests
|
||
import os
|
||
import folder_paths
|
||
import mimetypes
|
||
from PIL import Image
|
||
import numpy as np
|
||
import torch
|
||
import io
|
||
import re
|
||
|
||
try:
|
||
from loguru import logger
|
||
except ImportError:
|
||
import logging
|
||
|
||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||
logger = logging.getLogger("ImgSubmitNode")
|
||
print("提示: loguru 未安装,使用内置logging。建议安装以获得更好的日志体验: pip install loguru")
|
||
|
||
|
||
def fetch_and_process_image_models():
|
||
"""
|
||
在节点加载时从API获取生图模型列表,并存储其配置用于后端校验。
|
||
"""
|
||
image_model_urls = {
|
||
"prod": "https://bowongai-prod--text-video-agent-fastapi-app.modal.run/api/custom/model/list?category=image",
|
||
"dev": "https://bowongai-dev--text-video-agent-fastapi-app.modal.run/api/custom/model/list?category=image",
|
||
"test": "https://bowongai-test--text-video-agent-fastapi-app.modal.run/api/custom/model/list?category=image"
|
||
}
|
||
|
||
model_data = {
|
||
"configs": {},
|
||
"full_display_list": [],
|
||
"display_to_tech_name": {},
|
||
"temp_list_for_sorting": []
|
||
}
|
||
|
||
try:
|
||
response = None
|
||
for env, url in image_model_urls.items():
|
||
try:
|
||
response = requests.get(url, timeout=10)
|
||
response.raise_for_status()
|
||
logger.info(f"成功从 [{env}] 环境获取生图模型列表。")
|
||
break
|
||
except requests.exceptions.RequestException:
|
||
logger.warning(f"无法从 [{env}] 环境获取模型列表,尝试下一个...")
|
||
continue
|
||
|
||
if not response:
|
||
raise ConnectionError("所有环境的模型列表API都无法访问。")
|
||
|
||
data = response.json()
|
||
if not data.get("status") or "data" not in data:
|
||
raise ValueError(f"API响应格式错误: {data.get('msg', '未知错误')}")
|
||
|
||
for model in data["data"]:
|
||
tech_name = model.get("model_name")
|
||
if not tech_name:
|
||
continue
|
||
description_from_api = str(model.get("description", tech_name)).strip()
|
||
display_name = f"{description_from_api} ({tech_name})"
|
||
mode = model.get("mode")
|
||
sort_key = 99
|
||
if mode == "t2i":
|
||
sort_key = 0
|
||
elif mode == "i2i":
|
||
sort_key = 1
|
||
elif mode == "both":
|
||
sort_key = 2
|
||
|
||
# 4. 存储所有信息
|
||
model_data["configs"][tech_name] = model
|
||
model_data["display_to_tech_name"][display_name] = tech_name
|
||
model_data["temp_list_for_sorting"].append((sort_key, display_name))
|
||
|
||
model_data["temp_list_for_sorting"].sort(key=lambda x: (x[0], x[1]))
|
||
model_data["full_display_list"] = [item[1] for item in model_data["temp_list_for_sorting"]]
|
||
|
||
except Exception as e:
|
||
logger.error(f"加载生图模型数据失败: {e}")
|
||
|
||
if not model_data["full_display_list"]:
|
||
model_data["full_display_list"] = ["错误:无法加载模型"]
|
||
|
||
return model_data
|
||
|
||
|
||
IMAGE_MODEL_DATA = fetch_and_process_image_models()
|
||
|
||
|
||
class ImgSubmitNode:
|
||
MODEL_DATA = IMAGE_MODEL_DATA
|
||
|
||
@classmethod
|
||
def INPUT_TYPES(cls):
|
||
# 从用户提供的截图中提取的尺寸选项
|
||
size_options = [
|
||
'2048x2048 (1:1)', '2394x1728 (4:3)', '1728x2394 (3:4)', '2560x1440 (16:9)',
|
||
'1440x2560 (9:16)', '2496x1664 (3:2)', '1664x2496 (2:3)', '3024x1296 (21:9)',
|
||
'1K', '2K', '4K'
|
||
]
|
||
return {
|
||
"required": {
|
||
"model_name_display": (cls.MODEL_DATA["full_display_list"],),
|
||
"prompt": ("STRING", {"multiline": True, "default": ""}),
|
||
"aspect_ratio": ("STRING", {"multiline": False, "default": "1:1"}),
|
||
"environment": (["prod", "dev", "test"], {"default": "prod"}),
|
||
},
|
||
"optional": {
|
||
"image": ("IMAGE",),
|
||
"img_urls": ("STRING", {"multiline": False, "default": "", "description": '输入图片的链接'}),
|
||
"output_count": ("STRING", {"default": 1, "multiline": False}),
|
||
"image_filename": ("STRING", {"multiline": False, "default": ""}),
|
||
"image_urls": (
|
||
"STRING", {"multiline": True, "default": "", "placeholder": "单个或多个图片URL,用英文逗号隔开..."}),
|
||
"num_images": ("INT", {"default": 1, "min": 1, "max": 9, "step": 1}),
|
||
"image_size": (size_options, {"default": "2K"}),
|
||
}
|
||
}
|
||
|
||
RETURN_TYPES = ("STRING",)
|
||
RETURN_NAMES = ("data",)
|
||
FUNCTION = "submit_task"
|
||
CATEGORY = "不忘科技-自定义节点🚩/api/图片生成"
|
||
|
||
def _get_base_url_and_tech_name(self, environment, model_name_display):
|
||
env_map = {
|
||
"prod": "https://bowongai-prod--text-video-agent-fastapi-app.modal.run",
|
||
"dev": "https://bowongai-dev--text-video-agent-fastapi-app.modal.run",
|
||
"test": "https://bowongai-test--text-video-agent-fastapi-app.modal.run"
|
||
}
|
||
base_url = env_map.get(environment, env_map["prod"])
|
||
tech_name = self.MODEL_DATA["display_to_tech_name"].get(model_name_display)
|
||
if not tech_name:
|
||
match = re.search(r'\((.*?)\)', model_name_display)
|
||
tech_name = match.group(1) if match else model_name_display
|
||
logger.info(f"环境: [{environment}], 模型: '{model_name_display}' -> '{tech_name}'")
|
||
return base_url, tech_name
|
||
|
||
def submit_task(self, model_name_display, prompt, aspect_ratio, environment,
|
||
image_filename=None, image=None, image_urls="", num_images=1, image_size="2K"):
|
||
file_obj = None
|
||
try:
|
||
base_url, tech_name = self._get_base_url_and_tech_name(environment, model_name_display)
|
||
api_endpoint = f'{base_url}/api/custom/image/submit/task'
|
||
headers = {'accept': 'application/json'}
|
||
if "doubao-seedream-4-0" in tech_name:
|
||
logger.info(f"检测到豆包模型 '{tech_name}',启用特定提交逻辑。")
|
||
|
||
# 从尺寸选项中提取实际值(例如 '2048x2048 (1:1)' -> '2048x2048')
|
||
actual_size = image_size.split(' ')[0]
|
||
|
||
extra_params = {
|
||
"sequential_image_generation": "auto",
|
||
"size": actual_size,
|
||
"max_images": num_images
|
||
}
|
||
|
||
payload = {
|
||
'prompt': prompt,
|
||
'model_name': tech_name,
|
||
'aspect_ratio': aspect_ratio,
|
||
'mode': 'turbo',
|
||
'webhook_flag': 'false',
|
||
'watermark': 'false',
|
||
'extra': json.dumps(extra_params)
|
||
}
|
||
|
||
if image_urls and image_urls.strip():
|
||
logger.info(f"使用提供的图片URL: {image_urls.strip()}")
|
||
image_urls = image_urls.strip()
|
||
image_urls = image_urls.strip().replace('\n', ',')
|
||
payload['img_list'] = image_urls
|
||
else:
|
||
logger.info("未提供图片URL,将以文生图模式运行。")
|
||
|
||
logger.info(f"向豆包模型端点 {api_endpoint} 发送请求...")
|
||
response = requests.post(
|
||
api_endpoint, headers=headers, data=payload, timeout=60
|
||
)
|
||
|
||
else:
|
||
# 原始逻辑,适用于其他所有模型
|
||
logger.info(f"使用标准逻辑提交任务到模型 '{tech_name}'。")
|
||
model_config = self.MODEL_DATA["configs"].get(tech_name)
|
||
if not model_config:
|
||
raise ValueError(f"无法找到模型 '{tech_name}' 的配置。")
|
||
|
||
def validate_and_correct_parameter(param_name, user_value, supported_values):
|
||
if not supported_values: return user_value
|
||
if user_value in supported_values: return user_value
|
||
|
||
default_value = supported_values[0]
|
||
logger.warning(
|
||
f"参数警告!模型 '{tech_name}' 不支持 '{param_name}': '{user_value}'。"
|
||
f"已自动替换为支持的默认值: '{default_value}'。支持的选项: {supported_values}"
|
||
)
|
||
return default_value
|
||
|
||
final_ar = validate_and_correct_parameter("宽高比", aspect_ratio, model_config.get("supported_ar", []))
|
||
|
||
payload = {'prompt': prompt, 'model_name': tech_name, 'aspect_ratio': final_ar,
|
||
'mode': 'turbo', 'webhook_flag': 'false'}
|
||
files_to_send = {}
|
||
|
||
if image is not None:
|
||
logger.info(f"检测到 IMAGE (Tensor) 输入,优先处理。")
|
||
img_tensor = image[0]
|
||
img_np = np.clip(255. * img_tensor.cpu().numpy(), 0, 255).astype(np.uint8)
|
||
pil_image = Image.fromarray(img_np)
|
||
buffer = io.BytesIO()
|
||
pil_image.save(buffer, format="PNG")
|
||
buffer.seek(0)
|
||
files_to_send['img_file'] = ('image_from_workflow.png', buffer, 'image/png')
|
||
elif image_filename and image_filename.strip():
|
||
logger.info(f"处理文件名: {image_filename}")
|
||
full_path = folder_paths.get_full_path("input", image_filename.strip())
|
||
if not (full_path and os.path.exists(full_path)):
|
||
return (f"错误: 在ComfyUI的input文件夹中未找到文件 '{image_filename}'",)
|
||
filename = os.path.basename(full_path)
|
||
mime_type, _ = mimetypes.guess_type(full_path) or ('application/octet-stream', None)
|
||
file_obj = open(full_path, 'rb')
|
||
files_to_send['img_file'] = (filename, file_obj, mime_type)
|
||
else:
|
||
logger.info("未提供任何图像输入,以纯文本模式运行。")
|
||
|
||
logger.info(f"向标准端点 {api_endpoint} 发送请求...")
|
||
response = requests.post(
|
||
api_endpoint, headers=headers, data=payload, files=files_to_send, timeout=60
|
||
)
|
||
response.raise_for_status()
|
||
response_json = response.json()
|
||
logger.info(f"任务提交成功,完整响应: {json.dumps(response_json, indent=2, ensure_ascii=False)}")
|
||
|
||
if response_json.get('status') is True:
|
||
return (str(response_json.get('data', "错误: 状态为true但缺少data字段")),)
|
||
else:
|
||
return (json.dumps(response_json, indent=4, ensure_ascii=False),)
|
||
|
||
except Exception as e:
|
||
logger.error(f"任务处理失败: {e}")
|
||
return (f"错误: {str(e)}",)
|
||
finally:
|
||
if file_obj:
|
||
file_obj.close()
|
||
|
||
|
||
NODE_CLASS_MAPPINGS = {
|
||
"ImgSubmitNode": ImgSubmitNode
|
||
}
|
||
|
||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||
"ImgSubmitNode": "提交图片生成"
|
||
} |