908 lines
34 KiB
Markdown
908 lines
34 KiB
Markdown
"""
|
||
Gemini AI服务
|
||
|
||
集成Google Gemini API,提供视频内容分析和分类功能。
|
||
"""
|
||
|
||
import asyncio
|
||
import base64
|
||
import io
|
||
import time
|
||
import json
|
||
import os
|
||
import hashlib
|
||
from typing import Dict, List, Any, Optional, Union, Callable
|
||
from pathlib import Path
|
||
import logging
|
||
from dataclasses import dataclass, asdict
|
||
|
||
import requests
|
||
|
||
try:
|
||
import google.generativeai as genai
|
||
from google.generativeai.types import HarmCategory, HarmBlockThreshold
|
||
GEMINI_AVAILABLE = True
|
||
except ImportError:
|
||
GEMINI_AVAILABLE = False
|
||
genai = None
|
||
|
||
from PIL import Image
|
||
import cv2
|
||
import numpy as np
|
||
|
||
from src.core.di import Injectable, Inject, Service
|
||
|
||
|
||
@dataclass
|
||
class GeminiConfig:
|
||
"""Gemini配置类 - 参考demo.py实现"""
|
||
# 认证配置
|
||
cloudflare_project_id: str = ""
|
||
cloudflare_gateway_id: str = ""
|
||
google_project_id: str = ""
|
||
regions: List[str] = None
|
||
access_token: str = ""
|
||
|
||
# API配置
|
||
model_name: str = "gemini-2.5-flash"
|
||
base_url: str = "https://bowongai-dev--bowong-ai-video-gemini-fastapi-webapp.modal.run"
|
||
bearer_token: str = "bowong7777"
|
||
timeout: int = 120
|
||
|
||
# 缓存配置
|
||
enable_cache: bool = True
|
||
cache_dir: str = ".cache/gemini_analysis"
|
||
cache_expiry: int = 7 * 24 * 3600 # 7天
|
||
|
||
# 上传缓存配置
|
||
enable_upload_cache: bool = True
|
||
upload_cache_dir: str = ".cache/gemini_uploads"
|
||
upload_cache_expiry: int = 24 * 3600 # 1天
|
||
|
||
# 重试配置
|
||
max_retries: int = 3
|
||
retry_delay: int = 5
|
||
|
||
def __post_init__(self):
|
||
if self.regions is None:
|
||
self.regions = ["us-central1", "us-east1", "europe-west1"]
|
||
|
||
|
||
@dataclass
|
||
class AnalysisProgress:
|
||
"""分析进度"""
|
||
step: str
|
||
progress: int # 0-100
|
||
description: str = ""
|
||
current_file: str = ""
|
||
stage: str = "upload" # upload, analysis, complete
|
||
|
||
|
||
@dataclass
|
||
class CacheEntry:
|
||
"""缓存条目"""
|
||
video_path: str
|
||
file_uri: str
|
||
prompt: str
|
||
result: Dict[str, Any]
|
||
timestamp: float
|
||
checksum: str
|
||
model_name: str
|
||
|
||
|
||
@dataclass
|
||
class UploadCacheEntry:
|
||
"""上传缓存条目"""
|
||
video_path: str
|
||
file_uri: str
|
||
timestamp: float
|
||
checksum: str
|
||
file_size: int
|
||
|
||
|
||
@Service("gemini_service")
|
||
class GeminiService:
|
||
"""
|
||
Gemini AI服务
|
||
|
||
提供基于Google Gemini API的视频内容分析和分类功能。
|
||
支持两种模式:
|
||
1. 传统模式:使用google-generativeai库
|
||
2. 新模式:使用Cloudflare Gateway + Vertex AI (参考demo.py)
|
||
"""
|
||
|
||
def __init__(self,
|
||
config: Dict[str, Any] = Inject("config"),
|
||
logger: logging.Logger = Inject("logger")):
|
||
self.config = config
|
||
self.logger = logger
|
||
|
||
# 传统Gemini配置
|
||
self.gemini_config = config.get("gemini", {})
|
||
self.api_key = self.gemini_config.get("api_key", "")
|
||
self.model_name = self.gemini_config.get("model", "gemini-2.5-flash")
|
||
self.max_tokens = self.gemini_config.get("max_tokens", 1000)
|
||
self.temperature = self.gemini_config.get("temperature", 0.1)
|
||
|
||
# 新模式配置 - 参考demo.py
|
||
self.new_mode_config = GeminiConfig(
|
||
cloudflare_project_id=self.gemini_config.get("cloudflare_project_id", ""),
|
||
cloudflare_gateway_id=self.gemini_config.get("cloudflare_gateway_id", ""),
|
||
google_project_id=self.gemini_config.get("google_project_id", ""),
|
||
regions=self.gemini_config.get("regions", ["us-central1", "us-east1", "europe-west1"]),
|
||
model_name=self.model_name,
|
||
base_url=self.gemini_config.get("base_url", "https://bowongai-dev--bowong-ai-video-gemini-fastapi-webapp.modal.run"),
|
||
bearer_token=self.gemini_config.get("bearer_token", "bowong7777"),
|
||
timeout=self.gemini_config.get("timeout", 120),
|
||
enable_cache=self.gemini_config.get("enable_cache", True),
|
||
cache_dir=self.gemini_config.get("cache_dir", ".cache/gemini_analysis"),
|
||
enable_upload_cache=self.gemini_config.get("enable_upload_cache", True),
|
||
upload_cache_dir=self.gemini_config.get("upload_cache_dir", ".cache/gemini_uploads"),
|
||
upload_cache_expiry=self.gemini_config.get("upload_cache_expiry", 86400),
|
||
max_retries=self.gemini_config.get("max_retries", 3),
|
||
retry_delay=self.gemini_config.get("retry_delay", 5)
|
||
)
|
||
|
||
# 检查是否启用新模式
|
||
self.use_new_mode = self.gemini_config.get("use_new_mode", False)
|
||
|
||
# 缓存相关
|
||
self._access_token = None
|
||
self._token_expires_at = None
|
||
|
||
# 确保缓存目录存在
|
||
if self.new_mode_config.enable_cache:
|
||
os.makedirs(self.new_mode_config.cache_dir, exist_ok=True)
|
||
|
||
# 确保上传缓存目录存在
|
||
if self.new_mode_config.enable_upload_cache:
|
||
os.makedirs(self.new_mode_config.upload_cache_dir, exist_ok=True)
|
||
|
||
# 速率限制配置(仅传统模式使用)
|
||
self.rate_limit_config = self.gemini_config.get("rate_limit", {})
|
||
self.requests_per_minute = self.rate_limit_config.get("requests_per_minute", 60)
|
||
self.requests_per_day = self.rate_limit_config.get("requests_per_day", 1500)
|
||
|
||
# 请求历史记录(用于速率限制)
|
||
self.request_history = []
|
||
|
||
# 初始化客户端(传统模式)
|
||
self.model = None
|
||
if not self.use_new_mode:
|
||
self._initialize_client()
|
||
|
||
def _initialize_client(self) -> None:
|
||
"""初始化Gemini客户端"""
|
||
if not GEMINI_AVAILABLE:
|
||
self.logger.error("Google Generative AI库未安装,请运行: pip install google-generativeai")
|
||
return
|
||
|
||
if not self.api_key:
|
||
self.logger.warning("未配置Gemini API密钥,AI分类功能将不可用")
|
||
return
|
||
|
||
try:
|
||
# 配置API密钥
|
||
|
||
# 创建模型实例
|
||
generation_config = {
|
||
"temperature": self.temperature,
|
||
"top_p": 0.95,
|
||
"top_k": 64,
|
||
"max_output_tokens": self.max_tokens,
|
||
}
|
||
|
||
safety_settings = {
|
||
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
|
||
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
|
||
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
|
||
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
|
||
}
|
||
|
||
|
||
self.logger.info(f"Gemini客户端初始化成功,模型: {self.model_name}")
|
||
|
||
except Exception as e:
|
||
self.logger.error(f"Gemini客户端初始化失败: {e}")
|
||
self.model = None
|
||
|
||
def is_available(self) -> bool:
|
||
"""检查Gemini服务是否可用"""
|
||
if self.use_new_mode:
|
||
# 新模式:检查配置是否完整
|
||
return (self.new_mode_config.cloudflare_project_id and
|
||
self.new_mode_config.cloudflare_gateway_id and
|
||
self.new_mode_config.google_project_id and
|
||
self.new_mode_config.bearer_token)
|
||
else:
|
||
# 传统模式
|
||
return GEMINI_AVAILABLE and self.model is not None
|
||
|
||
async def get_access_token(self) -> str:
|
||
"""
|
||
获取Google访问令牌,参考demo.py实现
|
||
"""
|
||
# 检查缓存的令牌是否仍然有效
|
||
if (self._access_token and self._token_expires_at and
|
||
time.time() < self._token_expires_at - 300): # 提前5分钟刷新
|
||
return self._access_token
|
||
|
||
try:
|
||
headers = {
|
||
"Authorization": f"Bearer {self.new_mode_config.bearer_token}",
|
||
"Content-Type": "application/json"
|
||
}
|
||
|
||
response = requests.get(
|
||
f"{self.new_mode_config.base_url}/google/access-token",
|
||
headers=headers,
|
||
timeout=self.new_mode_config.timeout
|
||
)
|
||
|
||
if response.status_code != 200:
|
||
raise Exception(f"获取访问令牌失败: {response.status_code} - {response.text}")
|
||
|
||
token_data = response.json()
|
||
self._access_token = token_data["access_token"]
|
||
self._token_expires_at = time.time() + token_data.get("expires_in", 3600)
|
||
|
||
self.logger.info("✅ 成功获取Google访问令牌")
|
||
return self._access_token
|
||
|
||
except Exception as e:
|
||
self.logger.error(f"❌ 获取访问令牌失败: {e}")
|
||
raise Exception(f"获取访问令牌失败: {str(e)}")
|
||
|
||
def _create_gemini_client(self, access_token: str) -> Dict[str, Any]:
|
||
"""
|
||
创建Gemini客户端配置,参考demo.py实现
|
||
"""
|
||
import random
|
||
|
||
# 随机选择区域
|
||
region = random.choice(self.new_mode_config.regions)
|
||
|
||
gateway_url = (
|
||
f"https://gateway.ai.cloudflare.com/v1/"
|
||
f"{self.new_mode_config.cloudflare_project_id}/"
|
||
f"{self.new_mode_config.cloudflare_gateway_id}/"
|
||
f"google-vertex-ai/v1/projects/"
|
||
f"{self.new_mode_config.google_project_id}/"
|
||
f"locations/{region}/publishers/google/models"
|
||
)
|
||
|
||
return {
|
||
"gateway_url": gateway_url,
|
||
"access_token": access_token,
|
||
"headers": {
|
||
"Content-Type": "application/json",
|
||
"Authorization": f"Bearer {access_token}"
|
||
}
|
||
}
|
||
|
||
async def _check_rate_limit(self) -> bool:
|
||
"""检查速率限制"""
|
||
current_time = time.time()
|
||
|
||
# 清理过期的请求记录
|
||
self.request_history = [
|
||
req_time for req_time in self.request_history
|
||
if current_time - req_time < 86400 # 24小时
|
||
]
|
||
|
||
# 检查每日限制
|
||
if len(self.request_history) >= self.requests_per_day:
|
||
self.logger.warning("已达到Gemini API每日请求限制")
|
||
return False
|
||
|
||
# 检查每分钟限制
|
||
recent_requests = [
|
||
req_time for req_time in self.request_history
|
||
if current_time - req_time < 60 # 1分钟
|
||
]
|
||
|
||
if len(recent_requests) >= self.requests_per_minute:
|
||
self.logger.warning("已达到Gemini API每分钟请求限制")
|
||
return False
|
||
|
||
return True
|
||
|
||
def _calculate_file_checksum(self, file_path: str) -> str:
|
||
"""计算文件校验和"""
|
||
hash_md5 = hashlib.md5()
|
||
with open(file_path, "rb") as f:
|
||
for chunk in iter(lambda: f.read(4096), b""):
|
||
hash_md5.update(chunk)
|
||
return hash_md5.hexdigest()
|
||
|
||
def _generate_cache_key(self, video_path: str, prompt: str, model_name: str) -> str:
|
||
"""生成缓存键"""
|
||
key_data = f"{video_path}:{prompt}:{model_name}"
|
||
return hashlib.md5(key_data.encode()).hexdigest()
|
||
|
||
def _check_analysis_cache(self, video_path: str, prompt: str) -> Optional[Dict[str, Any]]:
|
||
"""检查分析缓存"""
|
||
if not self.new_mode_config.enable_cache:
|
||
return None
|
||
|
||
try:
|
||
cache_key = self._generate_cache_key(video_path, prompt, self.new_mode_config.model_name)
|
||
cache_file = os.path.join(self.new_mode_config.cache_dir, f"{cache_key}.json")
|
||
|
||
if not os.path.exists(cache_file):
|
||
return None
|
||
|
||
with open(cache_file, 'r', encoding='utf-8') as f:
|
||
cache_entry_data = json.load(f)
|
||
cache_entry = CacheEntry(**cache_entry_data)
|
||
|
||
# 检查缓存是否过期
|
||
if time.time() - cache_entry.timestamp > self.new_mode_config.cache_expiry:
|
||
os.unlink(cache_file)
|
||
self.logger.info(f"⏰ 缓存已过期: {Path(video_path).name}")
|
||
return None
|
||
|
||
# 检查文件是否发生变化
|
||
current_checksum = self._calculate_file_checksum(video_path)
|
||
if current_checksum != cache_entry.checksum:
|
||
os.unlink(cache_file)
|
||
self.logger.info(f"🔄 文件已变更: {Path(video_path).name}")
|
||
return None
|
||
|
||
self.logger.info(f"🎯 使用缓存的分析结果: {Path(video_path).name}")
|
||
return cache_entry.result
|
||
|
||
except Exception as e:
|
||
self.logger.warning(f"检查分析缓存失败: {e}")
|
||
return None
|
||
|
||
def _save_analysis_cache(self, video_path: str, file_uri: str, prompt: str, result: Dict[str, Any]) -> None:
|
||
"""保存分析结果到缓存"""
|
||
if not self.new_mode_config.enable_cache:
|
||
return
|
||
|
||
try:
|
||
cache_key = self._generate_cache_key(video_path, prompt, self.new_mode_config.model_name)
|
||
cache_file = os.path.join(self.new_mode_config.cache_dir, f"{cache_key}.json")
|
||
|
||
checksum = self._calculate_file_checksum(video_path)
|
||
cache_entry = CacheEntry(
|
||
video_path=video_path,
|
||
file_uri=file_uri,
|
||
prompt=prompt,
|
||
result=result,
|
||
timestamp=time.time(),
|
||
checksum=checksum,
|
||
model_name=self.new_mode_config.model_name
|
||
)
|
||
|
||
with open(cache_file, 'w', encoding='utf-8') as f:
|
||
json.dump(asdict(cache_entry), f, ensure_ascii=False, indent=2)
|
||
|
||
self.logger.info(f"💾 分析结果已缓存: {Path(video_path).name}")
|
||
|
||
except Exception as e:
|
||
self.logger.warning(f"保存分析缓存失败: {e}")
|
||
|
||
def _generate_upload_cache_key(self, video_path: str) -> str:
|
||
"""生成上传缓存键"""
|
||
# 使用文件路径和修改时间生成唯一键
|
||
file_stat = os.stat(video_path)
|
||
key_data = f"{video_path}:{file_stat.st_mtime}:{file_stat.st_size}"
|
||
return hashlib.md5(key_data.encode()).hexdigest()
|
||
|
||
def _check_upload_cache(self, video_path: str) -> Optional[str]:
|
||
"""检查上传缓存"""
|
||
if not self.new_mode_config.enable_upload_cache:
|
||
return None
|
||
|
||
try:
|
||
cache_key = self._generate_upload_cache_key(video_path)
|
||
cache_file = os.path.join(self.new_mode_config.upload_cache_dir, f"{cache_key}.json")
|
||
|
||
if not os.path.exists(cache_file):
|
||
return None
|
||
|
||
with open(cache_file, 'r', encoding='utf-8') as f:
|
||
cache_entry_data = json.load(f)
|
||
cache_entry = UploadCacheEntry(**cache_entry_data)
|
||
|
||
# 检查缓存是否过期
|
||
if time.time() - cache_entry.timestamp > self.new_mode_config.upload_cache_expiry:
|
||
os.unlink(cache_file)
|
||
self.logger.info(f"⏰ 上传缓存已过期: {Path(video_path).name}")
|
||
return None
|
||
|
||
# 检查文件是否发生变化
|
||
current_checksum = self._calculate_file_checksum(video_path)
|
||
current_size = os.path.getsize(video_path)
|
||
|
||
if (current_checksum != cache_entry.checksum or
|
||
current_size != cache_entry.file_size):
|
||
os.unlink(cache_file)
|
||
self.logger.info(f"🔄 文件已变更,清除上传缓存: {Path(video_path).name}")
|
||
return None
|
||
|
||
self.logger.info(f"🎯 使用缓存的上传URI: {Path(video_path).name} -> {cache_entry.file_uri}")
|
||
return cache_entry.file_uri
|
||
|
||
except Exception as e:
|
||
self.logger.warning(f"检查上传缓存失败: {e}")
|
||
return None
|
||
|
||
def _save_upload_cache(self, video_path: str, file_uri: str) -> None:
|
||
"""保存上传缓存"""
|
||
if not self.new_mode_config.enable_upload_cache:
|
||
return
|
||
|
||
try:
|
||
cache_key = self._generate_upload_cache_key(video_path)
|
||
cache_file = os.path.join(self.new_mode_config.upload_cache_dir, f"{cache_key}.json")
|
||
|
||
checksum = self._calculate_file_checksum(video_path)
|
||
file_size = os.path.getsize(video_path)
|
||
|
||
cache_entry = UploadCacheEntry(
|
||
video_path=video_path,
|
||
file_uri=file_uri,
|
||
timestamp=time.time(),
|
||
checksum=checksum,
|
||
file_size=file_size
|
||
)
|
||
|
||
with open(cache_file, 'w', encoding='utf-8') as f:
|
||
json.dump(asdict(cache_entry), f, ensure_ascii=False, indent=2)
|
||
|
||
self.logger.info(f"💾 上传URI已缓存: {Path(video_path).name} -> {file_uri}")
|
||
|
||
except Exception as e:
|
||
self.logger.warning(f"保存上传缓存失败: {e}")
|
||
|
||
async def _upload_video_file_new_mode(self, video_path: str) -> str:
|
||
"""
|
||
上传视频文件到Gemini,参考demo.py实现
|
||
支持上传缓存,避免重复上传相同文件
|
||
"""
|
||
try:
|
||
# 检查上传缓存
|
||
cached_uri = self._check_upload_cache(video_path)
|
||
if cached_uri:
|
||
return cached_uri
|
||
|
||
# 检查文件大小
|
||
file_size = os.path.getsize(video_path)
|
||
max_size = 100 * 1024 * 1024 # 100MB限制
|
||
|
||
if file_size > max_size:
|
||
raise Exception(f"视频文件过大 ({file_size / 1024 / 1024:.1f}MB),请使用小于100MB的文件")
|
||
|
||
# 获取访问令牌
|
||
access_token = await self.get_access_token()
|
||
|
||
# 准备FormData
|
||
with open(video_path, 'rb') as f:
|
||
video_data = f.read()
|
||
|
||
# 使用新的上传API
|
||
files = {
|
||
'file': (Path(video_path).name, video_data, 'video/mp4')
|
||
}
|
||
|
||
headers = {
|
||
"Authorization": f"Bearer {access_token}",
|
||
"x-google-api-key": access_token,
|
||
}
|
||
|
||
# 上传到Vertex AI
|
||
upload_url = f"{self.new_mode_config.base_url}/google/vertex-ai/upload"
|
||
params = {
|
||
"bucket": "dy-media-storage",
|
||
"prefix": "video-analysis"
|
||
}
|
||
|
||
response = requests.post(
|
||
upload_url,
|
||
files=files,
|
||
headers=headers,
|
||
params=params,
|
||
timeout=self.new_mode_config.timeout
|
||
)
|
||
|
||
if response.status_code != 200:
|
||
raise Exception(f"文件上传失败: {response.status_code} - {response.text}")
|
||
|
||
upload_result = response.json()
|
||
file_uri = upload_result.get('urn') or upload_result.get('uri')
|
||
|
||
if not file_uri:
|
||
raise Exception("上传成功但未获取到文件URI")
|
||
|
||
self.logger.info(f"✅ 视频上传成功: {Path(video_path).name} -> {file_uri}")
|
||
|
||
# 保存到上传缓存
|
||
self._save_upload_cache(video_path, file_uri)
|
||
|
||
return file_uri
|
||
|
||
except Exception as e:
|
||
self.logger.error(f"❌ 视频上传失败: {e}")
|
||
raise Exception(f"视频上传失败: {str(e)}")
|
||
|
||
async def _generate_content_new_mode(self, file_uri: str, prompt: str) -> Dict[str, Any]:
|
||
"""
|
||
生成内容分析,参考demo.py实现
|
||
"""
|
||
try:
|
||
# 获取访问令牌
|
||
access_token = await self.get_access_token()
|
||
|
||
# 创建客户端配置
|
||
client_config = self._create_gemini_client(access_token)
|
||
|
||
# 格式化GCS URI
|
||
formatted_uri = self._format_gcs_uri(file_uri)
|
||
|
||
# 准备请求数据,参考demo.py实现
|
||
request_data = {
|
||
"contents": [
|
||
{
|
||
"role": "user",
|
||
"parts": [
|
||
{
|
||
"text": prompt
|
||
},
|
||
{
|
||
"fileData": {
|
||
"mimeType": "video/mp4",
|
||
"fileUri": formatted_uri
|
||
}
|
||
}
|
||
]
|
||
}
|
||
],
|
||
"generationConfig": {
|
||
"temperature": self.temperature,
|
||
"topK": 32,
|
||
"topP": 1,
|
||
"maxOutputTokens": self.max_tokens
|
||
}
|
||
}
|
||
|
||
# 发送请求到Cloudflare Gateway
|
||
generate_url = f"{client_config['gateway_url']}/{self.new_mode_config.model_name}:generateContent"
|
||
|
||
self.logger.info(f"📤 发送 Gemini API 请求: {formatted_uri}")
|
||
|
||
# 重试机制
|
||
last_exception = None
|
||
for attempt in range(self.new_mode_config.max_retries):
|
||
try:
|
||
response = requests.post(
|
||
generate_url,
|
||
headers=client_config["headers"],
|
||
json=request_data,
|
||
timeout=self.new_mode_config.timeout
|
||
)
|
||
|
||
if response.status_code == 200:
|
||
result = response.json()
|
||
|
||
if 'candidates' not in result or not result['candidates']:
|
||
raise Exception("API返回结果为空")
|
||
|
||
self.logger.info("✅ 成功获取Gemini分析结果")
|
||
return result
|
||
else:
|
||
error_msg = f"API请求失败: {response.status_code} - {response.text}"
|
||
self.logger.warning(f"⚠️ 尝试 {attempt + 1}/{self.new_mode_config.max_retries}: {error_msg}")
|
||
|
||
if attempt == self.new_mode_config.max_retries - 1:
|
||
raise Exception(error_msg)
|
||
|
||
await asyncio.sleep(self.new_mode_config.retry_delay)
|
||
|
||
except requests.exceptions.Timeout as e:
|
||
last_exception = e
|
||
self.logger.warning(f"⚠️ 请求超时,尝试 {attempt + 1}/{self.new_mode_config.max_retries}")
|
||
if attempt < self.new_mode_config.max_retries - 1:
|
||
await asyncio.sleep(self.new_mode_config.retry_delay)
|
||
except Exception as e:
|
||
last_exception = e
|
||
self.logger.warning(f"⚠️ 请求失败,尝试 {attempt + 1}/{self.new_mode_config.max_retries}: {e}")
|
||
if attempt < self.new_mode_config.max_retries - 1:
|
||
await asyncio.sleep(self.new_mode_config.retry_delay)
|
||
|
||
raise Exception(f"内容生成失败,已重试{self.new_mode_config.max_retries}次: {last_exception}")
|
||
|
||
except Exception as e:
|
||
self.logger.error(f"❌ 内容生成失败: {e}")
|
||
raise Exception(f"内容生成失败: {str(e)}")
|
||
|
||
def _format_gcs_uri(self, file_uri: str) -> str:
|
||
"""格式化GCS URI"""
|
||
if file_uri.startswith('gs://'):
|
||
return file_uri
|
||
elif file_uri.startswith('https://storage.googleapis.com/'):
|
||
# 转换为gs://格式
|
||
path = file_uri.replace('https://storage.googleapis.com/', '')
|
||
return f"gs://{path}"
|
||
else:
|
||
# 假设已经是正确格式
|
||
return file_uri
|
||
|
||
def _parse_analysis_result_new_mode(self, api_result: Dict[str, Any], video_path: str) -> Dict[str, Any]:
|
||
"""
|
||
解析分析结果,参考demo.py实现
|
||
|
||
|
||
"""
|
||
try:
|
||
# 提取文本内容
|
||
candidates = api_result.get('candidates', [])
|
||
if not candidates:
|
||
raise Exception("无有效的分析结果")
|
||
|
||
content = candidates[0].get('content', {})
|
||
parts = content.get('parts', [])
|
||
|
||
if not parts:
|
||
raise Exception("分析结果为空")
|
||
|
||
analysis_text = parts[0].get('text', '')
|
||
|
||
if not analysis_text:
|
||
raise Exception("未获取到分析文本")
|
||
|
||
self.logger.info(f"✅ 成功获取响应文本,长度: {len(analysis_text)}")
|
||
|
||
# 尝试解析JSON格式的结果
|
||
analysis_data = None
|
||
try:
|
||
# 清理文本,移除可能的markdown标记
|
||
cleaned_text = analysis_text.strip()
|
||
if cleaned_text.startswith('```json'):
|
||
cleaned_text = cleaned_text[7:]
|
||
if cleaned_text.endswith('```'):
|
||
cleaned_text = cleaned_text[:-3]
|
||
cleaned_text = cleaned_text.strip()
|
||
|
||
# 直接尝试解析JSON
|
||
if cleaned_text.startswith('{') or cleaned_text.startswith('['):
|
||
analysis_data = json.loads(cleaned_text)
|
||
self.logger.info("✅ 成功解析JSON格式的分析结果")
|
||
else:
|
||
# 使用正则表达式提取JSON部分
|
||
import re
|
||
json_match = re.search(r'\{.*\}', cleaned_text, re.DOTALL)
|
||
if json_match:
|
||
json_str = json_match.group()
|
||
analysis_data = json.loads(json_str)
|
||
self.logger.info("✅ 成功解析JSON格式的分析结果")
|
||
else:
|
||
raise json.JSONDecodeError("No JSON found", "", 0)
|
||
|
||
except json.JSONDecodeError:
|
||
# JSON解析失败,使用文本格式
|
||
self.logger.info("📝 使用文本格式的分析结果")
|
||
analysis_data = {
|
||
"content_analysis": {
|
||
"summary": analysis_text[:500] + "..." if len(analysis_text) > 500 else analysis_text,
|
||
"full_text": analysis_text
|
||
}
|
||
}
|
||
|
||
# 提取新增字段
|
||
"""
|
||
{{
|
||
"category": "分类结果",
|
||
"confidence": 0.85,
|
||
"reasoning": "详细的分类理由,包括商品匹配情况和内容特征",
|
||
"features": ["观察到的关键特征1", "关键特征2", "关键特征3"],
|
||
"product_match": true/false,
|
||
"quality_score": 0.9
|
||
}}
|
||
"""
|
||
product_match = analysis_data.get("product_match", True) # 默认为True保持兼容性
|
||
quality_score = analysis_data.get("quality_score", 1.0) # 默认为1.0保持兼容性
|
||
|
||
# 如果商品不匹配或质量太低,强制分类为废弃素材
|
||
category = analysis_data.get("category", "unclassified")
|
||
confidence = analysis_data.get("confidence", 0.8)
|
||
|
||
if not product_match or quality_score < 0.5:
|
||
category = "废弃素材"
|
||
confidence = max(confidence, 0.8) # 提高废弃素材的置信度
|
||
self.logger.info(f"商品不匹配或质量不合格,分类为废弃素材: product_match={product_match}, quality_score={quality_score}")
|
||
|
||
# 构建标准化结果格式
|
||
result = {
|
||
"success": True,
|
||
"category": category,
|
||
"confidence": confidence,
|
||
"reasoning": analysis_data.get("reasoning", "AI分析结果"),
|
||
"features": analysis_data.get("features", []),
|
||
"product_match": product_match,
|
||
"quality_score": quality_score,
|
||
"video_info": {
|
||
"file_name": Path(video_path).name,
|
||
"file_path": str(video_path),
|
||
"file_size": os.path.getsize(video_path),
|
||
"analysis_time": time.strftime("%Y-%m-%d %H:%M:%S"),
|
||
"model_used": self.new_mode_config.model_name
|
||
},
|
||
"analysis_result": analysis_data,
|
||
"metadata": {
|
||
"response_length": len(analysis_text),
|
||
"candidates_count": len(candidates),
|
||
"mode": "new_mode"
|
||
},
|
||
"raw_response": analysis_text
|
||
}
|
||
|
||
return result
|
||
|
||
except Exception as e:
|
||
self.logger.error(f"❌ 结果解析失败: {e}")
|
||
return {
|
||
"success": False,
|
||
"error": str(e),
|
||
"category": "unclassified",
|
||
"confidence": 0.0
|
||
}
|
||
|
||
|
||
|
||
async def _make_request(self, prompt: str, images: List[Image.Image] = None) -> Optional[str]:
|
||
"""发送请求到Gemini API"""
|
||
if not self.is_available():
|
||
raise RuntimeError("Gemini服务不可用")
|
||
|
||
# 检查速率限制
|
||
if not await self._check_rate_limit():
|
||
raise RuntimeError("已达到API速率限制")
|
||
|
||
try:
|
||
# 准备输入内容
|
||
content = [prompt]
|
||
|
||
if images:
|
||
for image in images:
|
||
content.append(image)
|
||
|
||
# 发送请求
|
||
response = await asyncio.to_thread(
|
||
self.model.generate_content,
|
||
content
|
||
)
|
||
|
||
# 记录请求时间
|
||
self.request_history.append(time.time())
|
||
|
||
# 检查响应
|
||
if response.candidates and len(response.candidates) > 0:
|
||
candidate = response.candidates[0]
|
||
if hasattr(candidate, 'content') and candidate.content.parts:
|
||
return candidate.content.parts[0].text
|
||
|
||
self.logger.warning("Gemini API返回空响应")
|
||
return None
|
||
|
||
except Exception as e:
|
||
self.logger.error(f"Gemini API请求失败: {e}")
|
||
raise
|
||
|
||
|
||
|
||
async def analyze_video(self, video_path: str, prompt: str = None, product_title: str = "", **kwargs) -> Dict[str, Any]:
|
||
return await self.analyze_video_content(video_path, product_title=product_title, **kwargs)
|
||
|
||
async def analyze_video_content(self, video_path: str, product_title: str = "", **kwargs) -> Dict[str, Any]:
|
||
"""分析视频内容"""
|
||
try:
|
||
return await self._analyze_video_new_mode(video_path, product_title=product_title, **kwargs)
|
||
except Exception as e:
|
||
self.logger.error(f"视频分析失败: {e}")
|
||
return {
|
||
"success": False,
|
||
"error": str(e),
|
||
"category": "unclassified",
|
||
"confidence": 0.0
|
||
}
|
||
|
||
async def _analyze_video_new_mode(self, video_path: str, product_title: str = "", **kwargs) -> Dict[str, Any]:
|
||
try:
|
||
# 构建分析提示词
|
||
prompt = ""
|
||
|
||
# 检查缓存
|
||
cached_result = self._check_analysis_cache(video_path, prompt)
|
||
if cached_result:
|
||
return cached_result
|
||
|
||
# 上传视频文件
|
||
file_uri = await self._upload_video_file_new_mode(video_path)
|
||
|
||
# 发送分析请求
|
||
result = await self._generate_content_new_mode(file_uri, prompt)
|
||
|
||
# 解析结果
|
||
"""
|
||
{
|
||
"success": True,
|
||
"category": category,
|
||
"confidence": confidence,
|
||
"reasoning": analysis_data.get("reasoning", "AI分析结果"),
|
||
"features": analysis_data.get("features", []),
|
||
"product_match": product_match,
|
||
"quality_score": quality_score,
|
||
"video_info": {
|
||
"file_name": Path(video_path).name,
|
||
"file_path": str(video_path),
|
||
"file_size": os.path.getsize(video_path),
|
||
"analysis_time": time.strftime("%Y-%m-%d %H:%M:%S"),
|
||
"model_used": self.new_mode_config.model_name
|
||
},
|
||
"analysis_result": analysis_data,
|
||
"metadata": {
|
||
"response_length": len(analysis_text),
|
||
"candidates_count": len(candidates),
|
||
"mode": "new_mode"
|
||
},
|
||
"raw_response": analysis_text
|
||
}
|
||
"""
|
||
parsed_result = self._parse_analysis_result_new_mode(result, video_path)
|
||
|
||
# 保存到缓存
|
||
self._save_analysis_cache(video_path, file_uri, prompt, parsed_result)
|
||
|
||
return parsed_result
|
||
|
||
except Exception as e:
|
||
self.logger.error(f"新模式视频分析失败: {e}")
|
||
raise
|
||
|
||
|
||
async def batch_analyze_videos(self, video_paths: List[str], **kwargs) -> Dict[str, Dict[str, Any]]:
|
||
"""批量分析视频"""
|
||
self.logger.info(f"开始批量分析 {len(video_paths)} 个视频")
|
||
|
||
results = {}
|
||
|
||
# 控制并发数以避免API限制
|
||
max_concurrent = kwargs.get("max_concurrent", 3)
|
||
semaphore = asyncio.Semaphore(max_concurrent)
|
||
|
||
async def analyze_single(video_path: str):
|
||
async with semaphore:
|
||
try:
|
||
result = await self.analyze_video_content(video_path, **kwargs)
|
||
return video_path, result
|
||
except Exception as e:
|
||
self.logger.error(f"分析视频失败 {video_path}: {e}")
|
||
return video_path, {
|
||
"success": False,
|
||
"error": str(e),
|
||
"category": "unclassified",
|
||
"confidence": 0.0
|
||
}
|
||
|
||
# 创建并发任务
|
||
tasks = [analyze_single(video_path) for video_path in video_paths]
|
||
|
||
# 等待所有任务完成
|
||
completed_tasks = await asyncio.gather(*tasks, return_exceptions=True)
|
||
|
||
# 处理结果
|
||
for task_result in completed_tasks:
|
||
if isinstance(task_result, Exception):
|
||
self.logger.error(f"批量分析任务失败: {task_result}")
|
||
else:
|
||
video_path, result = task_result
|
||
results[video_path] = result
|
||
|
||
success_count = len([r for r in results.values() if r.get("success", False)])
|
||
self.logger.info(f"批量分析完成: 成功 {success_count}/{len(video_paths)}")
|
||
|
||
return results
|