mixvideo-v2/cargos/tvai/examples/model_management_demo.rs

219 lines
7.1 KiB
Rust

//! Topaz Video AI 模型管理演示
//!
//! 展示如何使用模型管理器检查、下载和管理 AI 模型
use std::path::Path;
use tvai::*;
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
println!("🧠 Topaz Video AI 模型管理演示");
println!("==============================\n");
// 检查 Topaz 安装
let topaz_path = match detect_topaz_installation() {
Some(path) => {
println!("✅ 找到 Topaz Video AI: {}", path.display());
path
}
None => {
println!("❌ 未找到 Topaz Video AI 安装");
return Ok(());
}
};
// 创建模型管理器
let model_manager = match ModelManager::new(&topaz_path) {
Ok(manager) => {
println!("✅ 模型管理器初始化成功");
manager
}
Err(e) => {
println!("❌ 模型管理器初始化失败: {}", e);
return Ok(());
}
};
// 演示各种功能
demo_model_status(&model_manager).await?;
demo_model_recommendations(&model_manager).await?;
demo_download_guide(&model_manager).await?;
demo_model_testing(&model_manager).await?;
println!("\n🎉 模型管理演示完成!");
Ok(())
}
/// 演示模型状态检查
async fn demo_model_status(model_manager: &ModelManager) -> Result<(), Box<dyn std::error::Error>> {
println!("📊 1. 模型状态检查");
println!("------------------");
// 获取所有模型
let all_models = model_manager.get_all_models()?;
println!("总模型数量: {}", all_models.len());
// 统计各类型模型
let upscale_count = all_models.iter().filter(|m| m.model_type == ModelType::Upscale).count();
let interpolation_count = all_models.iter().filter(|m| m.model_type == ModelType::Interpolation).count();
let other_count = all_models.iter().filter(|m| m.model_type == ModelType::Other).count();
println!(" 🔍 超分辨率模型: {}", upscale_count);
println!(" 🎬 帧插值模型: {}", interpolation_count);
println!(" 🔧 其他模型: {}", other_count);
// 检查下载状态
let downloaded_models = model_manager.get_downloaded_models()?;
let missing_models = model_manager.get_missing_models()?;
println!("\n📥 下载状态:");
println!(" ✅ 已下载: {} 个模型", downloaded_models.len());
println!(" ❌ 缺失: {} 个模型", missing_models.len());
if !downloaded_models.is_empty() {
println!("\n✅ 已下载的模型:");
for model_name in &downloaded_models {
if let Some(model) = all_models.iter().find(|m| m.short_name == *model_name) {
println!(" {} - {}",
model.short_name,
model.display_name.as_deref().unwrap_or("N/A"));
}
}
}
if !missing_models.is_empty() {
println!("\n❌ 缺失的模型 (前10个):");
for model in missing_models.iter().take(10) {
println!(" {} - {}",
model.short_name,
model.display_name.as_deref().unwrap_or("N/A"));
}
if missing_models.len() > 10 {
println!(" ... 还有 {} 个模型", missing_models.len() - 10);
}
}
Ok(())
}
/// 演示模型推荐
async fn demo_model_recommendations(model_manager: &ModelManager) -> Result<(), Box<dyn std::error::Error>> {
println!("\n💡 2. 模型推荐");
println!("-------------");
let use_cases = vec![
("general", "通用处理"),
("high_quality", "高质量处理"),
("fast", "快速处理"),
("gaming", "游戏内容"),
("old_video", "老视频修复"),
("portrait", "人像视频"),
];
for (use_case, description) in use_cases {
let recommended = model_manager.get_recommended_models(use_case)?;
println!("🎯 {} ({}):", description, use_case);
if recommended.is_empty() {
println!(" (无推荐模型)");
} else {
for model_name in recommended {
let is_downloaded = model_manager.is_model_downloaded(&model_name)?;
let status = if is_downloaded { "" } else { "" };
println!(" {} {}", status, model_name);
}
}
println!();
}
Ok(())
}
/// 演示下载指南生成
async fn demo_download_guide(model_manager: &ModelManager) -> Result<(), Box<dyn std::error::Error>> {
println!("📋 3. 生成下载指南");
println!("------------------");
let guide_path = Path::new("model_download_guide_demo.md");
match model_manager.generate_download_guide(guide_path) {
Ok(()) => {
println!("✅ 下载指南已生成: {}", guide_path.display());
// 显示指南的前几行
if let Ok(content) = std::fs::read_to_string(guide_path) {
let lines: Vec<&str> = content.lines().take(15).collect();
println!("\n📄 指南预览:");
for line in lines {
println!(" {}", line);
}
if content.lines().count() > 15 {
println!(" ... (更多内容请查看文件)");
}
}
}
Err(e) => {
println!("❌ 生成指南失败: {}", e);
}
}
Ok(())
}
/// 演示模型测试
async fn demo_model_testing(model_manager: &ModelManager) -> Result<(), Box<dyn std::error::Error>> {
println!("\n🧪 4. 模型测试");
println!("-------------");
let test_video = Path::new("target/demo.mp4");
if !test_video.exists() {
println!("❌ 测试视频不存在: {}", test_video.display());
println!("💡 请确保测试视频文件存在以进行模型测试");
return Ok(());
}
println!("✅ 找到测试视频: {}", test_video.display());
// 获取一些常用模型进行测试
let test_models = vec!["iris", "amq", "prob", "chr"];
println!("\n🔍 测试常用模型可用性:");
for model_name in test_models {
print!(" 测试 {} ... ", model_name);
match model_manager.attempt_model_download(model_name, test_video) {
Ok(true) => {
println!("✅ 可用或已触发下载");
}
Ok(false) => {
println!("❌ 模型不可用");
}
Err(e) => {
println!("❌ 测试失败: {}", e);
}
}
}
println!("\n💡 提示:");
println!(" - 如果模型显示'不可用',需要先在 Topaz 应用中下载");
println!(" - 某些模型可能需要特定的输入格式才能正确加载");
println!(" - 建议使用生成的下载指南进行手动下载");
Ok(())
}
/// 创建进度回调
fn create_progress_callback(operation_name: &str) -> ProgressCallback {
let name = operation_name.to_string();
Box::new(move |progress| {
let percentage = (progress * 100.0) as u32;
print!("\r{}: {}%", name, percentage);
if progress >= 1.0 {
println!();
}
})
}