mxivideo/python_core/ai_video/api_client.py

286 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
API Client Module
API 客户端模块
Handles communication with AI video generation APIs.
"""
import os
import time
from typing import Dict, Any, Optional, Callable
import sys
from ..config import settings
from ..utils import setup_logger
logger = setup_logger(__name__)
class APIClient:
"""Client for AI video generation API."""
def __init__(self, api_key: str = None, base_url: str = None):
"""
Initialize API client.
Args:
api_key: API key for authentication
base_url: Base URL for API endpoints
"""
self.api_key = api_key or os.getenv('AI_VIDEO_API_KEY', '21575c22-14aa-40ca-8aa8-f00ca27a3a17')
self.base_url = base_url or os.getenv('AI_VIDEO_BASE_URL', 'https://ark.cn-beijing.volces.com/api/v3')
# Model configurations
self.models = {
'lite': {
'name': 'doubao-seedance-1-0-lite-i2v-250428',
'resolution': '720p'
},
'pro': {
'name': 'doubao-seedance-1-0-pro-250528',
'resolution': '1080p'
}
}
def submit_task(self, prompt: str, img_url: str, duration: str = '5', model_type: str = 'lite') -> Dict[str, Any]:
"""
Submit video generation task.
Args:
prompt: Text prompt for video generation
img_url: URL of the input image (http/https URL or file:// for local files)
duration: Video duration ('5' or '10')
model_type: Model type ('lite' or 'pro')
Returns:
Dictionary with task submission result
"""
result = {'status': False, 'data': None, 'msg': ''}
if duration not in ('5', '10'):
result['msg'] = 'Duration must be either 5 or 10'
return result
if model_type not in self.models:
result['msg'] = f'Model type must be one of: {list(self.models.keys())}'
return result
# Handle local file URLs
if img_url.startswith('file://'):
result['status'] = False
result['msg'] = 'Local files are not supported by the API. Please upload the image to cloud storage first.'
logger.error(f"Local file URL not supported: {img_url}")
return result
try:
import requests
model_config = self.models[model_type]
headers = {
'Content-Type': 'application/json',
'Authorization': f'Bearer {self.api_key}',
}
json_data = {
'model': model_config['name'],
'content': [
{
'type': 'text',
'text': f'{prompt} --resolution {model_config["resolution"]} --dur {duration} --camerafixed false',
},
{
'type': 'image_url',
'image_url': {
'url': img_url,
},
},
],
}
response = requests.post(
f'{self.base_url}/contents/generations/tasks',
headers=headers,
json=json_data,
timeout=30
)
# Check HTTP status code
if response.status_code != 200:
error_msg = f"API request failed, status code: {response.status_code}, response: {response.text}"
logger.error(error_msg)
result['msg'] = error_msg
return result
resp_json = response.json()
# Check if response contains id field
if 'id' not in resp_json:
error_msg = f"API response missing id field, response: {resp_json}"
logger.error(error_msg)
result['msg'] = error_msg
return result
job_id = resp_json['id']
result['status'] = True
result['data'] = job_id
result['msg'] = 'Task submitted successfully'
logger.info(f"Task submitted successfully, job ID: {job_id}")
except Exception as e:
import traceback
error_details = {
'error_type': type(e).__name__,
'error_message': str(e),
'traceback': traceback.format_exc()
}
logger.error(f"Failed to submit task: {error_details['error_type']}: {error_details['error_message']}")
logger.error(f"Traceback: {error_details['traceback']}")
result['msg'] = f"{error_details['error_type']}: {error_details['error_message']}"
result['error_details'] = error_details
return result
def query_task_status(self, job_id: str) -> Dict[str, Any]:
"""
Query task status.
Args:
job_id: Task ID to query
Returns:
Dictionary with task status
"""
result = {'status': False, 'data': None, 'msg': ''}
try:
import requests
headers = {
'Content-Type': 'application/json',
'Authorization': f'Bearer {self.api_key}',
}
response = requests.get(
f'{self.base_url}/contents/generations/tasks/{job_id}',
headers=headers,
timeout=30
)
if response.status_code != 200:
result['msg'] = f"API request failed, status code: {response.status_code}"
return result
resp_json = response.json()
# Parse response
task_status = resp_json.get('status', 'unknown')
result['msg'] = task_status
if task_status == 'succeeded':
result['status'] = True
result['data'] = resp_json.get('content', {}).get('video_url')
elif task_status in ['failed', 'cancelled']:
result['status'] = False
result['data'] = None
else:
# Still running, pending, or queued
result['status'] = False
result['data'] = None
except Exception as e:
logger.error(f"Failed to query task status: {str(e)}")
result['msg'] = str(e)
return result
def wait_for_completion(self, job_id: str, timeout: int = 180, interval: int = 2,
progress_callback: Optional[Callable[[str], None]] = None) -> Dict[str, Any]:
"""
Wait for task completion with progress updates.
Args:
job_id: Task ID to wait for
timeout: Maximum wait time in seconds
interval: Check interval in seconds
progress_callback: Optional callback for progress updates
Returns:
Dictionary with final result
"""
result = {'status': False, 'data': None, 'msg': ''}
end_time = time.time() + timeout
wait_count = 0
if progress_callback:
progress_callback(f"开始查询任务状态任务ID: {job_id}")
while time.time() < end_time:
status_result = self.query_task_status(job_id)
if status_result['status']:
# Task completed successfully
result['status'] = True
result['data'] = status_result['data']
result['msg'] = 'succeeded'
if progress_callback:
progress_callback("[完成] 视频生成完成!")
break
elif status_result['msg'] == 'running':
wait_count += 1
elapsed = wait_count * interval
remaining = max(0, timeout - elapsed)
progress_msg = f"[运行中] 任务运行中,已等待{elapsed}秒,预计剩余{remaining}秒..."
logger.info(progress_msg)
if progress_callback:
progress_callback(progress_msg)
# Send detailed progress via JSON-RPC
from .json_rpc import create_progress_reporter
progress = create_progress_reporter()
progress.update(
step="running",
progress=min(100, (elapsed / timeout) * 100),
message=progress_msg,
details={
"elapsed_seconds": elapsed,
"remaining_seconds": remaining,
"total_timeout": timeout
}
)
time.sleep(interval)
elif status_result['msg'] == 'failed':
result['msg'] = '任务执行失败'
if progress_callback:
progress_callback("[失败] 任务执行失败")
break
elif status_result['msg'] in ['pending', 'queued']:
wait_count += 1
elapsed = wait_count * interval
remaining = max(0, timeout - elapsed)
progress_msg = f"[排队中] 任务排队中,已等待{elapsed}秒,预计剩余{remaining}秒..."
logger.info(progress_msg)
if progress_callback:
progress_callback(progress_msg)
time.sleep(interval)
else:
# Unknown status, continue waiting
wait_count += 1
logger.info(f"未知状态: {status_result['msg']},继续等待...")
if progress_callback:
progress_callback(f"[未知] 状态: {status_result['msg']},继续等待...")
time.sleep(interval)
if not result['status'] and result['msg'] == '':
result['msg'] = '任务超时'
if progress_callback:
progress_callback(f"[超时] 任务查询超时({timeout}秒)")
return result