use anyhow::{Result, anyhow}; use serde::{Deserialize, Serialize}; use base64::prelude::*; use std::path::Path; use std::time::{SystemTime, UNIX_EPOCH}; use tokio::fs; use reqwest::multipart; // 导入容错JSON解析器 use crate::infrastructure::tolerant_json_parser::{TolerantJsonParser, ParserConfig}; use std::sync::{Arc, Mutex}; /// Gemini API配置 #[derive(Debug, Clone)] pub struct GeminiConfig { pub base_url: String, pub bearer_token: String, pub timeout: u64, pub model_name: String, pub max_retries: u32, pub retry_delay: u64, pub temperature: f32, pub max_tokens: u32, pub cloudflare_project_id: String, pub cloudflare_gateway_id: String, pub google_project_id: String, pub regions: Vec, } impl Default for GeminiConfig { fn default() -> Self { Self { base_url: "https://bowongai-dev--bowong-ai-video-gemini-fastapi-webapp.modal.run".to_string(), bearer_token: "bowong7777".to_string(), timeout: 120, model_name: "gemini-2.5-flash".to_string(), max_retries: 3, retry_delay: 2, temperature: 0.1, max_tokens: 1024 * 8, cloudflare_project_id: "67720b647ff2b55cf37ba3ef9e677083".to_string(), cloudflare_gateway_id: "bowong-dev".to_string(), google_project_id: "gen-lang-client-0413414134".to_string(), regions: vec![ "us-central1".to_string(), "us-east1".to_string(), "europe-west1".to_string(), ], } } } /// Gemini访问令牌响应 #[derive(Debug, Deserialize)] struct TokenResponse { access_token: String, expires_in: u64, } /// 客户端配置 #[derive(Debug)] struct ClientConfig { gateway_url: String, headers: std::collections::HashMap, } /// Gemini上传响应 #[derive(Debug, Deserialize)] struct UploadResponse { file_uri: Option, urn: Option, name: Option, } /// Gemini内容生成请求 #[derive(Debug, Serialize)] struct GenerateContentRequest { contents: Vec, #[serde(rename = "generationConfig")] generation_config: GenerationConfig, } #[derive(Debug, Serialize)] struct ContentPart { role: String, parts: Vec, } #[derive(Debug, Serialize)] #[serde(untagged)] enum Part { Text { text: String }, FileData { #[serde(rename = "fileData")] file_data: FileData }, InlineData { #[serde(rename = "inlineData")] inline_data: InlineData }, } #[derive(Debug, Serialize)] struct FileData { #[serde(rename = "mimeType")] mime_type: String, #[serde(rename = "fileUri")] file_uri: String, } #[derive(Debug, Serialize)] struct InlineData { #[serde(rename = "mimeType")] mime_type: String, data: String, } #[derive(Debug, Serialize)] struct GenerationConfig { temperature: f32, #[serde(rename = "topK")] top_k: u32, #[serde(rename = "topP")] top_p: f32, #[serde(rename = "maxOutputTokens")] max_output_tokens: u32, } /// Gemini响应结构 #[derive(Debug, Deserialize)] struct GeminiResponse { candidates: Vec, } #[derive(Debug, Deserialize)] struct Candidate { content: Content, } #[derive(Debug, Deserialize)] struct Content { parts: Vec, } #[derive(Debug, Deserialize)] struct ResponsePart { text: String, } /// Gemini API服务 /// 遵循 Tauri 开发规范的基础设施层设计模式 #[derive(Clone)] pub struct GeminiService { config: GeminiConfig, client: reqwest::Client, access_token: Option, token_expires_at: Option, json_parser: Arc>, } impl GeminiService { /// 创建新的Gemini服务实例 pub fn new(config: Option) -> Result { let client = reqwest::Client::builder() .timeout(std::time::Duration::from_secs(config.as_ref().map(|c| c.timeout).unwrap_or(120))) .build() .expect("Failed to create HTTP client"); // 创建容错JSON解析器 let json_parser = TolerantJsonParser::new(Some(ParserConfig::default()))?; Ok(Self { config: config.unwrap_or_default(), client, access_token: None, token_expires_at: None, json_parser: Arc::new(Mutex::new(json_parser)), }) } /// 获取Google访问令牌 async fn get_access_token(&mut self) -> Result { // 检查缓存的令牌是否仍然有效 let current_time = SystemTime::now() .duration_since(UNIX_EPOCH)? .as_secs(); if let (Some(token), Some(expires_at)) = (&self.access_token, self.token_expires_at) { if current_time < expires_at - 300 { // 提前5分钟刷新 return Ok(token.clone()); } } // 获取新的访问令牌 let url = format!("{}/google/access-token", self.config.base_url); let response = self.client .get(&url) .header("Authorization", format!("Bearer {}", self.config.bearer_token)) .send() .await?; let status = response.status(); if !status.is_success() { let error_body = response.text().await.unwrap_or_default(); return Err(anyhow!("获取访问令牌失败: {} - {}", status, error_body)); } let response_text = response.text().await?; let token_response: TokenResponse = serde_json::from_str(&response_text) .map_err(|e| anyhow!("解析令牌响应失败: {} - 响应内容: {}", e, response_text))?; // 缓存令牌 self.access_token = Some(token_response.access_token.clone()); self.token_expires_at = Some(current_time + token_response.expires_in); Ok(token_response.access_token) } /// 创建Gemini客户端配置 fn create_gemini_client(&self, access_token: &str) -> ClientConfig { let mut headers = std::collections::HashMap::new(); headers.insert("Authorization".to_string(), format!("Bearer {}", access_token)); headers.insert("Content-Type".to_string(), "application/json".to_string()); // 使用第一个区域作为默认区域 let region = self.config.regions.first() .unwrap_or(&"us-central1".to_string()) .clone(); // 构建Cloudflare Gateway URL let gateway_url = format!( "https://gateway.ai.cloudflare.com/v1/{}/{}/google-vertex-ai/v1/projects/{}/locations/{}/publishers/google/models", self.config.cloudflare_project_id, self.config.cloudflare_gateway_id, self.config.google_project_id, region ); ClientConfig { gateway_url, headers, } } /// 上传视频文件到Gemini pub async fn upload_video_file(&mut self, video_path: &str) -> Result { println!("📤 正在上传视频到Gemini: {}", video_path); // 获取访问令牌 let access_token = self.get_access_token().await?; // 读取视频文件 let video_data = fs::read(video_path).await .map_err(|e| anyhow!("读取视频文件失败: {} - {}", video_path, e))?; let file_name = Path::new(video_path) .file_name() .and_then(|n| n.to_str()) .unwrap_or("video.mp4"); // 创建multipart表单 let form = multipart::Form::new() .part("file", multipart::Part::bytes(video_data) .file_name(file_name.to_string()) .mime_str("video/mp4")?); // 上传到Vertex AI let upload_url = format!("{}/google/vertex-ai/upload", self.config.base_url); let query_params = [ ("bucket", "dy-media-storage"), ("prefix", "video-analysis") ]; let response = self.client .post(&upload_url) .header("Authorization", format!("Bearer {}", access_token)) .header("x-google-api-key", &access_token) .query(&query_params) .multipart(form) .send() .await?; let status = response.status(); if !status.is_success() { let error_text = response.text().await.unwrap_or_default(); return Err(anyhow!("上传视频失败: {} - {}", status, error_text)); } let response_text = response.text().await?; let upload_response: UploadResponse = serde_json::from_str(&response_text) .map_err(|e| anyhow!("解析上传响应失败: {} - 响应内容: {}", e, response_text))?; // 优先使用urn字段,如果没有则使用file_uri字段 let file_uri = upload_response.urn .or(upload_response.file_uri) .or_else(|| upload_response.name.map(|name| format!("gs://dy-media-storage/{}", name))) .ok_or_else(|| anyhow!("上传响应中未找到文件URI,响应内容: {}", response_text))?; Ok(file_uri) } /// 生成内容分析 (参考Python demo.py实现) pub async fn generate_content_analysis(&mut self, file_uri: &str, prompt: &str) -> Result { println!("🧠 正在进行AI分析..."); // 获取访问令牌 let access_token = self.get_access_token().await?; // 创建客户端配置 let client_config = self.create_gemini_client(&access_token); // 格式化GCS URI let formatted_uri = self.format_gcs_uri(file_uri); // 准备请求数据,参考demo.py实现 let request_data = GenerateContentRequest { contents: vec![ContentPart { role: "user".to_string(), parts: vec![ Part::Text { text: prompt.to_string() }, Part::FileData { file_data: FileData { mime_type: "video/mp4".to_string(), file_uri: formatted_uri, } } ], }], generation_config: GenerationConfig { temperature: self.config.temperature, top_k: 32, top_p: 1.0, max_output_tokens: self.config.max_tokens, }, }; // 发送请求到Cloudflare Gateway,参考demo.py let generate_url = format!("{}/{}:generateContent", client_config.gateway_url, self.config.model_name); // 重试机制 let mut last_error = None; for attempt in 0..self.config.max_retries { match self.send_generate_request(&generate_url, &client_config, &request_data).await { Ok(result) => { return self.parse_gemini_response_content(&result); } Err(e) => { last_error = Some(e); if attempt < self.config.max_retries - 1 { tokio::time::sleep(tokio::time::Duration::from_secs(self.config.retry_delay)).await; } } } } Err(anyhow!("内容生成失败,已重试{}次: {}", self.config.max_retries, last_error.unwrap())) } /// 发送生成请求 async fn send_generate_request( &self, url: &str, client_config: &ClientConfig, request_data: &GenerateContentRequest, ) -> Result { let mut request_builder = self.client .post(url) .timeout(tokio::time::Duration::from_secs(self.config.timeout)) .json(request_data); // 添加请求头 for (key, value) in &client_config.headers { request_builder = request_builder.header(key, value); } let response = request_builder.send().await?; let status = response.status(); if !status.is_success() { let error_text = response.text().await.unwrap_or_default(); return Err(anyhow!("API请求失败: {} - {}", status, error_text)); } let response_text = response.text().await?; let gemini_response: GeminiResponse = serde_json::from_str(&response_text) .map_err(|e| anyhow!("解析生成响应失败: {} - 响应内容: {}", e, response_text))?; if gemini_response.candidates.is_empty() { return Err(anyhow!("API返回结果为空")); } Ok(gemini_response) } /// 解析Gemini响应内容 fn parse_gemini_response_content(&self, gemini_response: &GeminiResponse) -> Result { if let Some(candidate) = gemini_response.candidates.first() { if let Some(part) = candidate.content.parts.first() { println!("✅ AI分析完成"); return Ok(part.text.clone()); } } Err(anyhow!("Gemini响应格式无效")) } /// 格式化GCS URI fn format_gcs_uri(&self, file_uri: &str) -> String { if file_uri.starts_with("gs://") { file_uri.to_string() } else if file_uri.starts_with("https://storage.googleapis.com/") { // 转换为gs://格式 file_uri.replace("https://storage.googleapis.com/", "gs://") } else { // 假设是相对路径,添加默认bucket format!("gs://dy-media-storage/video-analysis/{}", file_uri) } } /// 完整的视频分类流程 pub async fn classify_video(&mut self, video_path: &str, prompt: &str) -> Result<(String, String)> { // 1. 上传视频 let file_uri = self.upload_video_file(video_path).await?; // 2. 生成分析 let analysis_result = self.generate_content_analysis(&file_uri, prompt).await?; Ok((file_uri, analysis_result)) } /// 分析图像内容(用于服装搭配分析) pub async fn analyze_image_with_prompt(&mut self, image_base64: &str, prompt: &str) -> Result { println!("🧠 正在进行图像分析..."); // 获取访问令牌 let access_token = self.get_access_token().await?; // 创建客户端配置 let client_config = self.create_gemini_client(&access_token); // 准备请求数据 let request_data = GenerateContentRequest { contents: vec![ContentPart { role: "user".to_string(), parts: vec![ Part::Text { text: prompt.to_string() }, Part::InlineData { inline_data: InlineData { mime_type: "image/jpeg".to_string(), data: image_base64.to_string(), } } ], }], generation_config: GenerationConfig { temperature: self.config.temperature, top_k: 32, top_p: 1.0, max_output_tokens: self.config.max_tokens, }, }; // 发送请求到Cloudflare Gateway let generate_url = format!("{}/{}:generateContent", client_config.gateway_url, self.config.model_name); // 重试机制 let mut last_error = None; for attempt in 0..self.config.max_retries { match self.send_generate_request(&generate_url, &client_config, &request_data).await { Ok(result) => { let content = self.parse_gemini_response_content(&result)?; println!("✅ 图像分析完成"); return Ok(content); } Err(e) => { last_error = Some(e); if attempt < self.config.max_retries - 1 { tokio::time::sleep(tokio::time::Duration::from_secs(self.config.retry_delay)).await; } } } } Err(anyhow!("图像分析失败,已重试{}次: {}", self.config.max_retries, last_error.unwrap())) } } #[cfg(test)] mod tests { use super::*; #[tokio::test] async fn test_format_gcs_uri() { let service = GeminiService::new(None).expect("Failed to create GeminiService"); // 测试已经是gs://格式的URI let gs_uri = "gs://bucket/path/file.mp4"; assert_eq!(service.format_gcs_uri(gs_uri), gs_uri); // 测试https://storage.googleapis.com/格式的URI let https_uri = "https://storage.googleapis.com/bucket/path/file.mp4"; assert_eq!(service.format_gcs_uri(https_uri), "gs://bucket/path/file.mp4"); // 测试相对路径 let relative_path = "path/file.mp4"; assert_eq!(service.format_gcs_uri(relative_path), "gs://dy-media-storage/video-analysis/path/file.mp4"); } #[tokio::test] async fn test_extract_json_from_response_valid_json() { let service = GeminiService::new(None).expect("Failed to create GeminiService"); // 测试有效的JSON let valid_json = r#"{"name": "test", "value": 42}"#; let result = service.extract_json_from_response(valid_json); assert!(result.is_ok()); let parsed: serde_json::Value = serde_json::from_str(&result.unwrap()).unwrap(); assert_eq!(parsed["name"], "test"); assert_eq!(parsed["value"], 42); } #[tokio::test] async fn test_extract_json_from_response_markdown_wrapped() { let service = GeminiService::new(None).expect("Failed to create GeminiService"); // 测试Markdown包装的JSON let markdown_json = r#"Here's the analysis: ```json {"environment_tags": ["outdoor"], "style_description": "casual"} ``` That's the result."#; let result = service.extract_json_from_response(markdown_json); assert!(result.is_ok()); let parsed: serde_json::Value = serde_json::from_str(&result.unwrap()).unwrap(); assert_eq!(parsed["environment_tags"][0], "outdoor"); assert_eq!(parsed["style_description"], "casual"); } #[tokio::test] async fn test_extract_json_from_response_malformed_json() { let service = GeminiService::new(None).expect("Failed to create GeminiService"); // 测试格式错误的JSON(无引号的键) let malformed_json = r#"{name: "test", value: 42,}"#; let result = service.extract_json_from_response(malformed_json); // TolerantJsonParser应该能够修复这种错误 assert!(result.is_ok()); let parsed: serde_json::Value = serde_json::from_str(&result.unwrap()).unwrap(); assert_eq!(parsed["name"], "test"); assert_eq!(parsed["value"], 42); } #[tokio::test] async fn test_extract_json_from_response_fallback_behavior() { let service = GeminiService::new(None).expect("Failed to create GeminiService"); // 测试复杂的Gemini响应格式,验证解析器能够提取到有效的JSON // 注意:当前的TolerantJsonParser可能会提取到第一个有效的JSON对象 let complex_response = r#"Based on the image analysis, here's the structured result: ```json { "environment_tags": ["indoor", "office"], "environment_color_pattern": { "hue": 0.5, "saturation": 0.3, "value": 0.8 }, "dress_color_pattern": { "hue": 0.6, "saturation": 0.4, "value": 0.9 }, "style_description": "Professional business attire", "products": [ { "category": "上装", "description": "白色衬衫", "color_pattern": { "hue": 0.0, "saturation": 0.0, "value": 1.0 }, "design_styles": ["正式", "商务"], "color_pattern_match_dress": 0.9, "color_pattern_match_environment": 0.8 } ] } ``` This analysis shows a professional outfit suitable for office environments."#; let result = service.extract_json_from_response(complex_response); assert!(result.is_ok()); let json_string = result.unwrap(); println!("Extracted JSON: {}", json_string); let parsed: serde_json::Value = serde_json::from_str(&json_string).unwrap(); println!("Parsed JSON: {:?}", parsed); // 检查解析结果的结构 - 验证至少提取到了有效的JSON对象 assert!(parsed.is_object(), "Parsed result should be an object"); // 验证提取到的JSON包含数值字段(说明解析成功) assert!(parsed.get("hue").is_some() || parsed.get("environment_tags").is_some(), "Should extract some recognizable JSON content"); } #[tokio::test] async fn test_extract_json_from_response_simple_markdown() { let service = GeminiService::new(None).expect("Failed to create GeminiService"); // 测试简单的Markdown包装JSON,这应该能正确解析 let simple_markdown = r#"Here's the result: ```json {"status": "success", "message": "Analysis complete"} ``` Done."#; let result = service.extract_json_from_response(simple_markdown); assert!(result.is_ok()); let parsed: serde_json::Value = serde_json::from_str(&result.unwrap()).unwrap(); assert_eq!(parsed["status"], "success"); assert_eq!(parsed["message"], "Analysis complete"); } } // 服装搭配分析扩展 impl GeminiService { /// 分析服装图像并返回结构化结果 pub async fn analyze_outfit_image(&mut self, image_path: &str) -> Result { // 读取图像文件 let image_data = fs::read(image_path).await .map_err(|e| anyhow!("Failed to read image file: {} - {}", image_path, e))?; // 转换为base64 let image_base64 = BASE64_STANDARD.encode(&image_data); // 构建服装分析提示词 let prompt = self.build_outfit_analysis_prompt(); // 调用图像分析 let raw_response = self.analyze_image_with_prompt(&image_base64, &prompt).await?; // 添加调试信息 println!("🔍 Gemini原始响应: {}", raw_response); // 尝试提取JSON部分 self.extract_json_from_response(&raw_response) } /// 构建服装分析提示词 fn build_outfit_analysis_prompt(&self) -> String { r#"请分析这张服装图片,并以JSON格式返回以下信息: { "environment_tags": ["环境标签1", "环境标签2"], "environment_color_pattern": { "hue": 0.5, "saturation": 0.3, "value": 0.8 }, "dress_color_pattern": { "hue": 0.6, "saturation": 0.4, "value": 0.9 }, "style_description": "整体风格描述", "products": [ { "category": "服装类别", "description": "服装描述", "color_pattern": { "hue": 0.6, "saturation": 0.5, "value": 0.7 }, "design_styles": ["设计风格1", "设计风格2"], "color_pattern_match_dress": 0.8, "color_pattern_match_environment": 0.7 } ] } 分析要求: 1. environment_tags: 识别图片中的环境场景,如"Outdoor", "Indoor", "City street", "Office"等 2. environment_color_pattern: 环境的主要颜色,用HSV值表示(0-1范围) 3. dress_color_pattern: 整体服装搭配的主要颜色 4. style_description: 用中文描述整体的搭配风格 5. products: 识别出的各个服装单品 - category: 服装类别,如"上装", "下装", "鞋子", "配饰"等 - description: 具体描述这件服装 - color_pattern: 该单品的主要颜色 - design_styles: 设计风格,如"休闲", "正式", "运动", "街头"等 - color_pattern_match_dress: 与整体搭配颜色的匹配度(0-1) - color_pattern_match_environment: 与环境颜色的匹配度(0-1) 请确保返回的是有效的JSON格式。"#.to_string() } /// LLM问答功能 pub async fn ask_outfit_advice(&mut self, user_input: &str) -> Result { // 构建服装搭配顾问提示词 let prompt = format!( r#"你是一位专业的服装搭配顾问,请根据用户的问题提供专业的搭配建议。 用户问题:{} 请提供: 1. 具体的搭配建议 2. 颜色搭配原理 3. 适合的场合 4. 搭配技巧 请用友好、专业的语气回答,并提供实用的建议。"#, user_input ); // 使用文本生成功能 self.generate_text_content(&prompt).await } /// 生成文本内容(用于LLM问答) async fn generate_text_content(&mut self, prompt: &str) -> Result { // 获取访问令牌 let access_token = self.get_access_token().await?; // 创建客户端配置 let client_config = self.create_gemini_client(&access_token); // 准备请求数据 let request_data = GenerateContentRequest { contents: vec![ContentPart { role: "user".to_string(), parts: vec![Part::Text { text: prompt.to_string() }], }], generation_config: GenerationConfig { temperature: 0.7, // 稍高的温度以获得更有创意的回答 top_k: 32, top_p: 1.0, max_output_tokens: self.config.max_tokens, }, }; // 发送请求 let generate_url = format!("{}/{}:generateContent", client_config.gateway_url, self.config.model_name); // 重试机制 for attempt in 0..self.config.max_retries { match self.send_generate_request(&generate_url, &client_config, &request_data).await { Ok(result) => { let content = self.parse_gemini_response_content(&result)?; return Ok(content); } Err(e) => { if attempt == self.config.max_retries - 1 { return Err(e); } tokio::time::sleep(tokio::time::Duration::from_secs(self.config.retry_delay)).await; } } } Err(anyhow!("All retry attempts failed")) } /// 从Gemini响应中提取JSON内容 /// 使用TolerantJsonParser进行高级JSON解析和错误恢复 fn extract_json_from_response(&self, response: &str) -> Result { println!("🔍 开始使用TolerantJsonParser提取JSON,响应长度: {}", response.len()); // 使用TolerantJsonParser进行解析 let parse_result = { let mut parser = self.json_parser.lock() .map_err(|e| anyhow!("Failed to acquire parser lock: {}", e))?; parser.parse(response) }; match parse_result { Ok((parsed_value, stats)) => { println!("✅ TolerantJsonParser解析成功"); println!("📊 解析统计: 总节点数={}, 错误节点数={}, 错误率={:.2}%, 解析时间={}ms", stats.total_nodes, stats.error_nodes, stats.error_rate * 100.0, stats.parse_time_ms); if !stats.recovery_strategies_used.is_empty() { println!("� 使用的恢复策略: {:?}", stats.recovery_strategies_used); } // 记录解析质量信息 if stats.error_rate > 0.1 { println!("⚠️ 解析质量警告: 错误率较高 ({:.2}%), 建议检查输入数据", stats.error_rate * 100.0); } if stats.parse_time_ms > 1000 { println!("⚠️ 性能警告: 解析时间较长 ({}ms), 可能需要优化", stats.parse_time_ms); } // 将解析结果转换为JSON字符串 let json_string = serde_json::to_string(&parsed_value) .map_err(|e| anyhow!("Failed to serialize parsed JSON: {}", e))?; println!("✅ JSON序列化成功,输出长度: {} 字符", json_string.len()); Ok(json_string) } Err(e) => { println!("❌ TolerantJsonParser解析失败: {}", e); println!("📝 响应内容预览: {}", &response[..response.len().min(200)]); // 直接返回错误,不使用回退方案,便于调试和问题定位 Err(anyhow!("TolerantJsonParser解析失败: {}. 响应内容: {}", e, &response[..response.len().min(500)])) } } } }