ComfyUI-CustomNode/nodes/img_agent.py

213 lines
8.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding:utf-8 -*-
"""
File img_agent.py
Author silence
Date 2025/9/6
"""
import json
import requests
import os
import folder_paths
import mimetypes
from PIL import Image
import numpy as np
import torch
import io
import re
try:
from loguru import logger
except ImportError:
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger("ImgSubmitNode_Final")
print("提示: loguru 未安装使用内置logging。建议安装以获得更好的日志体验: pip install loguru")
def fetch_and_process_image_models():
"""
在节点加载时从API获取生图模型列表并存储其配置用于后端校验。
"""
image_model_urls = {
"prod": "https://bowongai-prod--text-video-agent-fastapi-app.modal.run/api/custom/model/list?category=image",
"dev": "https://bowongai-dev--text-video-agent-fastapi-app.modal.run/api/custom/model/list?category=image",
"test": "https://bowongai-test--text-video-agent-fastapi-app.modal.run/api/custom/model/list?category=image"
}
model_data = {
"configs": {},
"full_display_list": [],
"display_to_tech_name": {},
"temp_list_for_sorting": []
}
try:
response = None
for env, url in image_model_urls.items():
try:
response = requests.get(url, timeout=10)
response.raise_for_status()
logger.info(f"成功从 [{env}] 环境获取生图模型列表。")
break
except requests.exceptions.RequestException:
logger.warning(f"无法从 [{env}] 环境获取模型列表,尝试下一个...")
continue
if not response:
raise ConnectionError("所有环境的模型列表API都无法访问。")
data = response.json()
if not data.get("status") or "data" not in data:
raise ValueError(f"API响应格式错误: {data.get('msg', '未知错误')}")
for model in data["data"]:
tech_name = model.get("model_name")
if not tech_name:
continue
description_from_api = str(model.get("description", tech_name)).strip()
display_name = f"{description_from_api} ({tech_name})"
mode = model.get("mode")
sort_key = 99
if mode == "t2i":
sort_key = 0
elif mode == "i2i":
sort_key = 1
elif mode == "both":
sort_key = 2
# 4. 存储所有信息
model_data["configs"][tech_name] = model
model_data["display_to_tech_name"][display_name] = tech_name
model_data["temp_list_for_sorting"].append((sort_key, display_name))
model_data["temp_list_for_sorting"].sort(key=lambda x: (x[0], x[1]))
model_data["full_display_list"] = [item[1] for item in model_data["temp_list_for_sorting"]]
except Exception as e:
logger.error(f"加载生图模型数据失败: {e}")
if not model_data["full_display_list"]:
model_data["full_display_list"] = ["错误:无法加载模型"]
return model_data
IMAGE_MODEL_DATA = fetch_and_process_image_models()
class ImgSubmitNode:
MODEL_DATA = IMAGE_MODEL_DATA
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model_name_display": (cls.MODEL_DATA["full_display_list"],),
"prompt": ("STRING", {"multiline": True, "default": ""}),
"aspect_ratio": ("STRING", {"multiline": False, "default": "1:1"}),
"environment": (["prod", "dev", "test"], {"default": "prod"}),
},
"optional": {
"image": ("IMAGE",),
"image_filename": ("STRING", {"multiline": False, "default": ""}),
}
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("data",)
FUNCTION = "submit_task"
CATEGORY = "不忘科技-自定义节点🚩/api/图片生成"
def _get_base_url_and_tech_name(self, environment, model_name_display):
env_map = {
"prod": "https://bowongai-prod--text-video-agent-fastapi-app.modal.run",
"dev": "https://bowongai-dev--text-video-agent-fastapi-app.modal.run",
"test": "https://bowongai-test--text-video-agent-fastapi-app.modal.run"
}
base_url = env_map.get(environment, env_map["prod"])
tech_name = self.MODEL_DATA["display_to_tech_name"].get(model_name_display)
if not tech_name:
match = re.search(r'\((.*?)\)', model_name_display)
tech_name = match.group(1) if match else model_name_display
logger.info(f"环境: [{environment}], 模型: '{model_name_display}' -> '{tech_name}'")
return base_url, tech_name
def submit_task(self, model_name_display, prompt, aspect_ratio, environment, image_filename=None, image=None):
try:
base_url, tech_name = self._get_base_url_and_tech_name(environment, model_name_display)
model_config = self.MODEL_DATA["configs"].get(tech_name)
if not model_config:
raise ValueError(f"无法找到模型 '{tech_name}' 的配置。")
def validate_and_correct_parameter(param_name, user_value, supported_values):
if not supported_values: return user_value
if user_value in supported_values: return user_value
default_value = supported_values[0]
logger.warning(
f"参数警告!模型 '{tech_name}' 不支持 '{param_name}': '{user_value}'"
f"已自动替换为支持的默认值: '{default_value}'。支持的选项: {supported_values}"
)
return default_value
final_ar = validate_and_correct_parameter("宽高比", aspect_ratio, model_config.get("supported_ar", []))
headers = {'accept': 'application/json'}
payload = {'prompt': prompt, 'model_name': tech_name, 'aspect_ratio': final_ar,
'mode': 'turbo', 'webhook_flag': 'false'}
files_to_send = {}
file_obj = None
if image is not None:
logger.info(f"检测到 IMAGE (Tensor) 输入,优先处理。")
img_tensor = image[0]
img_np = np.clip(255. * img_tensor.cpu().numpy(), 0, 255).astype(np.uint8)
pil_image = Image.fromarray(img_np)
buffer = io.BytesIO()
pil_image.save(buffer, format="PNG")
buffer.seek(0)
files_to_send['img_file'] = ('image_from_workflow.png', buffer, 'image/png')
elif image_filename and image_filename.strip():
logger.info(f"处理文件名: {image_filename}")
full_path = folder_paths.get_full_path("input", image_filename.strip())
if not (full_path and os.path.exists(full_path)):
return (f"错误: 在ComfyUI的input文件夹中未找到文件 '{image_filename}'",)
filename = os.path.basename(full_path)
mime_type, _ = mimetypes.guess_type(full_path) or ('application/octet-stream', None)
file_obj = open(full_path, 'rb')
files_to_send['img_file'] = (filename, file_obj, mime_type)
else:
logger.info("未提供任何图像输入,以纯文本模式运行。")
api_endpoint = f'{base_url}/api/custom/image/submit/task'
logger.info(f"向端点 {api_endpoint} 发送请求...")
response = requests.post(
api_endpoint, headers=headers, data=payload, files=files_to_send, timeout=60
)
response.raise_for_status()
response_json = response.json()
logger.info(f"任务提交成功,完整响应: {json.dumps(response_json, indent=2, ensure_ascii=False)}")
if response_json.get('status') is True:
return (str(response_json.get('data', "错误: 状态为true但缺少data字段")),)
else:
return (json.dumps(response_json, indent=4, ensure_ascii=False),)
except Exception as e:
logger.error(f"任务处理失败: {e}")
return (f"错误: {str(e)}",)
finally:
if file_obj:
file_obj.close()
# NODE_CLASS_MAPPINGS = {
# "ImgSubmitNode": ImgSubmitNode
# }
#
# NODE_DISPLAY_NAME_MAPPINGS = {
# "ImgSubmitNode": "统一生图任务节点"
# }