Update TVAI functionality across desktop app and cargo modules

This commit is contained in:
imeepos 2025-08-12 19:38:46 +08:00
parent 0a742f1e6b
commit d0845c3933
7 changed files with 103 additions and 23 deletions

View File

@ -198,6 +198,7 @@ pub async fn quick_upscale_video_command(
input_path: String,
output_path: String,
scale_factor: f32,
model: Option<String>,
) -> Result<String, String> {
let task_id = Uuid::new_v4().to_string();
@ -242,12 +243,38 @@ pub async fn quick_upscale_video_command(
app_handle_clone.emit("tvai_task_updated", &task_id_clone).unwrap();
let start_time = std::time::Instant::now();
// 解析模型参数
let upscale_model = if let Some(model_str) = model {
match model_str.as_str() {
"aaa-9" => Some(UpscaleModel::Aaa9),
"ahq-12" => Some(UpscaleModel::Ahq12),
"alq-13" => Some(UpscaleModel::Alq13),
"alqs-2" => Some(UpscaleModel::Alqs2),
"amq-13" => Some(UpscaleModel::Amq13),
"amqs-2" => Some(UpscaleModel::Amqs2),
"ghq-5" => Some(UpscaleModel::Ghq5),
"iris-2" => Some(UpscaleModel::Iris2),
"iris-3" => Some(UpscaleModel::Iris3),
"nyx-3" => Some(UpscaleModel::Nyx3),
"prob-4" => Some(UpscaleModel::Prob4),
"thf-4" => Some(UpscaleModel::Thf4),
"thd-3" => Some(UpscaleModel::Thd3),
"thm-2" => Some(UpscaleModel::Thm2),
"rhea-1" => Some(UpscaleModel::Rhea1),
"rxl-1" => Some(UpscaleModel::Rxl1),
_ => None, // 使用默认模型
}
} else {
None // 使用默认模型
};
// 执行处理
let result = quick_upscale_video(
let result = tvai::quick_upscale_video_with_model(
Path::new(&input_path),
Path::new(&output_path),
scale_factor,
upscale_model,
).await;
let processing_time = start_time.elapsed();

View File

@ -38,6 +38,7 @@ export function TvaiExample() {
// 高级参数
const [selectedModel, setSelectedModel] = useState<UpscaleModel>('iris-3');
const [selectedInterpolationModel, setSelectedInterpolationModel] = useState<'apo-8' | 'apf-1' | 'chr-2' | 'chf-3'>('apo-8');
const [compression, setCompression] = useState(0.0);
const [blend, setBlend] = useState(0.0);
@ -112,7 +113,7 @@ export function TvaiExample() {
try {
if (processingType === 'video') {
await tvaiService.quickUpscaleVideo(inputPath, outputPath, scaleFactor);
await tvaiService.quickUpscaleVideo(inputPath, outputPath, scaleFactor, selectedModel);
} else if (processingType === 'interpolation') {
// 插帧处理 - 需要先获取视频信息来确定帧率
try {
@ -396,7 +397,7 @@ export function TvaiExample() {
key={preset.key}
onClick={() => {
setSelectedPreset(preset.key);
setSelectedModel(preset.model as UpscaleModel);
setSelectedInterpolationModel(preset.model as 'apo-8' | 'apf-1' | 'chr-2' | 'chf-3');
}}
className={`p-3 rounded-lg border text-sm transition-all ${
selectedPreset === preset.key
@ -463,8 +464,14 @@ export function TvaiExample() {
{processingType === 'interpolation' ? '插帧模型' : 'AI 放大模型'}
</label>
<select
value={selectedModel}
onChange={(e) => setSelectedModel(e.target.value as UpscaleModel)}
value={processingType === 'interpolation' ? selectedInterpolationModel : selectedModel}
onChange={(e) => {
if (processingType === 'interpolation') {
setSelectedInterpolationModel(e.target.value as 'apo-8' | 'apf-1' | 'chr-2' | 'chf-3');
} else {
setSelectedModel(e.target.value as UpscaleModel);
}
}}
className="w-full p-3 border border-gray-300 rounded-lg focus:ring-2 focus:ring-blue-500 focus:border-transparent"
>
{processingType === 'interpolation' ? (

View File

@ -25,7 +25,7 @@ export interface UseTvaiReturn {
* TVAI Hook
*/
export function useTvai(options: UseTvaiOptions = {}): UseTvaiReturn {
const { autoRefresh = true, refreshInterval = 5000 } = options;
const { autoRefresh = true, refreshInterval = 10000 } = options; // 增加轮询间隔,主要依赖事件推送
const [tasks, setTasks] = useState<TvaiTask[]>([]);
const [isLoading, setIsLoading] = useState(false);
@ -95,16 +95,38 @@ export function useTvai(options: UseTvaiOptions = {}): UseTvaiReturn {
return tasks.filter(task => task.status === 'Failed').length;
}, [tasks]);
// 处理任务事件
// 处理任务事件 - 实时推送机制
useEffect(() => {
const handleTaskCreated = (taskId: string) => {
console.log('Task created:', taskId);
refreshTasks();
const handleTaskCreated = async (taskId: string) => {
console.log('Task created (real-time):', taskId);
// 立即刷新任务列表以获取新任务
await refreshTasks();
};
const handleTaskUpdated = (taskId: string) => {
console.log('Task updated:', taskId);
refreshTasks();
const handleTaskUpdated = async (taskId: string) => {
console.log('Task updated (real-time):', taskId);
// 获取单个任务的最新状态并更新
try {
const updatedTask = await tvaiService.getTvaiTaskStatus(taskId);
if (updatedTask) {
setTasks(prevTasks => {
const taskIndex = prevTasks.findIndex(task => task.id === taskId);
if (taskIndex >= 0) {
// 更新现有任务
const newTasks = [...prevTasks];
newTasks[taskIndex] = updatedTask;
return newTasks;
} else {
// 添加新任务(如果不存在)
return [...prevTasks, updatedTask];
}
});
}
} catch (error) {
console.error('Failed to update task:', error);
// 如果单个任务更新失败,回退到全量刷新
await refreshTasks();
}
};
// 添加事件监听器
@ -118,13 +140,23 @@ export function useTvai(options: UseTvaiOptions = {}): UseTvaiReturn {
};
}, [refreshTasks]);
// 自动刷新
// 智能轮询 - 只在有运行中的任务时才轮询
useEffect(() => {
if (!autoRefresh) return;
const interval = setInterval(refreshTasks, refreshInterval);
const hasRunningTasks = tasks.some(task =>
task.status === 'Pending' || task.status === 'Processing'
);
if (!hasRunningTasks) return; // 没有运行中的任务时不进行轮询
const interval = setInterval(() => {
console.log('Polling for task updates (fallback mechanism)');
refreshTasks();
}, refreshInterval);
return () => clearInterval(interval);
}, [autoRefresh, refreshInterval, refreshTasks]);
}, [autoRefresh, refreshInterval, refreshTasks, tasks]);
// 初始加载
useEffect(() => {

View File

@ -120,12 +120,14 @@ export class TvaiServiceImpl implements TvaiService {
async quickUpscaleVideo(
inputPath: string,
outputPath: string,
scaleFactor: number
scaleFactor: number,
model?: UpscaleModel
): Promise<string> {
return await invoke('quick_upscale_video_command', {
inputPath,
outputPath,
scaleFactor,
model,
});
}

View File

@ -170,9 +170,10 @@ export interface TvaiService {
// 快速处理
quickUpscaleVideo(
inputPath: string,
outputPath: string,
scaleFactor: number
inputPath: string,
outputPath: string,
scaleFactor: number,
model?: UpscaleModel
): Promise<string>;
quickUpscaleImage(

View File

@ -58,6 +58,7 @@ pub type Result<T> = std::result::Result<T, TvaiError>;
// Quick processing functions
pub use video::quick_upscale_video;
pub use video::quick_upscale_video_with_model;
pub use video::quick_interpolate_video;
pub use image::quick_upscale_image;
pub use video::auto_enhance_video;

View File

@ -373,6 +373,16 @@ pub async fn quick_upscale_video(
input: &Path,
output: &Path,
scale: f32,
) -> Result<ProcessResult, TvaiError> {
quick_upscale_video_with_model(input, output, scale, None).await
}
/// Quick video upscaling function with model selection
pub async fn quick_upscale_video_with_model(
input: &Path,
output: &Path,
scale: f32,
model: Option<UpscaleModel>,
) -> Result<ProcessResult, TvaiError> {
// Detect Topaz installation
let topaz_path = crate::utils::detect_topaz_installation()
@ -387,10 +397,10 @@ pub async fn quick_upscale_video(
// Create processor
let mut processor = TvaiProcessor::new(config)?;
// Create default upscaling parameters
// Create upscaling parameters with specified or default model
let params = VideoUpscaleParams {
scale_factor: scale,
model: crate::config::UpscaleModel::Iris3, // Best general purpose model
model: model.unwrap_or(crate::config::UpscaleModel::Iris3), // Use specified model or default to Iris3
compression: 0.0,
blend: 0.0,
quality_preset: crate::config::QualityPreset::HighQuality,