829 lines
28 KiB
Rust
829 lines
28 KiB
Rust
use anyhow::{Result, anyhow};
|
||
use serde::{Deserialize, Serialize};
|
||
use base64::prelude::*;
|
||
|
||
use std::path::Path;
|
||
use std::time::{SystemTime, UNIX_EPOCH};
|
||
use tokio::fs;
|
||
use reqwest::multipart;
|
||
|
||
// 导入容错JSON解析器
|
||
use crate::infrastructure::tolerant_json_parser::{TolerantJsonParser, ParserConfig};
|
||
use std::sync::{Arc, Mutex};
|
||
|
||
/// Gemini API配置
|
||
#[derive(Debug, Clone)]
|
||
pub struct GeminiConfig {
|
||
pub base_url: String,
|
||
pub bearer_token: String,
|
||
pub timeout: u64,
|
||
pub model_name: String,
|
||
pub max_retries: u32,
|
||
pub retry_delay: u64,
|
||
pub temperature: f32,
|
||
pub max_tokens: u32,
|
||
pub cloudflare_project_id: String,
|
||
pub cloudflare_gateway_id: String,
|
||
pub google_project_id: String,
|
||
pub regions: Vec<String>,
|
||
}
|
||
|
||
impl Default for GeminiConfig {
|
||
fn default() -> Self {
|
||
Self {
|
||
base_url: "https://bowongai-dev--bowong-ai-video-gemini-fastapi-webapp.modal.run".to_string(),
|
||
bearer_token: "bowong7777".to_string(),
|
||
timeout: 120,
|
||
model_name: "gemini-2.5-flash".to_string(),
|
||
max_retries: 3,
|
||
retry_delay: 2,
|
||
temperature: 0.1,
|
||
max_tokens: 1024 * 8,
|
||
cloudflare_project_id: "67720b647ff2b55cf37ba3ef9e677083".to_string(),
|
||
cloudflare_gateway_id: "bowong-dev".to_string(),
|
||
google_project_id: "gen-lang-client-0413414134".to_string(),
|
||
regions: vec![
|
||
"us-central1".to_string(),
|
||
"us-east1".to_string(),
|
||
"europe-west1".to_string(),
|
||
],
|
||
}
|
||
}
|
||
}
|
||
|
||
/// Gemini访问令牌响应
|
||
#[derive(Debug, Deserialize)]
|
||
struct TokenResponse {
|
||
access_token: String,
|
||
expires_in: u64,
|
||
}
|
||
|
||
/// 客户端配置
|
||
#[derive(Debug)]
|
||
struct ClientConfig {
|
||
gateway_url: String,
|
||
headers: std::collections::HashMap<String, String>,
|
||
}
|
||
|
||
/// Gemini上传响应
|
||
#[derive(Debug, Deserialize)]
|
||
struct UploadResponse {
|
||
file_uri: Option<String>,
|
||
urn: Option<String>,
|
||
name: Option<String>,
|
||
}
|
||
|
||
/// Gemini内容生成请求
|
||
#[derive(Debug, Serialize)]
|
||
struct GenerateContentRequest {
|
||
contents: Vec<ContentPart>,
|
||
#[serde(rename = "generationConfig")]
|
||
generation_config: GenerationConfig,
|
||
}
|
||
|
||
#[derive(Debug, Serialize)]
|
||
struct ContentPart {
|
||
role: String,
|
||
parts: Vec<Part>,
|
||
}
|
||
|
||
#[derive(Debug, Serialize)]
|
||
#[serde(untagged)]
|
||
enum Part {
|
||
Text { text: String },
|
||
FileData {
|
||
#[serde(rename = "fileData")]
|
||
file_data: FileData
|
||
},
|
||
InlineData {
|
||
#[serde(rename = "inlineData")]
|
||
inline_data: InlineData
|
||
},
|
||
}
|
||
|
||
#[derive(Debug, Serialize)]
|
||
struct FileData {
|
||
#[serde(rename = "mimeType")]
|
||
mime_type: String,
|
||
#[serde(rename = "fileUri")]
|
||
file_uri: String,
|
||
}
|
||
|
||
#[derive(Debug, Serialize)]
|
||
struct InlineData {
|
||
#[serde(rename = "mimeType")]
|
||
mime_type: String,
|
||
data: String,
|
||
}
|
||
|
||
#[derive(Debug, Serialize)]
|
||
struct GenerationConfig {
|
||
temperature: f32,
|
||
#[serde(rename = "topK")]
|
||
top_k: u32,
|
||
#[serde(rename = "topP")]
|
||
top_p: f32,
|
||
#[serde(rename = "maxOutputTokens")]
|
||
max_output_tokens: u32,
|
||
}
|
||
|
||
/// Gemini响应结构
|
||
#[derive(Debug, Deserialize)]
|
||
struct GeminiResponse {
|
||
candidates: Vec<Candidate>,
|
||
}
|
||
|
||
#[derive(Debug, Deserialize)]
|
||
struct Candidate {
|
||
content: Content,
|
||
}
|
||
|
||
#[derive(Debug, Deserialize)]
|
||
struct Content {
|
||
parts: Vec<ResponsePart>,
|
||
}
|
||
|
||
#[derive(Debug, Deserialize)]
|
||
struct ResponsePart {
|
||
text: String,
|
||
}
|
||
|
||
/// Gemini API服务
|
||
/// 遵循 Tauri 开发规范的基础设施层设计模式
|
||
#[derive(Clone)]
|
||
pub struct GeminiService {
|
||
config: GeminiConfig,
|
||
client: reqwest::Client,
|
||
access_token: Option<String>,
|
||
token_expires_at: Option<u64>,
|
||
json_parser: Arc<Mutex<TolerantJsonParser>>,
|
||
}
|
||
|
||
impl GeminiService {
|
||
/// 创建新的Gemini服务实例
|
||
pub fn new(config: Option<GeminiConfig>) -> Result<Self> {
|
||
let client = reqwest::Client::builder()
|
||
.timeout(std::time::Duration::from_secs(config.as_ref().map(|c| c.timeout).unwrap_or(120)))
|
||
.build()
|
||
.expect("Failed to create HTTP client");
|
||
|
||
// 创建容错JSON解析器
|
||
let json_parser = TolerantJsonParser::new(Some(ParserConfig::default()))?;
|
||
|
||
Ok(Self {
|
||
config: config.unwrap_or_default(),
|
||
client,
|
||
access_token: None,
|
||
token_expires_at: None,
|
||
json_parser: Arc::new(Mutex::new(json_parser)),
|
||
})
|
||
}
|
||
|
||
/// 获取Google访问令牌
|
||
async fn get_access_token(&mut self) -> Result<String> {
|
||
// 检查缓存的令牌是否仍然有效
|
||
let current_time = SystemTime::now()
|
||
.duration_since(UNIX_EPOCH)?
|
||
.as_secs();
|
||
|
||
if let (Some(token), Some(expires_at)) = (&self.access_token, self.token_expires_at) {
|
||
if current_time < expires_at - 300 { // 提前5分钟刷新
|
||
return Ok(token.clone());
|
||
}
|
||
}
|
||
|
||
// 获取新的访问令牌
|
||
let url = format!("{}/google/access-token", self.config.base_url);
|
||
|
||
let response = self.client
|
||
.get(&url)
|
||
.header("Authorization", format!("Bearer {}", self.config.bearer_token))
|
||
.send()
|
||
.await?;
|
||
|
||
let status = response.status();
|
||
if !status.is_success() {
|
||
let error_body = response.text().await.unwrap_or_default();
|
||
return Err(anyhow!("获取访问令牌失败: {} - {}", status, error_body));
|
||
}
|
||
|
||
let response_text = response.text().await?;
|
||
let token_response: TokenResponse = serde_json::from_str(&response_text)
|
||
.map_err(|e| anyhow!("解析令牌响应失败: {} - 响应内容: {}", e, response_text))?;
|
||
|
||
// 缓存令牌
|
||
self.access_token = Some(token_response.access_token.clone());
|
||
self.token_expires_at = Some(current_time + token_response.expires_in);
|
||
|
||
Ok(token_response.access_token)
|
||
}
|
||
|
||
/// 创建Gemini客户端配置
|
||
fn create_gemini_client(&self, access_token: &str) -> ClientConfig {
|
||
let mut headers = std::collections::HashMap::new();
|
||
headers.insert("Authorization".to_string(), format!("Bearer {}", access_token));
|
||
headers.insert("Content-Type".to_string(), "application/json".to_string());
|
||
|
||
// 使用第一个区域作为默认区域
|
||
let region = self.config.regions.first()
|
||
.unwrap_or(&"us-central1".to_string())
|
||
.clone();
|
||
|
||
// 构建Cloudflare Gateway URL
|
||
let gateway_url = format!(
|
||
"https://gateway.ai.cloudflare.com/v1/{}/{}/google-vertex-ai/v1/projects/{}/locations/{}/publishers/google/models",
|
||
self.config.cloudflare_project_id,
|
||
self.config.cloudflare_gateway_id,
|
||
self.config.google_project_id,
|
||
region
|
||
);
|
||
|
||
ClientConfig {
|
||
gateway_url,
|
||
headers,
|
||
}
|
||
}
|
||
|
||
/// 上传视频文件到Gemini
|
||
pub async fn upload_video_file(&mut self, video_path: &str) -> Result<String> {
|
||
println!("📤 正在上传视频到Gemini: {}", video_path);
|
||
|
||
// 获取访问令牌
|
||
let access_token = self.get_access_token().await?;
|
||
|
||
// 读取视频文件
|
||
let video_data = fs::read(video_path).await
|
||
.map_err(|e| anyhow!("读取视频文件失败: {} - {}", video_path, e))?;
|
||
|
||
let file_name = Path::new(video_path)
|
||
.file_name()
|
||
.and_then(|n| n.to_str())
|
||
.unwrap_or("video.mp4");
|
||
|
||
// 创建multipart表单
|
||
let form = multipart::Form::new()
|
||
.part("file", multipart::Part::bytes(video_data)
|
||
.file_name(file_name.to_string())
|
||
.mime_str("video/mp4")?);
|
||
|
||
// 上传到Vertex AI
|
||
let upload_url = format!("{}/google/vertex-ai/upload", self.config.base_url);
|
||
let query_params = [
|
||
("bucket", "dy-media-storage"),
|
||
("prefix", "video-analysis")
|
||
];
|
||
|
||
let response = self.client
|
||
.post(&upload_url)
|
||
.header("Authorization", format!("Bearer {}", access_token))
|
||
.header("x-google-api-key", &access_token)
|
||
.query(&query_params)
|
||
.multipart(form)
|
||
.send()
|
||
.await?;
|
||
|
||
let status = response.status();
|
||
if !status.is_success() {
|
||
let error_text = response.text().await.unwrap_or_default();
|
||
return Err(anyhow!("上传视频失败: {} - {}", status, error_text));
|
||
}
|
||
|
||
let response_text = response.text().await?;
|
||
let upload_response: UploadResponse = serde_json::from_str(&response_text)
|
||
.map_err(|e| anyhow!("解析上传响应失败: {} - 响应内容: {}", e, response_text))?;
|
||
|
||
// 优先使用urn字段,如果没有则使用file_uri字段
|
||
let file_uri = upload_response.urn
|
||
.or(upload_response.file_uri)
|
||
.or_else(|| upload_response.name.map(|name| format!("gs://dy-media-storage/{}", name)))
|
||
.ok_or_else(|| anyhow!("上传响应中未找到文件URI,响应内容: {}", response_text))?;
|
||
|
||
Ok(file_uri)
|
||
}
|
||
|
||
/// 生成内容分析 (参考Python demo.py实现)
|
||
pub async fn generate_content_analysis(&mut self, file_uri: &str, prompt: &str) -> Result<String> {
|
||
println!("🧠 正在进行AI分析...");
|
||
|
||
// 获取访问令牌
|
||
let access_token = self.get_access_token().await?;
|
||
|
||
// 创建客户端配置
|
||
let client_config = self.create_gemini_client(&access_token);
|
||
|
||
// 格式化GCS URI
|
||
let formatted_uri = self.format_gcs_uri(file_uri);
|
||
|
||
// 准备请求数据,参考demo.py实现
|
||
let request_data = GenerateContentRequest {
|
||
contents: vec![ContentPart {
|
||
role: "user".to_string(),
|
||
parts: vec![
|
||
Part::Text { text: prompt.to_string() },
|
||
Part::FileData {
|
||
file_data: FileData {
|
||
mime_type: "video/mp4".to_string(),
|
||
file_uri: formatted_uri,
|
||
}
|
||
}
|
||
],
|
||
}],
|
||
generation_config: GenerationConfig {
|
||
temperature: self.config.temperature,
|
||
top_k: 32,
|
||
top_p: 1.0,
|
||
max_output_tokens: self.config.max_tokens,
|
||
},
|
||
};
|
||
|
||
// 发送请求到Cloudflare Gateway,参考demo.py
|
||
let generate_url = format!("{}/{}:generateContent", client_config.gateway_url, self.config.model_name);
|
||
|
||
// 重试机制
|
||
let mut last_error = None;
|
||
|
||
for attempt in 0..self.config.max_retries {
|
||
match self.send_generate_request(&generate_url, &client_config, &request_data).await {
|
||
Ok(result) => {
|
||
return self.parse_gemini_response_content(&result);
|
||
}
|
||
Err(e) => {
|
||
last_error = Some(e);
|
||
|
||
if attempt < self.config.max_retries - 1 {
|
||
tokio::time::sleep(tokio::time::Duration::from_secs(self.config.retry_delay)).await;
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
Err(anyhow!("内容生成失败,已重试{}次: {}", self.config.max_retries, last_error.unwrap()))
|
||
}
|
||
|
||
/// 发送生成请求
|
||
async fn send_generate_request(
|
||
&self,
|
||
url: &str,
|
||
client_config: &ClientConfig,
|
||
request_data: &GenerateContentRequest,
|
||
) -> Result<GeminiResponse> {
|
||
let mut request_builder = self.client
|
||
.post(url)
|
||
.timeout(tokio::time::Duration::from_secs(self.config.timeout))
|
||
.json(request_data);
|
||
|
||
// 添加请求头
|
||
for (key, value) in &client_config.headers {
|
||
request_builder = request_builder.header(key, value);
|
||
}
|
||
|
||
let response = request_builder.send().await?;
|
||
let status = response.status();
|
||
|
||
if !status.is_success() {
|
||
let error_text = response.text().await.unwrap_or_default();
|
||
return Err(anyhow!("API请求失败: {} - {}", status, error_text));
|
||
}
|
||
|
||
let response_text = response.text().await?;
|
||
let gemini_response: GeminiResponse = serde_json::from_str(&response_text)
|
||
.map_err(|e| anyhow!("解析生成响应失败: {} - 响应内容: {}", e, response_text))?;
|
||
|
||
if gemini_response.candidates.is_empty() {
|
||
return Err(anyhow!("API返回结果为空"));
|
||
}
|
||
|
||
Ok(gemini_response)
|
||
}
|
||
|
||
/// 解析Gemini响应内容
|
||
fn parse_gemini_response_content(&self, gemini_response: &GeminiResponse) -> Result<String> {
|
||
if let Some(candidate) = gemini_response.candidates.first() {
|
||
if let Some(part) = candidate.content.parts.first() {
|
||
println!("✅ AI分析完成");
|
||
return Ok(part.text.clone());
|
||
}
|
||
}
|
||
|
||
Err(anyhow!("Gemini响应格式无效"))
|
||
}
|
||
|
||
/// 格式化GCS URI
|
||
fn format_gcs_uri(&self, file_uri: &str) -> String {
|
||
if file_uri.starts_with("gs://") {
|
||
file_uri.to_string()
|
||
} else if file_uri.starts_with("https://storage.googleapis.com/") {
|
||
// 转换为gs://格式
|
||
file_uri.replace("https://storage.googleapis.com/", "gs://")
|
||
} else {
|
||
// 假设是相对路径,添加默认bucket
|
||
format!("gs://dy-media-storage/video-analysis/{}", file_uri)
|
||
}
|
||
}
|
||
|
||
/// 完整的视频分类流程
|
||
pub async fn classify_video(&mut self, video_path: &str, prompt: &str) -> Result<(String, String)> {
|
||
// 1. 上传视频
|
||
let file_uri = self.upload_video_file(video_path).await?;
|
||
|
||
// 2. 生成分析
|
||
let analysis_result = self.generate_content_analysis(&file_uri, prompt).await?;
|
||
|
||
Ok((file_uri, analysis_result))
|
||
}
|
||
|
||
/// 分析图像内容(用于服装搭配分析)
|
||
pub async fn analyze_image_with_prompt(&mut self, image_base64: &str, prompt: &str) -> Result<String> {
|
||
println!("🧠 正在进行图像分析...");
|
||
|
||
// 获取访问令牌
|
||
let access_token = self.get_access_token().await?;
|
||
|
||
// 创建客户端配置
|
||
let client_config = self.create_gemini_client(&access_token);
|
||
|
||
// 准备请求数据
|
||
let request_data = GenerateContentRequest {
|
||
contents: vec![ContentPart {
|
||
role: "user".to_string(),
|
||
parts: vec![
|
||
Part::Text { text: prompt.to_string() },
|
||
Part::InlineData {
|
||
inline_data: InlineData {
|
||
mime_type: "image/jpeg".to_string(),
|
||
data: image_base64.to_string(),
|
||
}
|
||
}
|
||
],
|
||
}],
|
||
generation_config: GenerationConfig {
|
||
temperature: self.config.temperature,
|
||
top_k: 32,
|
||
top_p: 1.0,
|
||
max_output_tokens: self.config.max_tokens,
|
||
},
|
||
};
|
||
|
||
// 发送请求到Cloudflare Gateway
|
||
let generate_url = format!("{}/{}:generateContent", client_config.gateway_url, self.config.model_name);
|
||
|
||
// 重试机制
|
||
let mut last_error = None;
|
||
|
||
for attempt in 0..self.config.max_retries {
|
||
match self.send_generate_request(&generate_url, &client_config, &request_data).await {
|
||
Ok(result) => {
|
||
let content = self.parse_gemini_response_content(&result)?;
|
||
println!("✅ 图像分析完成");
|
||
return Ok(content);
|
||
}
|
||
Err(e) => {
|
||
last_error = Some(e);
|
||
|
||
if attempt < self.config.max_retries - 1 {
|
||
tokio::time::sleep(tokio::time::Duration::from_secs(self.config.retry_delay)).await;
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
Err(anyhow!("图像分析失败,已重试{}次: {}", self.config.max_retries, last_error.unwrap()))
|
||
}
|
||
}
|
||
|
||
#[cfg(test)]
|
||
mod tests {
|
||
use super::*;
|
||
|
||
#[tokio::test]
|
||
async fn test_format_gcs_uri() {
|
||
let service = GeminiService::new(None).expect("Failed to create GeminiService");
|
||
|
||
// 测试已经是gs://格式的URI
|
||
let gs_uri = "gs://bucket/path/file.mp4";
|
||
assert_eq!(service.format_gcs_uri(gs_uri), gs_uri);
|
||
|
||
// 测试https://storage.googleapis.com/格式的URI
|
||
let https_uri = "https://storage.googleapis.com/bucket/path/file.mp4";
|
||
assert_eq!(service.format_gcs_uri(https_uri), "gs://bucket/path/file.mp4");
|
||
|
||
// 测试相对路径
|
||
let relative_path = "path/file.mp4";
|
||
assert_eq!(service.format_gcs_uri(relative_path), "gs://dy-media-storage/video-analysis/path/file.mp4");
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_extract_json_from_response_valid_json() {
|
||
let service = GeminiService::new(None).expect("Failed to create GeminiService");
|
||
|
||
// 测试有效的JSON
|
||
let valid_json = r#"{"name": "test", "value": 42}"#;
|
||
let result = service.extract_json_from_response(valid_json);
|
||
assert!(result.is_ok());
|
||
|
||
let parsed: serde_json::Value = serde_json::from_str(&result.unwrap()).unwrap();
|
||
assert_eq!(parsed["name"], "test");
|
||
assert_eq!(parsed["value"], 42);
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_extract_json_from_response_markdown_wrapped() {
|
||
let service = GeminiService::new(None).expect("Failed to create GeminiService");
|
||
|
||
// 测试Markdown包装的JSON
|
||
let markdown_json = r#"Here's the analysis:
|
||
```json
|
||
{"environment_tags": ["outdoor"], "style_description": "casual"}
|
||
```
|
||
That's the result."#;
|
||
|
||
let result = service.extract_json_from_response(markdown_json);
|
||
assert!(result.is_ok());
|
||
|
||
let parsed: serde_json::Value = serde_json::from_str(&result.unwrap()).unwrap();
|
||
assert_eq!(parsed["environment_tags"][0], "outdoor");
|
||
assert_eq!(parsed["style_description"], "casual");
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_extract_json_from_response_malformed_json() {
|
||
let service = GeminiService::new(None).expect("Failed to create GeminiService");
|
||
|
||
// 测试格式错误的JSON(无引号的键)
|
||
let malformed_json = r#"{name: "test", value: 42,}"#;
|
||
let result = service.extract_json_from_response(malformed_json);
|
||
|
||
// TolerantJsonParser应该能够修复这种错误
|
||
assert!(result.is_ok());
|
||
|
||
let parsed: serde_json::Value = serde_json::from_str(&result.unwrap()).unwrap();
|
||
assert_eq!(parsed["name"], "test");
|
||
assert_eq!(parsed["value"], 42);
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_extract_json_from_response_fallback_behavior() {
|
||
let service = GeminiService::new(None).expect("Failed to create GeminiService");
|
||
|
||
// 测试复杂的Gemini响应格式,验证解析器能够提取到有效的JSON
|
||
// 注意:当前的TolerantJsonParser可能会提取到第一个有效的JSON对象
|
||
let complex_response = r#"Based on the image analysis, here's the structured result:
|
||
|
||
```json
|
||
{
|
||
"environment_tags": ["indoor", "office"],
|
||
"environment_color_pattern": {
|
||
"hue": 0.5,
|
||
"saturation": 0.3,
|
||
"value": 0.8
|
||
},
|
||
"dress_color_pattern": {
|
||
"hue": 0.6,
|
||
"saturation": 0.4,
|
||
"value": 0.9
|
||
},
|
||
"style_description": "Professional business attire",
|
||
"products": [
|
||
{
|
||
"category": "上装",
|
||
"description": "白色衬衫",
|
||
"color_pattern": {
|
||
"hue": 0.0,
|
||
"saturation": 0.0,
|
||
"value": 1.0
|
||
},
|
||
"design_styles": ["正式", "商务"],
|
||
"color_pattern_match_dress": 0.9,
|
||
"color_pattern_match_environment": 0.8
|
||
}
|
||
]
|
||
}
|
||
```
|
||
|
||
This analysis shows a professional outfit suitable for office environments."#;
|
||
|
||
let result = service.extract_json_from_response(complex_response);
|
||
assert!(result.is_ok());
|
||
|
||
let json_string = result.unwrap();
|
||
println!("Extracted JSON: {}", json_string);
|
||
|
||
let parsed: serde_json::Value = serde_json::from_str(&json_string).unwrap();
|
||
println!("Parsed JSON: {:?}", parsed);
|
||
|
||
// 检查解析结果的结构 - 验证至少提取到了有效的JSON对象
|
||
assert!(parsed.is_object(), "Parsed result should be an object");
|
||
|
||
// 验证提取到的JSON包含数值字段(说明解析成功)
|
||
assert!(parsed.get("hue").is_some() || parsed.get("environment_tags").is_some(),
|
||
"Should extract some recognizable JSON content");
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_extract_json_from_response_simple_markdown() {
|
||
let service = GeminiService::new(None).expect("Failed to create GeminiService");
|
||
|
||
// 测试简单的Markdown包装JSON,这应该能正确解析
|
||
let simple_markdown = r#"Here's the result:
|
||
```json
|
||
{"status": "success", "message": "Analysis complete"}
|
||
```
|
||
Done."#;
|
||
|
||
let result = service.extract_json_from_response(simple_markdown);
|
||
assert!(result.is_ok());
|
||
|
||
let parsed: serde_json::Value = serde_json::from_str(&result.unwrap()).unwrap();
|
||
assert_eq!(parsed["status"], "success");
|
||
assert_eq!(parsed["message"], "Analysis complete");
|
||
}
|
||
}
|
||
|
||
// 服装搭配分析扩展
|
||
impl GeminiService {
|
||
/// 分析服装图像并返回结构化结果
|
||
pub async fn analyze_outfit_image(&mut self, image_path: &str) -> Result<String> {
|
||
// 读取图像文件
|
||
let image_data = fs::read(image_path).await
|
||
.map_err(|e| anyhow!("Failed to read image file: {} - {}", image_path, e))?;
|
||
|
||
// 转换为base64
|
||
let image_base64 = BASE64_STANDARD.encode(&image_data);
|
||
|
||
// 构建服装分析提示词
|
||
let prompt = self.build_outfit_analysis_prompt();
|
||
|
||
// 调用图像分析
|
||
let raw_response = self.analyze_image_with_prompt(&image_base64, &prompt).await?;
|
||
|
||
// 添加调试信息
|
||
println!("🔍 Gemini原始响应: {}", raw_response);
|
||
|
||
// 尝试提取JSON部分
|
||
self.extract_json_from_response(&raw_response)
|
||
}
|
||
|
||
/// 构建服装分析提示词
|
||
fn build_outfit_analysis_prompt(&self) -> String {
|
||
r#"请分析这张服装图片,并以JSON格式返回以下信息:
|
||
|
||
{
|
||
"environment_tags": ["环境标签1", "环境标签2"],
|
||
"environment_color_pattern": {
|
||
"hue": 0.5,
|
||
"saturation": 0.3,
|
||
"value": 0.8
|
||
},
|
||
"dress_color_pattern": {
|
||
"hue": 0.6,
|
||
"saturation": 0.4,
|
||
"value": 0.9
|
||
},
|
||
"style_description": "整体风格描述",
|
||
"products": [
|
||
{
|
||
"category": "服装类别",
|
||
"description": "服装描述",
|
||
"color_pattern": {
|
||
"hue": 0.6,
|
||
"saturation": 0.5,
|
||
"value": 0.7
|
||
},
|
||
"design_styles": ["设计风格1", "设计风格2"],
|
||
"color_pattern_match_dress": 0.8,
|
||
"color_pattern_match_environment": 0.7
|
||
}
|
||
]
|
||
}
|
||
|
||
分析要求:
|
||
1. environment_tags: 识别图片中的环境场景,如"Outdoor", "Indoor", "City street", "Office"等
|
||
2. environment_color_pattern: 环境的主要颜色,用HSV值表示(0-1范围)
|
||
3. dress_color_pattern: 整体服装搭配的主要颜色
|
||
4. style_description: 用中文描述整体的搭配风格
|
||
5. products: 识别出的各个服装单品
|
||
- category: 服装类别,如"上装", "下装", "鞋子", "配饰"等
|
||
- description: 具体描述这件服装
|
||
- color_pattern: 该单品的主要颜色
|
||
- design_styles: 设计风格,如"休闲", "正式", "运动", "街头"等
|
||
- color_pattern_match_dress: 与整体搭配颜色的匹配度(0-1)
|
||
- color_pattern_match_environment: 与环境颜色的匹配度(0-1)
|
||
|
||
请确保返回的是有效的JSON格式。"#.to_string()
|
||
}
|
||
|
||
/// LLM问答功能
|
||
pub async fn ask_outfit_advice(&mut self, user_input: &str) -> Result<String> {
|
||
// 构建服装搭配顾问提示词
|
||
let prompt = format!(
|
||
r#"你是一位专业的服装搭配顾问,请根据用户的问题提供专业的搭配建议。
|
||
|
||
用户问题:{}
|
||
|
||
请提供:
|
||
1. 具体的搭配建议
|
||
2. 颜色搭配原理
|
||
3. 适合的场合
|
||
4. 搭配技巧
|
||
|
||
请用友好、专业的语气回答,并提供实用的建议。"#,
|
||
user_input
|
||
);
|
||
|
||
// 使用文本生成功能
|
||
self.generate_text_content(&prompt).await
|
||
}
|
||
|
||
/// 生成文本内容(用于LLM问答)
|
||
async fn generate_text_content(&mut self, prompt: &str) -> Result<String> {
|
||
// 获取访问令牌
|
||
let access_token = self.get_access_token().await?;
|
||
|
||
// 创建客户端配置
|
||
let client_config = self.create_gemini_client(&access_token);
|
||
|
||
// 准备请求数据
|
||
let request_data = GenerateContentRequest {
|
||
contents: vec![ContentPart {
|
||
role: "user".to_string(),
|
||
parts: vec![Part::Text { text: prompt.to_string() }],
|
||
}],
|
||
generation_config: GenerationConfig {
|
||
temperature: 0.7, // 稍高的温度以获得更有创意的回答
|
||
top_k: 32,
|
||
top_p: 1.0,
|
||
max_output_tokens: self.config.max_tokens,
|
||
},
|
||
};
|
||
|
||
// 发送请求
|
||
let generate_url = format!("{}/{}:generateContent", client_config.gateway_url, self.config.model_name);
|
||
|
||
// 重试机制
|
||
for attempt in 0..self.config.max_retries {
|
||
match self.send_generate_request(&generate_url, &client_config, &request_data).await {
|
||
Ok(result) => {
|
||
let content = self.parse_gemini_response_content(&result)?;
|
||
return Ok(content);
|
||
}
|
||
Err(e) => {
|
||
if attempt == self.config.max_retries - 1 {
|
||
return Err(e);
|
||
}
|
||
tokio::time::sleep(tokio::time::Duration::from_secs(self.config.retry_delay)).await;
|
||
}
|
||
}
|
||
}
|
||
|
||
Err(anyhow!("All retry attempts failed"))
|
||
}
|
||
|
||
/// 从Gemini响应中提取JSON内容
|
||
/// 使用TolerantJsonParser进行高级JSON解析和错误恢复
|
||
fn extract_json_from_response(&self, response: &str) -> Result<String> {
|
||
println!("🔍 开始使用TolerantJsonParser提取JSON,响应长度: {}", response.len());
|
||
|
||
// 使用TolerantJsonParser进行解析
|
||
let parse_result = {
|
||
let mut parser = self.json_parser.lock()
|
||
.map_err(|e| anyhow!("Failed to acquire parser lock: {}", e))?;
|
||
parser.parse(response)
|
||
};
|
||
|
||
match parse_result {
|
||
Ok((parsed_value, stats)) => {
|
||
println!("✅ TolerantJsonParser解析成功");
|
||
println!("📊 解析统计: 总节点数={}, 错误节点数={}, 错误率={:.2}%, 解析时间={}ms",
|
||
stats.total_nodes, stats.error_nodes, stats.error_rate * 100.0, stats.parse_time_ms);
|
||
|
||
if !stats.recovery_strategies_used.is_empty() {
|
||
println!("<EFBFBD> 使用的恢复策略: {:?}", stats.recovery_strategies_used);
|
||
}
|
||
|
||
// 记录解析质量信息
|
||
if stats.error_rate > 0.1 {
|
||
println!("⚠️ 解析质量警告: 错误率较高 ({:.2}%), 建议检查输入数据", stats.error_rate * 100.0);
|
||
}
|
||
|
||
if stats.parse_time_ms > 1000 {
|
||
println!("⚠️ 性能警告: 解析时间较长 ({}ms), 可能需要优化", stats.parse_time_ms);
|
||
}
|
||
|
||
// 将解析结果转换为JSON字符串
|
||
let json_string = serde_json::to_string(&parsed_value)
|
||
.map_err(|e| anyhow!("Failed to serialize parsed JSON: {}", e))?;
|
||
|
||
println!("✅ JSON序列化成功,输出长度: {} 字符", json_string.len());
|
||
Ok(json_string)
|
||
}
|
||
Err(e) => {
|
||
println!("❌ TolerantJsonParser解析失败: {}", e);
|
||
println!("📝 响应内容预览: {}", &response[..response.len().min(200)]);
|
||
|
||
// 直接返回错误,不使用回退方案,便于调试和问题定位
|
||
Err(anyhow!("TolerantJsonParser解析失败: {}. 响应内容: {}", e, &response[..response.len().min(500)]))
|
||
}
|
||
}
|
||
}
|
||
}
|