ADD 增加持久化到DB节点

This commit is contained in:
康宇佳 2025-02-27 14:28:18 +08:00
parent 4de4ad938b
commit 473ca10397
2 changed files with 110 additions and 17 deletions

View File

@ -3,20 +3,27 @@ import json
import os import os
import shutil import shutil
import traceback import traceback
import urllib.request
import uuid import uuid
from datetime import datetime from datetime import datetime
import server
import cv2 import cv2
import ffmpy
import numpy as np import numpy as np
import torch import torch
import yaml import yaml
from ultralytics import YOLO
from comfy import model_management from comfy import model_management
from qcloud_cos import CosConfig, CosClientError, CosServiceError from qcloud_cos import CosConfig, CosClientError, CosServiceError
from qcloud_cos import CosS3Client from qcloud_cos import CosS3Client
from sqlalchemy import Column, Integer, func, DateTime, ForeignKey, String, create_engine
from sqlalchemy.orm import sessionmaker
from ultralytics import YOLO
from sqlalchemy.ext.declarative import declarative_base
Base = declarative_base()
from .test_single_image import test_node from .test_single_image import test_node
import ffmpy
video_extensions = ["webm", "mp4", "mkv", "gif", "mov"] video_extensions = ["webm", "mp4", "mkv", "gif", "mov"]
@ -147,10 +154,10 @@ class FaceExtract:
template.fill(20) template.fill(20)
for a, a1 in zip(list(range(int(x1), int(x2))), list(range(face_size))): for a, a1 in zip(list(range(int(x1), int(x2))), list(range(face_size))):
for b, b1 in zip( for b, b1 in zip(
list(range(int(y1), int(y2))), list(range(face_size)) list(range(int(y1), int(y2))), list(range(face_size))
): ):
if (a >= 0 and a < r.orig_img.shape[0]) and ( if (a >= 0 and a < r.orig_img.shape[0]) and (
b >= 0 and b < r.orig_img.shape[1] b >= 0 and b < r.orig_img.shape[1]
): ):
template[a1][b1] = r.orig_img[a][b] template[a1][b1] = r.orig_img[a][b]
print(int(x1), int(x2), int(y1), int(y2)) print(int(x1), int(x2), int(y1), int(y2))
@ -194,11 +201,11 @@ class COSDownload:
for i in range(0, 10): for i in range(0, 10):
try: try:
with open( with open(
os.path.join( os.path.join(
os.path.dirname(os.path.abspath(__file__)), "config.yaml" os.path.dirname(os.path.abspath(__file__)), "config.yaml"
), ),
encoding="utf-8", encoding="utf-8",
mode="r+", mode="r+",
) as f: ) as f:
yaml_config = yaml.load(f, Loader=yaml.FullLoader) yaml_config = yaml.load(f, Loader=yaml.FullLoader)
config = CosConfig( config = CosConfig(
@ -253,11 +260,11 @@ class COSUpload:
for i in range(0, 10): for i in range(0, 10):
try: try:
with open( with open(
os.path.join( os.path.join(
os.path.dirname(os.path.abspath(__file__)), "config.yaml" os.path.dirname(os.path.abspath(__file__)), "config.yaml"
), ),
encoding="utf-8", encoding="utf-8",
mode="r+", mode="r+",
) as f: ) as f:
yaml_config = yaml.load(f, Loader=yaml.FullLoader) yaml_config = yaml.load(f, Loader=yaml.FullLoader)
config = CosConfig( config = CosConfig(
@ -282,7 +289,20 @@ class COSUpload:
) )
break break
except CosClientError or CosServiceError as e: except CosClientError or CosServiceError as e:
print(e) raise RuntimeError("上传失败")
data = {"prompt_id": "",
"video_url": "https://{}.cos.{}.myqcloud.com/{}".format(yaml_config['bucket'], yaml_config['region'],
'/'.join([yaml_config['subfolder'],
path.split('/')[
-1] if '/' in path else
path.split('\\')[-1], ]))
}
headers = {'Content-Type': 'application/json'}
try:
req = urllib.request.Request("", data=json.dumps(data).encode("utf-8"), headers=headers)
response = urllib.request.urlopen(req)
except:
raise RuntimeError("上报MQ状态失败")
return ( return (
"/".join( "/".join(
[ [
@ -293,6 +313,76 @@ class COSUpload:
) )
class Task(Base):
__tablename__ = 'task'
id = Column(Integer, primary_key=True)
gmt_create = Column(DateTime(timezone=True), server_default=func.now())
gmt_modified = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())
prompt_id = Column(String, index=True, nullable=False)
result = Column(String, nullable=True)
job_id = Column(String, index=True, nullable=False)
def __repr__(self):
return f"{self.id},{self.gmt_create},{self.gmt_modified},{self.prompt_id},{self.result}"
class LogToDB:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"job_id": ("STRING",{"forceInput": True}),
"log": ("STRING",{"forceInput": True}),
},
"hidden": {
"unique_id": "UNIQUE_ID",
}
}
RETURN_TYPES = ("STRING",)
FUNCTION = "log2db"
OUTPUT_NODE = True
OUTPUT_IS_LIST = (True,)
# OUTPUT_NODE = False
CATEGORY = "不忘科技-自定义节点🚩"
def log2db(self, job_id, log, unique_id):
# 获取comfy服务器队列信息
(_, prompt_id, prompt, extra_data, outputs_to_execute) = next(
iter(server.PromptServer.instance.prompt_queue.currently_running.values()))
engine = create_engine(
"mysql+pymysql://root:*k3&5xxG6oqHJM@sh-cdb-1xspb808.sql.tencentcdb.com:28795/comfy",
echo=True
)
# Base.metadata.create_all(engine)
session = sessionmaker(bind=engine)()
# 查询
tasks = session.query(Task).filter(Task.prompt_id == prompt_id).all()
print(prompt)
result = {
"curr_node_id": str(unique_id),
"last_node_id": list(prompt.keys())[-1],
"node_output": str(log)
}
if len(tasks) == 0:
# 不存在插入
task = Task(prompt_id=prompt_id, job_id=job_id, result=json.dumps(result))
session.add(task)
elif len(tasks) == 1:
# 存在更新
session.query(Task).filter(Task.prompt_id == prompt_id).update({"result": json.dumps(result)})
else:
# 异常报错
raise RuntimeError("状态数据库prompt_id不唯一, 无法记录状态!")
session.commit()
return {"ui": {"text": json.dumps(result)}, "result": (json.dumps(result),)}
class VideoCut: class VideoCut:
"""FFMPEG视频剪辑 -- !有卡顿问题 暂废弃""" """FFMPEG视频剪辑 -- !有卡顿问题 暂废弃"""
@ -426,7 +516,6 @@ from tencentcloud.common import credential
from tencentcloud.vod.v20180717 import vod_client, models from tencentcloud.vod.v20180717 import vod_client, models
import requests import requests
from pathlib import Path from pathlib import Path
import tempfile
class VodToLocalNode: class VodToLocalNode:
@ -540,6 +629,7 @@ NODE_CLASS_MAPPINGS = {
"COSDownload": COSDownload, "COSDownload": COSDownload,
"VideoCutCustom": VideoCut, "VideoCutCustom": VideoCut,
"VodToLocal": VodToLocalNode, "VodToLocal": VodToLocalNode,
"LogToDB": LogToDB
} }
# A dictionary that contains the friendly/humanly readable titles for the nodes # A dictionary that contains the friendly/humanly readable titles for the nodes
@ -550,4 +640,5 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"COSDownload": "COS下载", "COSDownload": "COS下载",
"VideoCutCustom": "视频剪裁", "VideoCutCustom": "视频剪裁",
"VodToLocal": "腾讯云VOD下载", "VodToLocal": "腾讯云VOD下载",
"LogToDB": "状态持久化DB"
} }

View File

@ -8,3 +8,5 @@ ultralytics
cos-python-sdk-v5 cos-python-sdk-v5
tencentcloud-sdk-python tencentcloud-sdk-python
requests requests
sqlalchemy
pymysql