ADD 增加持久化到DB节点

This commit is contained in:
康宇佳 2025-02-27 14:28:18 +08:00
parent 4de4ad938b
commit 473ca10397
2 changed files with 110 additions and 17 deletions

View File

@ -3,20 +3,27 @@ import json
import os
import shutil
import traceback
import urllib.request
import uuid
from datetime import datetime
import server
import cv2
import ffmpy
import numpy as np
import torch
import yaml
from ultralytics import YOLO
from comfy import model_management
from qcloud_cos import CosConfig, CosClientError, CosServiceError
from qcloud_cos import CosS3Client
from sqlalchemy import Column, Integer, func, DateTime, ForeignKey, String, create_engine
from sqlalchemy.orm import sessionmaker
from ultralytics import YOLO
from sqlalchemy.ext.declarative import declarative_base
Base = declarative_base()
from .test_single_image import test_node
import ffmpy
video_extensions = ["webm", "mp4", "mkv", "gif", "mov"]
@ -147,10 +154,10 @@ class FaceExtract:
template.fill(20)
for a, a1 in zip(list(range(int(x1), int(x2))), list(range(face_size))):
for b, b1 in zip(
list(range(int(y1), int(y2))), list(range(face_size))
list(range(int(y1), int(y2))), list(range(face_size))
):
if (a >= 0 and a < r.orig_img.shape[0]) and (
b >= 0 and b < r.orig_img.shape[1]
b >= 0 and b < r.orig_img.shape[1]
):
template[a1][b1] = r.orig_img[a][b]
print(int(x1), int(x2), int(y1), int(y2))
@ -194,11 +201,11 @@ class COSDownload:
for i in range(0, 10):
try:
with open(
os.path.join(
os.path.dirname(os.path.abspath(__file__)), "config.yaml"
),
encoding="utf-8",
mode="r+",
os.path.join(
os.path.dirname(os.path.abspath(__file__)), "config.yaml"
),
encoding="utf-8",
mode="r+",
) as f:
yaml_config = yaml.load(f, Loader=yaml.FullLoader)
config = CosConfig(
@ -253,11 +260,11 @@ class COSUpload:
for i in range(0, 10):
try:
with open(
os.path.join(
os.path.dirname(os.path.abspath(__file__)), "config.yaml"
),
encoding="utf-8",
mode="r+",
os.path.join(
os.path.dirname(os.path.abspath(__file__)), "config.yaml"
),
encoding="utf-8",
mode="r+",
) as f:
yaml_config = yaml.load(f, Loader=yaml.FullLoader)
config = CosConfig(
@ -282,7 +289,20 @@ class COSUpload:
)
break
except CosClientError or CosServiceError as e:
print(e)
raise RuntimeError("上传失败")
data = {"prompt_id": "",
"video_url": "https://{}.cos.{}.myqcloud.com/{}".format(yaml_config['bucket'], yaml_config['region'],
'/'.join([yaml_config['subfolder'],
path.split('/')[
-1] if '/' in path else
path.split('\\')[-1], ]))
}
headers = {'Content-Type': 'application/json'}
try:
req = urllib.request.Request("", data=json.dumps(data).encode("utf-8"), headers=headers)
response = urllib.request.urlopen(req)
except:
raise RuntimeError("上报MQ状态失败")
return (
"/".join(
[
@ -293,6 +313,76 @@ class COSUpload:
)
class Task(Base):
__tablename__ = 'task'
id = Column(Integer, primary_key=True)
gmt_create = Column(DateTime(timezone=True), server_default=func.now())
gmt_modified = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())
prompt_id = Column(String, index=True, nullable=False)
result = Column(String, nullable=True)
job_id = Column(String, index=True, nullable=False)
def __repr__(self):
return f"{self.id},{self.gmt_create},{self.gmt_modified},{self.prompt_id},{self.result}"
class LogToDB:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"job_id": ("STRING",{"forceInput": True}),
"log": ("STRING",{"forceInput": True}),
},
"hidden": {
"unique_id": "UNIQUE_ID",
}
}
RETURN_TYPES = ("STRING",)
FUNCTION = "log2db"
OUTPUT_NODE = True
OUTPUT_IS_LIST = (True,)
# OUTPUT_NODE = False
CATEGORY = "不忘科技-自定义节点🚩"
def log2db(self, job_id, log, unique_id):
# 获取comfy服务器队列信息
(_, prompt_id, prompt, extra_data, outputs_to_execute) = next(
iter(server.PromptServer.instance.prompt_queue.currently_running.values()))
engine = create_engine(
"mysql+pymysql://root:*k3&5xxG6oqHJM@sh-cdb-1xspb808.sql.tencentcdb.com:28795/comfy",
echo=True
)
# Base.metadata.create_all(engine)
session = sessionmaker(bind=engine)()
# 查询
tasks = session.query(Task).filter(Task.prompt_id == prompt_id).all()
print(prompt)
result = {
"curr_node_id": str(unique_id),
"last_node_id": list(prompt.keys())[-1],
"node_output": str(log)
}
if len(tasks) == 0:
# 不存在插入
task = Task(prompt_id=prompt_id, job_id=job_id, result=json.dumps(result))
session.add(task)
elif len(tasks) == 1:
# 存在更新
session.query(Task).filter(Task.prompt_id == prompt_id).update({"result": json.dumps(result)})
else:
# 异常报错
raise RuntimeError("状态数据库prompt_id不唯一, 无法记录状态!")
session.commit()
return {"ui": {"text": json.dumps(result)}, "result": (json.dumps(result),)}
class VideoCut:
"""FFMPEG视频剪辑 -- !有卡顿问题 暂废弃"""
@ -426,7 +516,6 @@ from tencentcloud.common import credential
from tencentcloud.vod.v20180717 import vod_client, models
import requests
from pathlib import Path
import tempfile
class VodToLocalNode:
@ -540,6 +629,7 @@ NODE_CLASS_MAPPINGS = {
"COSDownload": COSDownload,
"VideoCutCustom": VideoCut,
"VodToLocal": VodToLocalNode,
"LogToDB": LogToDB
}
# A dictionary that contains the friendly/humanly readable titles for the nodes
@ -550,4 +640,5 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"COSDownload": "COS下载",
"VideoCutCustom": "视频剪裁",
"VodToLocal": "腾讯云VOD下载",
"LogToDB": "状态持久化DB"
}

View File

@ -7,4 +7,6 @@ opencv-python
ultralytics
cos-python-sdk-v5
tencentcloud-sdk-python
requests
requests
sqlalchemy
pymysql