219 lines
7.1 KiB
Rust
219 lines
7.1 KiB
Rust
//! Topaz Video AI 模型管理演示
|
|
//!
|
|
//! 展示如何使用模型管理器检查、下载和管理 AI 模型
|
|
|
|
use std::path::Path;
|
|
use tvai::*;
|
|
|
|
#[tokio::main]
|
|
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|
println!("🧠 Topaz Video AI 模型管理演示");
|
|
println!("==============================\n");
|
|
|
|
// 检查 Topaz 安装
|
|
let topaz_path = match detect_topaz_installation() {
|
|
Some(path) => {
|
|
println!("✅ 找到 Topaz Video AI: {}", path.display());
|
|
path
|
|
}
|
|
None => {
|
|
println!("❌ 未找到 Topaz Video AI 安装");
|
|
return Ok(());
|
|
}
|
|
};
|
|
|
|
// 创建模型管理器
|
|
let model_manager = match ModelManager::new(&topaz_path) {
|
|
Ok(manager) => {
|
|
println!("✅ 模型管理器初始化成功");
|
|
manager
|
|
}
|
|
Err(e) => {
|
|
println!("❌ 模型管理器初始化失败: {}", e);
|
|
return Ok(());
|
|
}
|
|
};
|
|
|
|
// 演示各种功能
|
|
demo_model_status(&model_manager).await?;
|
|
demo_model_recommendations(&model_manager).await?;
|
|
demo_download_guide(&model_manager).await?;
|
|
demo_model_testing(&model_manager).await?;
|
|
|
|
println!("\n🎉 模型管理演示完成!");
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// 演示模型状态检查
|
|
async fn demo_model_status(model_manager: &ModelManager) -> Result<(), Box<dyn std::error::Error>> {
|
|
println!("📊 1. 模型状态检查");
|
|
println!("------------------");
|
|
|
|
// 获取所有模型
|
|
let all_models = model_manager.get_all_models()?;
|
|
println!("总模型数量: {}", all_models.len());
|
|
|
|
// 统计各类型模型
|
|
let upscale_count = all_models.iter().filter(|m| m.model_type == ModelType::Upscale).count();
|
|
let interpolation_count = all_models.iter().filter(|m| m.model_type == ModelType::Interpolation).count();
|
|
let other_count = all_models.iter().filter(|m| m.model_type == ModelType::Other).count();
|
|
|
|
println!(" 🔍 超分辨率模型: {} 个", upscale_count);
|
|
println!(" 🎬 帧插值模型: {} 个", interpolation_count);
|
|
println!(" 🔧 其他模型: {} 个", other_count);
|
|
|
|
// 检查下载状态
|
|
let downloaded_models = model_manager.get_downloaded_models()?;
|
|
let missing_models = model_manager.get_missing_models()?;
|
|
|
|
println!("\n📥 下载状态:");
|
|
println!(" ✅ 已下载: {} 个模型", downloaded_models.len());
|
|
println!(" ❌ 缺失: {} 个模型", missing_models.len());
|
|
|
|
if !downloaded_models.is_empty() {
|
|
println!("\n✅ 已下载的模型:");
|
|
for model_name in &downloaded_models {
|
|
if let Some(model) = all_models.iter().find(|m| m.short_name == *model_name) {
|
|
println!(" {} - {}",
|
|
model.short_name,
|
|
model.display_name.as_deref().unwrap_or("N/A"));
|
|
}
|
|
}
|
|
}
|
|
|
|
if !missing_models.is_empty() {
|
|
println!("\n❌ 缺失的模型 (前10个):");
|
|
for model in missing_models.iter().take(10) {
|
|
println!(" {} - {}",
|
|
model.short_name,
|
|
model.display_name.as_deref().unwrap_or("N/A"));
|
|
}
|
|
if missing_models.len() > 10 {
|
|
println!(" ... 还有 {} 个模型", missing_models.len() - 10);
|
|
}
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// 演示模型推荐
|
|
async fn demo_model_recommendations(model_manager: &ModelManager) -> Result<(), Box<dyn std::error::Error>> {
|
|
println!("\n💡 2. 模型推荐");
|
|
println!("-------------");
|
|
|
|
let use_cases = vec![
|
|
("general", "通用处理"),
|
|
("high_quality", "高质量处理"),
|
|
("fast", "快速处理"),
|
|
("gaming", "游戏内容"),
|
|
("old_video", "老视频修复"),
|
|
("portrait", "人像视频"),
|
|
];
|
|
|
|
for (use_case, description) in use_cases {
|
|
let recommended = model_manager.get_recommended_models(use_case)?;
|
|
println!("🎯 {} ({}):", description, use_case);
|
|
|
|
if recommended.is_empty() {
|
|
println!(" (无推荐模型)");
|
|
} else {
|
|
for model_name in recommended {
|
|
let is_downloaded = model_manager.is_model_downloaded(&model_name)?;
|
|
let status = if is_downloaded { "✅" } else { "❌" };
|
|
println!(" {} {}", status, model_name);
|
|
}
|
|
}
|
|
println!();
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// 演示下载指南生成
|
|
async fn demo_download_guide(model_manager: &ModelManager) -> Result<(), Box<dyn std::error::Error>> {
|
|
println!("📋 3. 生成下载指南");
|
|
println!("------------------");
|
|
|
|
let guide_path = Path::new("model_download_guide_demo.md");
|
|
|
|
match model_manager.generate_download_guide(guide_path) {
|
|
Ok(()) => {
|
|
println!("✅ 下载指南已生成: {}", guide_path.display());
|
|
|
|
// 显示指南的前几行
|
|
if let Ok(content) = std::fs::read_to_string(guide_path) {
|
|
let lines: Vec<&str> = content.lines().take(15).collect();
|
|
println!("\n📄 指南预览:");
|
|
for line in lines {
|
|
println!(" {}", line);
|
|
}
|
|
if content.lines().count() > 15 {
|
|
println!(" ... (更多内容请查看文件)");
|
|
}
|
|
}
|
|
}
|
|
Err(e) => {
|
|
println!("❌ 生成指南失败: {}", e);
|
|
}
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// 演示模型测试
|
|
async fn demo_model_testing(model_manager: &ModelManager) -> Result<(), Box<dyn std::error::Error>> {
|
|
println!("\n🧪 4. 模型测试");
|
|
println!("-------------");
|
|
|
|
let test_video = Path::new("target/demo.mp4");
|
|
|
|
if !test_video.exists() {
|
|
println!("❌ 测试视频不存在: {}", test_video.display());
|
|
println!("💡 请确保测试视频文件存在以进行模型测试");
|
|
return Ok(());
|
|
}
|
|
|
|
println!("✅ 找到测试视频: {}", test_video.display());
|
|
|
|
// 获取一些常用模型进行测试
|
|
let test_models = vec!["iris", "amq", "prob", "chr"];
|
|
|
|
println!("\n🔍 测试常用模型可用性:");
|
|
|
|
for model_name in test_models {
|
|
print!(" 测试 {} ... ", model_name);
|
|
|
|
match model_manager.attempt_model_download(model_name, test_video) {
|
|
Ok(true) => {
|
|
println!("✅ 可用或已触发下载");
|
|
}
|
|
Ok(false) => {
|
|
println!("❌ 模型不可用");
|
|
}
|
|
Err(e) => {
|
|
println!("❌ 测试失败: {}", e);
|
|
}
|
|
}
|
|
}
|
|
|
|
println!("\n💡 提示:");
|
|
println!(" - 如果模型显示'不可用',需要先在 Topaz 应用中下载");
|
|
println!(" - 某些模型可能需要特定的输入格式才能正确加载");
|
|
println!(" - 建议使用生成的下载指南进行手动下载");
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// 创建进度回调
|
|
fn create_progress_callback(operation_name: &str) -> ProgressCallback {
|
|
let name = operation_name.to_string();
|
|
Box::new(move |progress| {
|
|
let percentage = (progress * 100.0) as u32;
|
|
print!("\r{}: {}%", name, percentage);
|
|
if progress >= 1.0 {
|
|
println!();
|
|
}
|
|
})
|
|
}
|