669 lines
22 KiB
Python
669 lines
22 KiB
Python
import io
|
||
import json
|
||
import os
|
||
import subprocess
|
||
import tempfile
|
||
import time
|
||
import uuid
|
||
from time import sleep
|
||
from typing import Any, Dict
|
||
|
||
import cv2
|
||
import folder_paths
|
||
import numpy as np
|
||
import requests
|
||
import torch
|
||
from loguru import logger
|
||
from PIL import Image
|
||
from torchvision.transforms import transforms
|
||
from tqdm import tqdm
|
||
|
||
from ..utils.config_utils import config
|
||
from ..utils.object_storage import UploadResult, get_provider
|
||
|
||
|
||
class JMUtils:
|
||
"""
|
||
即梦AI工具类
|
||
|
||
提供即梦AI视频生成服务的完整功能,包括:
|
||
- 图像上传到云存储
|
||
- 任务提交和状态查询
|
||
- 视频下载和处理
|
||
- 张量和图像格式转换
|
||
|
||
使用统一的存储抽象层,支持多种云存储服务。
|
||
"""
|
||
|
||
def __init__(self):
|
||
"""
|
||
初始化即梦工具实例
|
||
|
||
从配置中读取API密钥和存储配置信息
|
||
"""
|
||
try:
|
||
# 获取即梦API配置
|
||
self.api_key = config.get_config("jm_api_key")
|
||
if not self.api_key:
|
||
raise ValueError("即梦API密钥未配置")
|
||
|
||
# 获取COS存储配置(用于素材上传)
|
||
cos_config = config.get_cos_config()
|
||
self.cos_bucket_name = cos_config.get("bucket_name") or config.get_config(
|
||
"cos_sucai_bucket_name"
|
||
)
|
||
|
||
if not self.cos_bucket_name:
|
||
raise ValueError("COS素材存储桶未配置")
|
||
|
||
# 获取存储提供者
|
||
self.storage_provider = get_provider("cos")
|
||
|
||
logger.info(f"即梦工具初始化成功,使用存储桶: {self.cos_bucket_name}")
|
||
|
||
except Exception as e:
|
||
logger.error(f"即梦工具初始化失败: {e}")
|
||
raise
|
||
|
||
def submit_task(
|
||
self, prompt: str, img_url: str, duration: str = "10", resolution: str = "720p"
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
提交即梦AI视频生成任务
|
||
|
||
Args:
|
||
prompt: 生成提示词
|
||
img_url: 输入图像URL
|
||
duration: 视频时长(秒)
|
||
resolution: 视频分辨率
|
||
|
||
Returns:
|
||
Dict: 任务提交结果
|
||
- status: 是否成功
|
||
- data: 任务ID或原图URL
|
||
- msg: 消息
|
||
"""
|
||
try:
|
||
# 验证输入参数
|
||
if not prompt or not prompt.strip():
|
||
return {"status": False, "data": None, "msg": "提示词不能为空"}
|
||
if not img_url or not img_url.strip():
|
||
return {"status": False, "data": None, "msg": "图像URL不能为空"}
|
||
|
||
headers = {
|
||
"Content-Type": "application/json",
|
||
"Authorization": f"Bearer {self.api_key}",
|
||
}
|
||
|
||
json_data = {
|
||
"model": "doubao-seedance-1-0-pro-250528",
|
||
"content": [
|
||
{
|
||
"type": "text",
|
||
"text": f"{prompt.strip()} --resolution {resolution} --dur {duration} --camerafixed false",
|
||
},
|
||
{
|
||
"type": "image_url",
|
||
"image_url": {
|
||
"url": img_url.strip(),
|
||
},
|
||
},
|
||
],
|
||
}
|
||
|
||
logger.info(
|
||
f"即梦任务提交中: prompt='{prompt[:50]}...', resolution={resolution}, duration={duration}"
|
||
)
|
||
|
||
response = requests.post(
|
||
"https://ark.cn-beijing.volces.com/api/v3/contents/generations/tasks",
|
||
headers=headers,
|
||
json=json_data,
|
||
timeout=30,
|
||
)
|
||
response.raise_for_status()
|
||
|
||
resp_json = response.json()
|
||
logger.info(
|
||
f"即梦任务提交响应: {json.dumps(resp_json, ensure_ascii=False)}"
|
||
)
|
||
|
||
if "id" not in resp_json:
|
||
error_msg = "未知错误"
|
||
if "error" in resp_json and "message" in resp_json["error"]:
|
||
error_msg = resp_json["error"]["message"]
|
||
return {
|
||
"status": False,
|
||
"data": img_url,
|
||
"msg": f"任务提交失败: {error_msg}",
|
||
}
|
||
else:
|
||
job_id = resp_json["id"]
|
||
logger.info(f"即梦任务提交成功,任务ID: {job_id}")
|
||
return {"data": job_id, "status": True, "msg": "任务提交成功"}
|
||
|
||
except requests.RequestException as e:
|
||
logger.error(f"即梦API请求失败: {e}")
|
||
return {"data": None, "status": False, "msg": f"网络请求失败: {str(e)}"}
|
||
except Exception as e:
|
||
logger.error(f"即梦任务提交异常: {e}")
|
||
return {"data": None, "status": False, "msg": str(e)}
|
||
|
||
def query_status(self, job_id: str) -> Dict[str, Any]:
|
||
"""
|
||
查询即梦AI任务状态
|
||
|
||
Args:
|
||
job_id: 任务ID
|
||
|
||
Returns:
|
||
Dict: 任务状态查询结果
|
||
- status: 任务是否完成成功
|
||
- data: 视频URL(如果完成)
|
||
- msg: 状态消息
|
||
"""
|
||
resp_dict = {"status": False, "data": None, "msg": ""}
|
||
|
||
try:
|
||
if not job_id or not job_id.strip():
|
||
resp_dict["msg"] = "任务ID不能为空"
|
||
return resp_dict
|
||
|
||
headers = {
|
||
"Content-Type": "application/json",
|
||
"Authorization": f"Bearer {self.api_key}",
|
||
}
|
||
|
||
response = requests.get(
|
||
f"https://ark.cn-beijing.volces.com/api/v3/contents/generations/tasks/{job_id.strip()}",
|
||
headers=headers,
|
||
timeout=15,
|
||
)
|
||
response.raise_for_status()
|
||
|
||
resp_json = response.json()
|
||
task_status = resp_json.get("status", "unknown")
|
||
|
||
# 任务完成成功
|
||
if task_status == "succeeded":
|
||
resp_dict["status"] = True
|
||
resp_dict["msg"] = "任务完成"
|
||
if "content" in resp_json and "video_url" in resp_json["content"]:
|
||
resp_dict["data"] = resp_json["content"]["video_url"]
|
||
else:
|
||
resp_dict["status"] = False
|
||
resp_dict["msg"] = "任务完成但未找到视频URL"
|
||
|
||
# 任务失败
|
||
elif task_status in ["failed", "error"]:
|
||
resp_dict["status"] = False
|
||
error_msg = "任务失败"
|
||
if "error" in resp_json:
|
||
error_msg += f": {resp_json['error'].get('message', '未知错误')}"
|
||
resp_dict["msg"] = error_msg
|
||
|
||
# 任务进行中
|
||
elif task_status in ["pending", "running", "processing"]:
|
||
resp_dict["status"] = False
|
||
resp_dict["msg"] = f"任务进行中: {task_status}"
|
||
|
||
# 其他状态
|
||
else:
|
||
resp_dict["status"] = False
|
||
resp_dict["msg"] = f"未知任务状态: {task_status}"
|
||
|
||
except requests.RequestException as e:
|
||
logger.error(f"即梦状态查询网络错误: {e}")
|
||
resp_dict["msg"] = f"网络请求失败: {str(e)}"
|
||
except Exception as e:
|
||
logger.error(f"即梦状态查询异常: {e}")
|
||
resp_dict["msg"] = str(e)
|
||
|
||
return resp_dict
|
||
|
||
def upload_io_to_cos(
|
||
self, file: io.IOBase, mime_type: str = "image/png"
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
上传IO对象到COS存储
|
||
|
||
Args:
|
||
file: 文件IO对象
|
||
mime_type: MIME类型
|
||
|
||
Returns:
|
||
Dict: 包含上传结果的字典
|
||
- status: 是否成功
|
||
- data: 上传后的URL
|
||
- msg: 消息
|
||
"""
|
||
resp_data = {"status": True, "data": "", "msg": ""}
|
||
|
||
try:
|
||
# 解析MIME类型
|
||
parts = mime_type.split("/")
|
||
category = parts[0] if len(parts) > 0 else "file"
|
||
suffix = parts[1] if len(parts) > 1 else "bin"
|
||
|
||
# 生成存储键名
|
||
object_key = f"tk/{category}/{uuid.uuid4()}.{suffix}"
|
||
|
||
logger.info(f"开始上传文件到COS: {object_key}")
|
||
|
||
# 读取文件内容
|
||
file_content = file.read()
|
||
file.seek(0) # 重置文件指针
|
||
|
||
# 使用统一存储接口上传
|
||
result: UploadResult = self.storage_provider.upload_bytes(
|
||
file_content, object_key, bucket_name=self.cos_bucket_name
|
||
)
|
||
|
||
if result.success:
|
||
# 构造COS URL(如果result中没有提供)
|
||
if result.url:
|
||
resp_data["data"] = result.url
|
||
else:
|
||
# 构造默认的COS URL
|
||
cos_config = config.get_cos_config()
|
||
region = cos_config.get("region", "ap-beijing")
|
||
resp_data["data"] = (
|
||
f"https://{self.cos_bucket_name}.cos.{region}.myqcloud.com/{object_key}"
|
||
)
|
||
|
||
resp_data["msg"] = "上传成功"
|
||
logger.info(f"文件上传成功: {resp_data['data']}")
|
||
else:
|
||
resp_data["status"] = False
|
||
resp_data["msg"] = result.message or "上传失败"
|
||
logger.error(f"文件上传失败: {resp_data['msg']}")
|
||
|
||
except Exception as e:
|
||
logger.error(f"上传文件时发生异常: {e}")
|
||
resp_data["status"] = False
|
||
resp_data["msg"] = str(e)
|
||
|
||
return resp_data
|
||
|
||
def tensor_to_io(self, tensor: torch.Tensor) -> io.BytesIO:
|
||
"""
|
||
将PyTorch张量转换为PNG格式的IO对象
|
||
|
||
Args:
|
||
tensor: PyTorch图像张量,支持多种格式
|
||
- (H, W) 灰度图
|
||
- (H, W, C) RGB图像
|
||
- (1, H, W, C) 批次图像
|
||
|
||
Returns:
|
||
io.BytesIO: PNG格式的字节流对象
|
||
|
||
Raises:
|
||
ValueError: 当张量格式不支持时
|
||
"""
|
||
try:
|
||
# 处理张量维度
|
||
if tensor.dim() == 4: # (1, H, W, C)
|
||
tensor = tensor.squeeze(0)
|
||
elif tensor.dim() == 2: # (H, W) 灰度图
|
||
pass # 保持原样
|
||
elif tensor.dim() == 3: # (H, W, C)
|
||
pass # 保持原样
|
||
else:
|
||
raise ValueError(f"不支持的张量维度: {tensor.dim()}D")
|
||
|
||
# 转换为numpy数组
|
||
numpy_array = tensor.cpu().numpy()
|
||
|
||
# 确保数值在有效范围内
|
||
numpy_array = np.clip(numpy_array, 0.0, 1.0)
|
||
|
||
# 转换为0-255范围的uint8
|
||
image_array = (numpy_array * 255).astype(np.uint8)
|
||
|
||
# 处理灰度图
|
||
if len(image_array.shape) == 2:
|
||
img = Image.fromarray(image_array, mode="L")
|
||
else:
|
||
img = Image.fromarray(image_array, mode="RGB")
|
||
|
||
# 保存为PNG格式的BytesIO
|
||
image_data = io.BytesIO()
|
||
img.save(image_data, format="PNG", optimize=True)
|
||
image_data.seek(0)
|
||
|
||
logger.debug(f"张量转换为PNG成功,大小: {len(image_data.getvalue())} bytes")
|
||
return image_data
|
||
|
||
except Exception as e:
|
||
logger.error(f"张量转换失败: {e}")
|
||
raise ValueError(f"张量转换为图像失败: {str(e)}")
|
||
|
||
def read_video_last_frame_to_tensor(self, video_path: str) -> torch.Tensor:
|
||
"""
|
||
读取视频文件的最后一帧并将其转换为BCHW格式的PyTorch张量。
|
||
|
||
参数:
|
||
video_path (str): 视频文件的路径。
|
||
|
||
返回:
|
||
torch.Tensor: 形状为[1, H, W, C]的张量,其中H和W分别是视频帧的高度和宽度,通道顺序为RGB。
|
||
|
||
异常:
|
||
FileNotFoundError: 如果指定的视频文件不存在。
|
||
ValueError: 如果视频文件为空或无法读取帧。
|
||
"""
|
||
# 打开视频文件
|
||
cap = cv2.VideoCapture(video_path)
|
||
|
||
if not cap.isOpened():
|
||
raise FileNotFoundError(f"无法打开视频文件: {video_path}")
|
||
|
||
# 获取视频总帧数
|
||
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||
|
||
if total_frames == 0:
|
||
cap.release()
|
||
raise ValueError("视频文件为空或无法确定帧数")
|
||
|
||
# 设置读取位置到最后一帧
|
||
cap.set(cv2.CAP_PROP_POS_FRAMES, total_frames - 1)
|
||
|
||
# 读取最后一帧
|
||
ret, frame = cap.read()
|
||
|
||
# 释放资源
|
||
cap.release()
|
||
|
||
if not ret or frame is None:
|
||
raise ValueError(f"无法读取视频的最后一帧,可能视频已损坏")
|
||
|
||
# 转换BGR到RGB (OpenCV默认读取为BGR)
|
||
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||
|
||
# 转换为PyTorch张量并调整维度为BCHW
|
||
transform = transforms.Compose(
|
||
[transforms.ToTensor()] # 转换为[C, H, W]格式的张量,值范围从0到1
|
||
)
|
||
|
||
tensor = (
|
||
transform(frame_rgb).unsqueeze(0).permute(0, 2, 3, 1)
|
||
) # 添加批次维度,变为[1, H, W, C]
|
||
|
||
return tensor
|
||
|
||
def download_video(self, url, timeout=30, retries=3, path=None):
|
||
"""下载视频到临时文件并返回文件路径"""
|
||
for attempt in range(retries):
|
||
try:
|
||
# 创建临时文件
|
||
if path:
|
||
temp_path = path
|
||
else:
|
||
temp_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
|
||
temp_path = temp_file.name
|
||
temp_file.close()
|
||
|
||
# 下载视频
|
||
print(f"开始下载视频 (尝试 {attempt + 1}/{retries})...")
|
||
response = requests.get(url, stream=True, timeout=timeout)
|
||
response.raise_for_status()
|
||
|
||
# 获取文件大小
|
||
total_size = int(response.headers.get("content-length", 0))
|
||
block_size = 1024 # 1 KB
|
||
|
||
# 使用tqdm显示下载进度
|
||
with open(temp_path, "wb") as f, tqdm(
|
||
desc=url.split("/")[-1],
|
||
total=total_size,
|
||
unit="B",
|
||
unit_scale=True,
|
||
unit_divisor=1024,
|
||
) as bar:
|
||
for data in response.iter_content(block_size):
|
||
size = f.write(data)
|
||
bar.update(size)
|
||
|
||
print(f"视频下载完成: {temp_path}")
|
||
return temp_path, self.read_video_last_frame_to_tensor(temp_path)
|
||
|
||
except Exception as e:
|
||
print(f"下载错误 (尝试 {attempt + 1}/{retries}): {str(e)}")
|
||
if attempt < retries - 1:
|
||
time.sleep((attempt + 1) * 2)
|
||
else:
|
||
raise
|
||
|
||
def jpg_to_tensor(self, image_path):
|
||
"""
|
||
将JPG图像转换为PyTorch张量
|
||
|
||
参数:
|
||
- image_path: JPG图像文件路径
|
||
|
||
返回:
|
||
- tensor: PyTorch张量
|
||
"""
|
||
try:
|
||
# 打开图像文件
|
||
image = Image.open(image_path).convert("RGB")
|
||
|
||
# 转换为张量
|
||
tensor = torch.from_numpy(np.array(image).astype(np.float32) / 255.0)[None,]
|
||
|
||
return tensor
|
||
|
||
except Exception as e:
|
||
print(f"转换失败: {str(e)}")
|
||
raise
|
||
|
||
def get_last_15th_frame_tensor(self, video_url):
|
||
"""
|
||
从视频URL截取倒数第15帧并转换为Tensor
|
||
先下载视频到本地临时文件再处理
|
||
"""
|
||
try:
|
||
# 下载视频
|
||
video_path, _ = self.download_video(video_url)
|
||
|
||
# 获取视频总帧数
|
||
cmd_frames = [
|
||
"ffprobe",
|
||
"-v",
|
||
"error",
|
||
"-select_streams",
|
||
"v:0",
|
||
"-show_entries",
|
||
"stream=nb_frames",
|
||
"-of",
|
||
"default=nokey=1:noprint_wrappers=1",
|
||
video_path,
|
||
]
|
||
|
||
result = subprocess.run(
|
||
cmd_frames, capture_output=True, text=True, check=True
|
||
)
|
||
|
||
# 处理可能的非数字输出
|
||
frame_count = result.stdout.strip()
|
||
if not frame_count.isdigit():
|
||
# 备选方案:通过解码获取帧数
|
||
print("无法获取准确帧数,尝试直接解码...")
|
||
cmd_decode = ["ffmpeg", "-i", video_path, "-f", "null", "-"]
|
||
decode_result = subprocess.run(
|
||
cmd_decode, capture_output=True, text=True
|
||
)
|
||
|
||
for line in decode_result.stderr.split("\n"):
|
||
if "frame=" in line:
|
||
parts = line.split("frame=")[-1].split()[0]
|
||
if parts.isdigit():
|
||
frame_count = int(parts)
|
||
break
|
||
else:
|
||
raise ValueError("无法确定视频帧数")
|
||
else:
|
||
frame_count = int(frame_count)
|
||
|
||
# 计算目标帧
|
||
target_frame = max(0, frame_count - 15)
|
||
print(f"视频总帧数: {frame_count}, 目标帧: {target_frame}")
|
||
|
||
# 截取指定帧
|
||
with tempfile.NamedTemporaryFile(
|
||
suffix="%03d.jpg", delete=True
|
||
) as frame_file:
|
||
frame_path = frame_file.name
|
||
|
||
cmd_extract = [
|
||
"ffmpeg",
|
||
"-ss",
|
||
f"00:00:00",
|
||
"-i",
|
||
video_path,
|
||
"-vframes",
|
||
"1",
|
||
"-vf",
|
||
f"select=eq(n\,{target_frame})",
|
||
"-vsync",
|
||
"0",
|
||
"-an",
|
||
"-y",
|
||
frame_path,
|
||
]
|
||
|
||
subprocess.run(cmd_extract, capture_output=True, check=True)
|
||
|
||
# 转换为Tensor
|
||
tensor = self.jpg_to_tensor(frame_path.replace("%03d", "001"))
|
||
except Exception as e:
|
||
raise e
|
||
return tensor
|
||
|
||
|
||
class JMGestureCorrect:
|
||
@classmethod
|
||
def INPUT_TYPES(s):
|
||
return {"required": {"image": ("IMAGE",), "resolution": (["720p", "1080p"])}}
|
||
|
||
RETURN_TYPES = ("IMAGE",)
|
||
RETURN_NAMES = ("正面图",)
|
||
FUNCTION = "gen"
|
||
CATEGORY = "不忘科技-自定义节点🚩/图片/姿态"
|
||
|
||
def gen(self, image: torch.Tensor, resolution: str):
|
||
wait_time = 240
|
||
interval = 2
|
||
client = JMUtils()
|
||
image_io = client.tensor_to_io(image)
|
||
upload_data = client.upload_io_to_cos(image_io)
|
||
if upload_data["status"]:
|
||
image_url = upload_data["data"]
|
||
else:
|
||
raise Exception("上传失败")
|
||
prompt = "Stand straight ahead, facing the camera, showing your full body, maintaining a proper posture, keeping the camera still, and ensuring that your head and feet are all within the frame"
|
||
submit_data = client.submit_task(
|
||
prompt, image_url, duration="5", resolution=resolution
|
||
)
|
||
if submit_data["status"]:
|
||
job_id = submit_data["data"]
|
||
else:
|
||
raise Exception("即梦任务提交失败")
|
||
job_data = None
|
||
for idx, _ in enumerate(range(0, wait_time, interval)):
|
||
logger.info(f"查询即梦结果 {idx + 1}")
|
||
query = client.query_status(job_id)
|
||
if query["status"]:
|
||
job_data = query["data"]
|
||
break
|
||
else:
|
||
if (
|
||
"error" in query["msg"]
|
||
or "失败" in query["msg"]
|
||
or "fail" in query["msg"]
|
||
):
|
||
raise Exception("即梦任务失败 {}".format(query["msg"]))
|
||
sleep(interval)
|
||
if not job_data:
|
||
raise Exception("即梦任务等待超时")
|
||
return (client.get_last_15th_frame_tensor(job_data),)
|
||
|
||
|
||
class JMCustom:
|
||
@classmethod
|
||
def INPUT_TYPES(s):
|
||
return {
|
||
"required": {
|
||
"image": ("IMAGE",),
|
||
"prompt": (
|
||
"STRING",
|
||
{
|
||
"default": "Stand straight ahead, facing the camera, showing your full body, maintaining a proper posture, keeping the camera still, and ensuring that your head and feet are all within the frame",
|
||
"multiline": True,
|
||
},
|
||
),
|
||
"duration": ("INT", {"default": 5, "min": 2, "max": 10}),
|
||
"resolution": (["720p", "1080p"]),
|
||
"wait_time": ("INT", {"default": 180, "min": 60, "max": 600}),
|
||
}
|
||
}
|
||
|
||
RETURN_TYPES = (
|
||
"STRING",
|
||
"IMAGE",
|
||
)
|
||
RETURN_NAMES = ("视频存储路径", "视频最后一帧")
|
||
FUNCTION = "gen"
|
||
CATEGORY = "不忘科技-自定义节点🚩/视频/即梦"
|
||
|
||
def gen(
|
||
self,
|
||
image: torch.Tensor,
|
||
prompt: str,
|
||
duration: int,
|
||
resolution: str,
|
||
wait_time: int,
|
||
):
|
||
interval = 2
|
||
client = JMUtils()
|
||
image_io = client.tensor_to_io(image)
|
||
upload_data = client.upload_io_to_cos(image_io)
|
||
if upload_data["status"]:
|
||
image_url = upload_data["data"]
|
||
else:
|
||
raise Exception("上传失败")
|
||
submit_data = client.submit_task(
|
||
prompt, image_url, str(duration), resolution=resolution
|
||
)
|
||
if submit_data["status"]:
|
||
job_id = submit_data["data"]
|
||
else:
|
||
raise Exception("即梦任务提交失败")
|
||
job_data = None
|
||
for idx, _ in enumerate(range(0, wait_time, interval)):
|
||
logger.info(f"查询即梦结果 {idx + 1}")
|
||
query = client.query_status(job_id)
|
||
if query["status"]:
|
||
job_data = query["data"]
|
||
break
|
||
else:
|
||
if (
|
||
"error" in query["msg"]
|
||
or "失败" in query["msg"]
|
||
or "fail" in query["msg"]
|
||
):
|
||
raise Exception("即梦任务失败 {}".format(query["msg"]))
|
||
sleep(interval)
|
||
if not job_data:
|
||
raise Exception("即梦任务等待超时")
|
||
output_dir = folder_paths.get_output_directory()
|
||
|
||
video_path, last_scene = client.download_video(
|
||
job_data, path=os.path.join(output_dir, f"{uuid.uuid4()}.mp4")
|
||
)
|
||
return (
|
||
video_path,
|
||
last_scene,
|
||
)
|