ComfyUI-CustomNode/nodes/img_agent.py

263 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
File img_agent.py
Author silence
Date 2025/9/6
"""
import json
import requests
import os
import folder_paths
import mimetypes
from PIL import Image
import numpy as np
import torch
import io
import re
try:
from loguru import logger
except ImportError:
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger("ImgSubmitNode")
print("提示: loguru 未安装使用内置logging。建议安装以获得更好的日志体验: pip install loguru")
def fetch_and_process_image_models():
"""
在节点加载时从API获取生图模型列表并存储其配置用于后端校验。
"""
image_model_urls = {
"prod": "https://bowongai-prod--text-video-agent-fastapi-app.modal.run/api/custom/model/list?category=image",
"dev": "https://bowongai-dev--text-video-agent-fastapi-app.modal.run/api/custom/model/list?category=image",
"test": "https://bowongai-test--text-video-agent-fastapi-app.modal.run/api/custom/model/list?category=image"
}
model_data = {
"configs": {},
"full_display_list": [],
"display_to_tech_name": {},
"temp_list_for_sorting": []
}
try:
response = None
for env, url in image_model_urls.items():
try:
response = requests.get(url, timeout=10)
response.raise_for_status()
logger.info(f"成功从 [{env}] 环境获取生图模型列表。")
break
except requests.exceptions.RequestException:
logger.warning(f"无法从 [{env}] 环境获取模型列表,尝试下一个...")
continue
if not response:
raise ConnectionError("所有环境的模型列表API都无法访问。")
data = response.json()
if not data.get("status") or "data" not in data:
raise ValueError(f"API响应格式错误: {data.get('msg', '未知错误')}")
for model in data["data"]:
tech_name = model.get("model_name")
if not tech_name:
continue
description_from_api = str(model.get("description", tech_name)).strip()
display_name = f"{description_from_api} ({tech_name})"
mode = model.get("mode")
sort_key = 99
if mode == "t2i":
sort_key = 0
elif mode == "i2i":
sort_key = 1
elif mode == "both":
sort_key = 2
# 4. 存储所有信息
model_data["configs"][tech_name] = model
model_data["display_to_tech_name"][display_name] = tech_name
model_data["temp_list_for_sorting"].append((sort_key, display_name))
model_data["temp_list_for_sorting"].sort(key=lambda x: (x[0], x[1]))
model_data["full_display_list"] = [item[1] for item in model_data["temp_list_for_sorting"]]
except Exception as e:
logger.error(f"加载生图模型数据失败: {e}")
if not model_data["full_display_list"]:
model_data["full_display_list"] = ["错误:无法加载模型"]
return model_data
IMAGE_MODEL_DATA = fetch_and_process_image_models()
class ImgSubmitNode:
MODEL_DATA = IMAGE_MODEL_DATA
@classmethod
def INPUT_TYPES(cls):
# 从用户提供的截图中提取的尺寸选项
size_options = [
'2048x2048 (1:1)', '2394x1728 (4:3)', '1728x2394 (3:4)', '2560x1440 (16:9)',
'1440x2560 (9:16)', '2496x1664 (3:2)', '1664x2496 (2:3)', '3024x1296 (21:9)',
'1K', '2K', '4K'
]
return {
"required": {
"model_name_display": (cls.MODEL_DATA["full_display_list"],),
"prompt": ("STRING", {"multiline": True, "default": ""}),
"aspect_ratio": ("STRING", {"multiline": False, "default": "1:1"}),
"environment": (["prod", "dev", "test"], {"default": "prod"}),
},
"optional": {
"image": ("IMAGE",),
"img_urls": ("STRING", {"multiline": False, "default": "", "description": '输入图片的链接'}),
"output_count": ("STRING", {"default": 1, "multiline": False}),
"image_filename": ("STRING", {"multiline": False, "default": ""}),
"image_urls": (
"STRING", {"multiline": True, "default": "", "placeholder": "单个或多个图片URL用英文逗号隔开..."}),
"num_images": ("INT", {"default": 1, "min": 1, "max": 9, "step": 1}),
"image_size": (size_options, {"default": "2K"}),
}
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("data",)
FUNCTION = "submit_task"
CATEGORY = "不忘科技-自定义节点🚩/api/图片生成"
def _get_base_url_and_tech_name(self, environment, model_name_display):
env_map = {
"prod": "https://bowongai-prod--text-video-agent-fastapi-app.modal.run",
"dev": "https://bowongai-dev--text-video-agent-fastapi-app.modal.run",
"test": "https://bowongai-test--text-video-agent-fastapi-app.modal.run"
}
base_url = env_map.get(environment, env_map["prod"])
tech_name = self.MODEL_DATA["display_to_tech_name"].get(model_name_display)
if not tech_name:
match = re.search(r'\((.*?)\)', model_name_display)
tech_name = match.group(1) if match else model_name_display
logger.info(f"环境: [{environment}], 模型: '{model_name_display}' -> '{tech_name}'")
return base_url, tech_name
def submit_task(self, model_name_display, prompt, aspect_ratio, environment,
image_filename=None, image=None, image_urls="", num_images=1, image_size="2K"):
file_obj = None
try:
base_url, tech_name = self._get_base_url_and_tech_name(environment, model_name_display)
api_endpoint = f'{base_url}/api/custom/image/submit/task'
headers = {'accept': 'application/json'}
if "doubao-seedream-4-0" in tech_name:
logger.info(f"检测到豆包模型 '{tech_name}',启用特定提交逻辑。")
# 从尺寸选项中提取实际值(例如 '2048x2048 (1:1)' -> '2048x2048'
actual_size = image_size.split(' ')[0]
extra_params = {
"sequential_image_generation": "auto",
"size": actual_size,
"max_images": num_images
}
payload = {
'prompt': prompt,
'model_name': tech_name,
'aspect_ratio': aspect_ratio,
'mode': 'turbo',
'webhook_flag': 'false',
'watermark': 'false',
'extra': json.dumps(extra_params)
}
if image_urls and image_urls.strip():
logger.info(f"使用提供的图片URL: {image_urls.strip()}")
image_urls = image_urls.strip()
image_urls = image_urls.strip().replace('\n', ',')
payload['img_list'] = image_urls
else:
logger.info("未提供图片URL将以文生图模式运行。")
logger.info(f"向豆包模型端点 {api_endpoint} 发送请求...")
response = requests.post(
api_endpoint, headers=headers, data=payload, timeout=60
)
else:
# 原始逻辑,适用于其他所有模型
logger.info(f"使用标准逻辑提交任务到模型 '{tech_name}'")
model_config = self.MODEL_DATA["configs"].get(tech_name)
if not model_config:
raise ValueError(f"无法找到模型 '{tech_name}' 的配置。")
def validate_and_correct_parameter(param_name, user_value, supported_values):
if not supported_values: return user_value
if user_value in supported_values: return user_value
default_value = supported_values[0]
logger.warning(
f"参数警告!模型 '{tech_name}' 不支持 '{param_name}': '{user_value}'"
f"已自动替换为支持的默认值: '{default_value}'。支持的选项: {supported_values}"
)
return default_value
final_ar = validate_and_correct_parameter("宽高比", aspect_ratio, model_config.get("supported_ar", []))
payload = {'prompt': prompt, 'model_name': tech_name, 'aspect_ratio': final_ar,
'mode': 'turbo', 'webhook_flag': 'false'}
files_to_send = {}
if image is not None:
logger.info(f"检测到 IMAGE (Tensor) 输入,优先处理。")
img_tensor = image[0]
img_np = np.clip(255. * img_tensor.cpu().numpy(), 0, 255).astype(np.uint8)
pil_image = Image.fromarray(img_np)
buffer = io.BytesIO()
pil_image.save(buffer, format="PNG")
buffer.seek(0)
files_to_send['img_file'] = ('image_from_workflow.png', buffer, 'image/png')
elif image_filename and image_filename.strip():
logger.info(f"处理文件名: {image_filename}")
full_path = folder_paths.get_full_path("input", image_filename.strip())
if not (full_path and os.path.exists(full_path)):
return (f"错误: 在ComfyUI的input文件夹中未找到文件 '{image_filename}'",)
filename = os.path.basename(full_path)
mime_type, _ = mimetypes.guess_type(full_path) or ('application/octet-stream', None)
file_obj = open(full_path, 'rb')
files_to_send['img_file'] = (filename, file_obj, mime_type)
else:
logger.info("未提供任何图像输入,以纯文本模式运行。")
logger.info(f"向标准端点 {api_endpoint} 发送请求...")
response = requests.post(
api_endpoint, headers=headers, data=payload, files=files_to_send, timeout=60
)
response.raise_for_status()
response_json = response.json()
logger.info(f"任务提交成功,完整响应: {json.dumps(response_json, indent=2, ensure_ascii=False)}")
if response_json.get('status') is True:
return (str(response_json.get('data', "错误: 状态为true但缺少data字段")),)
else:
return (json.dumps(response_json, indent=4, ensure_ascii=False),)
except Exception as e:
logger.error(f"任务处理失败: {e}")
return (f"错误: {str(e)}",)
finally:
if file_obj:
file_obj.close()
NODE_CLASS_MAPPINGS = {
"ImgSubmitNode": ImgSubmitNode
}
NODE_DISPLAY_NAME_MAPPINGS = {
"ImgSubmitNode": "提交图片生成"
}