287 lines
10 KiB
Python
287 lines
10 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
API Client Module
|
||
API 客户端模块
|
||
|
||
Handles communication with AI video generation APIs.
|
||
"""
|
||
|
||
import os
|
||
import time
|
||
from typing import Dict, Any, Optional, Callable
|
||
|
||
import sys
|
||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
from ..config import settings
|
||
from ..utils import setup_logger
|
||
logger = setup_logger(__name__)
|
||
|
||
class APIClient:
|
||
"""Client for AI video generation API."""
|
||
|
||
def __init__(self, api_key: str = None, base_url: str = None):
|
||
"""
|
||
Initialize API client.
|
||
|
||
Args:
|
||
api_key: API key for authentication
|
||
base_url: Base URL for API endpoints
|
||
"""
|
||
self.api_key = api_key or os.getenv('AI_VIDEO_API_KEY', '21575c22-14aa-40ca-8aa8-f00ca27a3a17')
|
||
self.base_url = base_url or os.getenv('AI_VIDEO_BASE_URL', 'https://ark.cn-beijing.volces.com/api/v3')
|
||
|
||
# Model configurations
|
||
self.models = {
|
||
'lite': {
|
||
'name': 'doubao-seedance-1-0-lite-i2v-250428',
|
||
'resolution': '720p'
|
||
},
|
||
'pro': {
|
||
'name': 'doubao-seedance-1-0-pro-250528',
|
||
'resolution': '1080p'
|
||
}
|
||
}
|
||
|
||
def submit_task(self, prompt: str, img_url: str, duration: str = '5', model_type: str = 'lite') -> Dict[str, Any]:
|
||
"""
|
||
Submit video generation task.
|
||
|
||
Args:
|
||
prompt: Text prompt for video generation
|
||
img_url: URL of the input image (http/https URL or file:// for local files)
|
||
duration: Video duration ('5' or '10')
|
||
model_type: Model type ('lite' or 'pro')
|
||
|
||
Returns:
|
||
Dictionary with task submission result
|
||
"""
|
||
result = {'status': False, 'data': None, 'msg': ''}
|
||
|
||
if duration not in ('5', '10'):
|
||
result['msg'] = 'Duration must be either 5 or 10'
|
||
return result
|
||
|
||
if model_type not in self.models:
|
||
result['msg'] = f'Model type must be one of: {list(self.models.keys())}'
|
||
return result
|
||
|
||
# Handle local file URLs
|
||
if img_url.startswith('file://'):
|
||
result['status'] = False
|
||
result['msg'] = 'Local files are not supported by the API. Please upload the image to cloud storage first.'
|
||
logger.error(f"Local file URL not supported: {img_url}")
|
||
return result
|
||
|
||
try:
|
||
import requests
|
||
|
||
model_config = self.models[model_type]
|
||
|
||
headers = {
|
||
'Content-Type': 'application/json',
|
||
'Authorization': f'Bearer {self.api_key}',
|
||
}
|
||
|
||
json_data = {
|
||
'model': model_config['name'],
|
||
'content': [
|
||
{
|
||
'type': 'text',
|
||
'text': f'{prompt} --resolution {model_config["resolution"]} --dur {duration} --camerafixed false',
|
||
},
|
||
{
|
||
'type': 'image_url',
|
||
'image_url': {
|
||
'url': img_url,
|
||
},
|
||
},
|
||
],
|
||
}
|
||
|
||
response = requests.post(
|
||
f'{self.base_url}/contents/generations/tasks',
|
||
headers=headers,
|
||
json=json_data,
|
||
timeout=30
|
||
)
|
||
|
||
# Check HTTP status code
|
||
if response.status_code != 200:
|
||
error_msg = f"API request failed, status code: {response.status_code}, response: {response.text}"
|
||
logger.error(error_msg)
|
||
result['msg'] = error_msg
|
||
return result
|
||
|
||
resp_json = response.json()
|
||
|
||
# Check if response contains id field
|
||
if 'id' not in resp_json:
|
||
error_msg = f"API response missing id field, response: {resp_json}"
|
||
logger.error(error_msg)
|
||
result['msg'] = error_msg
|
||
return result
|
||
|
||
job_id = resp_json['id']
|
||
result['status'] = True
|
||
result['data'] = job_id
|
||
result['msg'] = 'Task submitted successfully'
|
||
|
||
logger.info(f"Task submitted successfully, job ID: {job_id}")
|
||
|
||
except Exception as e:
|
||
import traceback
|
||
error_details = {
|
||
'error_type': type(e).__name__,
|
||
'error_message': str(e),
|
||
'traceback': traceback.format_exc()
|
||
}
|
||
logger.error(f"Failed to submit task: {error_details['error_type']}: {error_details['error_message']}")
|
||
logger.error(f"Traceback: {error_details['traceback']}")
|
||
result['msg'] = f"{error_details['error_type']}: {error_details['error_message']}"
|
||
result['error_details'] = error_details
|
||
|
||
return result
|
||
|
||
def query_task_status(self, job_id: str) -> Dict[str, Any]:
|
||
"""
|
||
Query task status.
|
||
|
||
Args:
|
||
job_id: Task ID to query
|
||
|
||
Returns:
|
||
Dictionary with task status
|
||
"""
|
||
result = {'status': False, 'data': None, 'msg': ''}
|
||
|
||
try:
|
||
import requests
|
||
|
||
headers = {
|
||
'Content-Type': 'application/json',
|
||
'Authorization': f'Bearer {self.api_key}',
|
||
}
|
||
|
||
response = requests.get(
|
||
f'{self.base_url}/contents/generations/tasks/{job_id}',
|
||
headers=headers,
|
||
timeout=30
|
||
)
|
||
|
||
if response.status_code != 200:
|
||
result['msg'] = f"API request failed, status code: {response.status_code}"
|
||
return result
|
||
|
||
resp_json = response.json()
|
||
|
||
# Parse response
|
||
task_status = resp_json.get('status', 'unknown')
|
||
result['msg'] = task_status
|
||
|
||
if task_status == 'succeeded':
|
||
result['status'] = True
|
||
result['data'] = resp_json.get('content', {}).get('video_url')
|
||
elif task_status in ['failed', 'cancelled']:
|
||
result['status'] = False
|
||
result['data'] = None
|
||
else:
|
||
# Still running, pending, or queued
|
||
result['status'] = False
|
||
result['data'] = None
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to query task status: {str(e)}")
|
||
result['msg'] = str(e)
|
||
|
||
return result
|
||
|
||
def wait_for_completion(self, job_id: str, timeout: int = 180, interval: int = 2,
|
||
progress_callback: Optional[Callable[[str], None]] = None) -> Dict[str, Any]:
|
||
"""
|
||
Wait for task completion with progress updates.
|
||
|
||
Args:
|
||
job_id: Task ID to wait for
|
||
timeout: Maximum wait time in seconds
|
||
interval: Check interval in seconds
|
||
progress_callback: Optional callback for progress updates
|
||
|
||
Returns:
|
||
Dictionary with final result
|
||
"""
|
||
result = {'status': False, 'data': None, 'msg': ''}
|
||
|
||
end_time = time.time() + timeout
|
||
wait_count = 0
|
||
|
||
if progress_callback:
|
||
progress_callback(f"开始查询任务状态,任务ID: {job_id}")
|
||
|
||
while time.time() < end_time:
|
||
status_result = self.query_task_status(job_id)
|
||
|
||
if status_result['status']:
|
||
# Task completed successfully
|
||
result['status'] = True
|
||
result['data'] = status_result['data']
|
||
result['msg'] = 'succeeded'
|
||
if progress_callback:
|
||
progress_callback("[完成] 视频生成完成!")
|
||
break
|
||
|
||
elif status_result['msg'] == 'running':
|
||
wait_count += 1
|
||
elapsed = wait_count * interval
|
||
remaining = max(0, timeout - elapsed)
|
||
progress_msg = f"[运行中] 任务运行中,已等待{elapsed}秒,预计剩余{remaining}秒..."
|
||
logger.info(progress_msg)
|
||
if progress_callback:
|
||
progress_callback(progress_msg)
|
||
|
||
# Send detailed progress via JSON-RPC
|
||
from .json_rpc import create_progress_reporter
|
||
progress = create_progress_reporter()
|
||
progress.update(
|
||
step="running",
|
||
progress=min(100, (elapsed / timeout) * 100),
|
||
message=progress_msg,
|
||
details={
|
||
"elapsed_seconds": elapsed,
|
||
"remaining_seconds": remaining,
|
||
"total_timeout": timeout
|
||
}
|
||
)
|
||
time.sleep(interval)
|
||
|
||
elif status_result['msg'] == 'failed':
|
||
result['msg'] = '任务执行失败'
|
||
if progress_callback:
|
||
progress_callback("[失败] 任务执行失败")
|
||
break
|
||
|
||
elif status_result['msg'] in ['pending', 'queued']:
|
||
wait_count += 1
|
||
elapsed = wait_count * interval
|
||
remaining = max(0, timeout - elapsed)
|
||
progress_msg = f"[排队中] 任务排队中,已等待{elapsed}秒,预计剩余{remaining}秒..."
|
||
logger.info(progress_msg)
|
||
if progress_callback:
|
||
progress_callback(progress_msg)
|
||
time.sleep(interval)
|
||
|
||
else:
|
||
# Unknown status, continue waiting
|
||
wait_count += 1
|
||
logger.info(f"未知状态: {status_result['msg']},继续等待...")
|
||
if progress_callback:
|
||
progress_callback(f"[未知] 状态: {status_result['msg']},继续等待...")
|
||
time.sleep(interval)
|
||
|
||
if not result['status'] and result['msg'] == '':
|
||
result['msg'] = '任务超时'
|
||
if progress_callback:
|
||
progress_callback(f"[超时] 任务查询超时({timeout}秒)")
|
||
|
||
return result
|