diff --git a/apps/desktop/src-tauri/src/infrastructure/gemini_service.rs b/apps/desktop/src-tauri/src/infrastructure/gemini_service.rs index b994359..e4d0a6d 100644 --- a/apps/desktop/src-tauri/src/infrastructure/gemini_service.rs +++ b/apps/desktop/src-tauri/src/infrastructure/gemini_service.rs @@ -12,6 +12,11 @@ pub struct GeminiConfig { pub base_url: String, pub bearer_token: String, pub timeout: u64, + pub model_name: String, + pub max_retries: u32, + pub retry_delay: u64, + pub temperature: f32, + pub max_tokens: u32, } impl Default for GeminiConfig { @@ -20,6 +25,11 @@ impl Default for GeminiConfig { base_url: "https://bowongai-dev--bowong-ai-video-gemini-fastapi-webapp.modal.run".to_string(), bearer_token: "bowong7777".to_string(), timeout: 120, + model_name: "gemini-2.5-flash".to_string(), + max_retries: 3, + retry_delay: 2, + temperature: 0.1, + max_tokens: 2048, } } } @@ -31,6 +41,13 @@ struct TokenResponse { expires_in: u64, } +/// 客户端配置 +#[derive(Debug)] +struct ClientConfig { + gateway_url: String, + headers: std::collections::HashMap, +} + /// Gemini上传响应 #[derive(Debug, Deserialize)] struct UploadResponse { @@ -186,6 +203,18 @@ impl GeminiService { Ok(token_response.access_token) } + /// 创建Gemini客户端配置 + fn create_gemini_client(&self, access_token: &str) -> ClientConfig { + let mut headers = std::collections::HashMap::new(); + headers.insert("Authorization".to_string(), format!("Bearer {}", access_token)); + headers.insert("Content-Type".to_string(), "application/json".to_string()); + + ClientConfig { + gateway_url: format!("{}/google/vertex-ai", self.config.base_url), + headers, + } + } + /// 上传视频文件到Gemini pub async fn upload_video_file(&mut self, video_path: &str) -> Result { println!("📤 开始上传视频文件: {}", video_path); @@ -262,7 +291,7 @@ impl GeminiService { Ok(file_uri) } - /// 生成内容分析 + /// 生成内容分析 (参考Python demo.py实现) pub async fn generate_content_analysis(&mut self, file_uri: &str, prompt: &str) -> Result { println!("🧠 开始生成内容分析..."); println!("📁 文件URI: {}", file_uri); @@ -271,11 +300,14 @@ impl GeminiService { // 获取访问令牌 let access_token = self.get_access_token().await?; + // 创建客户端配置 + let client_config = self.create_gemini_client(&access_token); + // 格式化GCS URI let formatted_uri = self.format_gcs_uri(file_uri); println!("🔗 格式化后的URI: {}", formatted_uri); - // 准备请求数据 + // 准备请求数据,参考demo.py实现 let request_data = GenerateContentRequest { contents: vec![ContentPart { role: "user".to_string(), @@ -290,28 +322,62 @@ impl GeminiService { ], }], generation_config: GenerationConfig { - temperature: 0.1, + temperature: self.config.temperature, top_k: 32, top_p: 1.0, - max_output_tokens: 2048, + max_output_tokens: self.config.max_tokens, }, }; println!("📦 请求数据: {}", serde_json::to_string_pretty(&request_data).unwrap_or_default()); - // 发送生成请求 - let generate_url = format!("{}/google/vertex-ai/generate", self.config.base_url); + // 发送请求到Cloudflare Gateway,参考demo.py + let generate_url = format!("{}:generateContent", client_config.gateway_url); println!("📡 生成URL: {}", generate_url); - let response = self.client - .post(&generate_url) - .header("Authorization", format!("Bearer {}", access_token)) - .header("x-google-api-key", &access_token) - .header("Content-Type", "application/json") - .json(&request_data) - .send() - .await?; + // 重试机制 + let mut last_error = None; + for attempt in 0..self.config.max_retries { + println!("🔄 尝试 {}/{}", attempt + 1, self.config.max_retries); + match self.send_generate_request(&generate_url, &client_config, &request_data).await { + Ok(result) => { + println!("✅ 成功获取Gemini分析结果"); + return self.parse_gemini_response_content(&result); + } + Err(e) => { + last_error = Some(e); + println!("⚠️ 尝试 {}/{} 失败: {}", attempt + 1, self.config.max_retries, last_error.as_ref().unwrap()); + + if attempt < self.config.max_retries - 1 { + println!("⏳ 等待 {} 秒后重试...", self.config.retry_delay); + tokio::time::sleep(tokio::time::Duration::from_secs(self.config.retry_delay)).await; + } + } + } + } + + Err(anyhow!("内容生成失败,已重试{}次: {}", self.config.max_retries, last_error.unwrap())) + } + + /// 发送生成请求 + async fn send_generate_request( + &self, + url: &str, + client_config: &ClientConfig, + request_data: &GenerateContentRequest, + ) -> Result { + let mut request_builder = self.client + .post(url) + .timeout(tokio::time::Duration::from_secs(self.config.timeout)) + .json(request_data); + + // 添加请求头 + for (key, value) in &client_config.headers { + request_builder = request_builder.header(key, value); + } + + let response = request_builder.send().await?; let status = response.status(); let headers = response.headers().clone(); @@ -321,7 +387,7 @@ impl GeminiService { if !status.is_success() { let error_text = response.text().await.unwrap_or_default(); println!("❌ 生成失败响应体: {}", error_text); - return Err(anyhow!("生成内容分析失败: {} - {}", status, error_text)); + return Err(anyhow!("API请求失败: {} - {}", status, error_text)); } let response_text = response.text().await?; @@ -330,6 +396,15 @@ impl GeminiService { let gemini_response: GeminiResponse = serde_json::from_str(&response_text) .map_err(|e| anyhow!("解析生成响应失败: {} - 响应内容: {}", e, response_text))?; + if gemini_response.candidates.is_empty() { + return Err(anyhow!("API返回结果为空")); + } + + Ok(gemini_response) + } + + /// 解析Gemini响应内容 + fn parse_gemini_response_content(&self, gemini_response: &GeminiResponse) -> Result { if let Some(candidate) = gemini_response.candidates.first() { if let Some(part) = candidate.content.parts.first() { println!("✅ 内容分析完成,响应长度: {} 字符", part.text.len());