mxivideo/src/services/textVideoAgentAPI.ts

566 lines
14 KiB
TypeScript
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/**
* Text Video Agent API 工具库
* 基于 https://bowongai-dev--text-video-agent-fastapi-app.modal.run 的API封装
*/
// 基础配置
const API_BASE_URL = 'https://bowongai-dev--text-video-agent-fastapi-app.modal.run'
// 通用响应接口
export interface APIResponse<T = any> {
status: boolean
msg: string
data?: T
}
// 文件上传响应
export interface FileUploadResponse {
status: boolean
msg: string
data?: string // 文件URL
}
// 任务请求接口
export interface TaskRequest {
task_type?: string // 任务类型如: tea, chop, lady, vlog
prompt: string // 生图的提示词
img_url?: string // 参考图
ar?: string // 生成图片,视频的分辨率默认9:16
}
// 任务状态接口
export interface TaskStatus {
task_id: string
status: 'pending' | 'running' | 'completed' | 'failed'
progress?: number
result?: any
error?: string
}
// 图片生成参数
export interface ImageGenerationParams {
prompt: string
img_file?: File
max_wait_time?: number // 默认120秒
poll_interval?: number // 默认2秒
}
// 视频生成参数
export interface VideoGenerationParams {
prompt: string
img_url?: string
img_file?: File
duration?: string // 默认5秒
max_wait_time?: number // 默认300秒
poll_interval?: number // 默认5秒
}
// 图片描述参数
export interface ImageDescribeParams {
image_url?: string
img_file?: File
max_wait_time?: number // 默认120秒
poll_interval?: number // 默认2秒
}
/**
* HTTP请求工具类
*/
class HTTPClient {
private baseURL: string
constructor(baseURL: string) {
this.baseURL = baseURL
}
private async request<T>(
endpoint: string,
options: RequestInit = {}
): Promise<T> {
const url = `${this.baseURL}${endpoint}`
try {
const response = await fetch(url, {
...options,
headers: {
...options.headers,
},
})
if (!response.ok) {
throw new Error(`HTTP error! status: ${response.status}`)
}
const data = await response.json()
return data
} catch (error) {
console.error(`API request failed: ${url}`, error)
throw error
}
}
async get<T>(endpoint: string, params?: Record<string, any>): Promise<T> {
const url = new URL(endpoint, this.baseURL)
if (params) {
Object.entries(params).forEach(([key, value]) => {
if (value !== undefined && value !== null) {
url.searchParams.append(key, String(value))
}
})
}
return this.request<T>(url.pathname + url.search)
}
async post<T>(
endpoint: string,
data?: any,
options: RequestInit = {}
): Promise<T> {
return this.request<T>(endpoint, {
method: 'POST',
...options,
body: data,
})
}
async postJSON<T>(endpoint: string, data?: any): Promise<T> {
return this.post<T>(endpoint, JSON.stringify(data), {
headers: {
'Content-Type': 'application/json',
},
})
}
async postFormData<T>(endpoint: string, formData: FormData): Promise<T> {
return this.post<T>(endpoint, formData)
}
async postFormUrlEncoded<T>(
endpoint: string,
data: Record<string, any>
): Promise<T> {
const formData = new URLSearchParams()
Object.entries(data).forEach(([key, value]) => {
if (value !== undefined && value !== null) {
formData.append(key, String(value))
}
})
return this.post<T>(endpoint, formData, {
headers: {
'Content-Type': 'application/x-www-form-urlencoded',
},
})
}
}
/**
* Text Video Agent API 客户端
*/
export class TextVideoAgentAPI {
private client: HTTPClient
constructor(baseURL: string = API_BASE_URL) {
this.client = new HTTPClient(baseURL)
}
// ==================== 基础功能 ====================
/**
* 健康检查
*/
async healthCheck(): Promise<APIResponse> {
return this.client.get('/health')
}
/**
* 获取示例提示词
*/
async getSamplePrompt(taskType?: string): Promise<APIResponse> {
return this.client.get('/api/prompt/default', { task_type: taskType })
}
// ==================== 文件操作 ====================
/**
* 上传文件到COS
*/
async uploadFile(file: File): Promise<FileUploadResponse> {
const formData = new FormData()
formData.append('file', file)
return this.client.postFormData('/api/file/upload', formData)
}
// ==================== Midjourney图片生成 ====================
/**
* Midjourney健康检查
*/
async mjHealthCheck(): Promise<APIResponse> {
return this.client.get('/api/mj/health')
}
/**
* 同步生成图片(阻塞等待)
*/
async generateImageSync(params: ImageGenerationParams): Promise<APIResponse> {
const formData = new FormData()
if (params.img_file) {
formData.append('img_file', params.img_file)
}
const queryParams = {
prompt: params.prompt,
max_wait_time: params.max_wait_time,
poll_interval: params.poll_interval,
}
const url = new URL('/api/mj/sync/image', this.client['baseURL'])
Object.entries(queryParams).forEach(([key, value]) => {
if (value !== undefined && value !== null) {
url.searchParams.append(key, String(value))
}
})
return this.client.postFormData(url.pathname + url.search, formData)
}
/**
* 生成图片(推荐使用同步接口)
*/
async generateImage(params: ImageGenerationParams): Promise<APIResponse> {
const formData = new FormData()
formData.append('prompt', params.prompt)
if (params.img_file) {
formData.append('img_file', params.img_file)
}
if (params.max_wait_time) {
formData.append('max_wait_time', String(params.max_wait_time))
}
if (params.poll_interval) {
formData.append('poll_interval', String(params.poll_interval))
}
return this.client.postFormData('/api/mj/generate-image', formData)
}
/**
* 异步提交生图任务
*/
async generateImageAsync(prompt: string, imgFile?: File): Promise<APIResponse> {
const formData = new FormData()
if (imgFile) {
formData.append('img_file', imgFile)
}
return this.client.postFormData(`/api/mj/async/generate/image?prompt=${encodeURIComponent(prompt)}`, formData)
}
/**
* 查询异步任务状态
*/
async queryImageTaskStatus(taskId: string): Promise<APIResponse> {
return this.client.get('/api/mj/async/query/status', { task_id: taskId })
}
/**
* 通过URL获取图像描述
*/
async describeImageByUrl(params: ImageDescribeParams): Promise<APIResponse> {
if (!params.image_url) {
throw new Error('image_url is required')
}
return this.client.postFormUrlEncoded('/api/mj/sync/img/describe', {
image_url: params.image_url,
max_wait_time: params.max_wait_time,
poll_interval: params.poll_interval,
})
}
/**
* 通过文件获取图像描述
*/
async describeImageByFile(params: ImageDescribeParams): Promise<APIResponse> {
if (!params.img_file) {
throw new Error('img_file is required')
}
const formData = new FormData()
formData.append('img_file', params.img_file)
if (params.max_wait_time) {
formData.append('max_wait_time', String(params.max_wait_time))
}
if (params.poll_interval) {
formData.append('poll_interval', String(params.poll_interval))
}
return this.client.postFormData('/api/mj/sync/file/img/describe', formData)
}
// ==================== 极梦视频生成 ====================
/**
* 极梦健康检查
*/
async jmHealthCheck(): Promise<APIResponse> {
return this.client.get('/api/jm/health')
}
/**
* 同步生成视频(阻塞等待)
*/
async generateVideoSync(params: VideoGenerationParams): Promise<APIResponse> {
if (!params.img_url) {
throw new Error('img_url is required for sync video generation')
}
return this.client.postFormUrlEncoded('/api/jm/generate-video', {
prompt: params.prompt,
img_url: params.img_url,
duration: params.duration || '5',
max_wait_time: params.max_wait_time || 300,
poll_interval: params.poll_interval || 5,
})
}
/**
* 异步生成视频
*/
async generateVideoAsync(params: VideoGenerationParams): Promise<APIResponse> {
const formData = new FormData()
formData.append('prompt', params.prompt)
if (params.img_url) {
formData.append('img_url', params.img_url)
}
if (params.img_file) {
formData.append('img_file', params.img_file)
}
formData.append('duration', params.duration || '5')
return this.client.postFormData('/api/jm/async/generate/video', formData)
}
/**
* 查询视频生成任务状态
*/
async queryVideoTaskStatus(taskId: string): Promise<APIResponse> {
return this.client.get('/api/jm/async/query/status', { task_id: taskId })
}
// ==================== 任务管理 ====================
/**
* 创建异步任务(不阻塞)
*/
async createTask(request: TaskRequest): Promise<APIResponse> {
return this.client.postJSON('/api/task/create/task', request)
}
/**
* 同步查询任务结果(阻塞等待)
*/
async getTaskResultSync(taskId: string): Promise<APIResponse> {
return this.client.get(`/api/task/${taskId}`)
}
/**
* 异步查询任务状态(立即返回)
*/
async getTaskStatusAsync(taskId: string): Promise<APIResponse> {
return this.client.get(`/api/task/status/${taskId}`)
}
// ==================== 高级封装方法 ====================
/**
* 完整的图片生成流程(带重试机制)
*/
async generateImageWithRetry(
params: ImageGenerationParams,
maxRetries: number = 3
): Promise<APIResponse> {
let lastError: Error | null = null
for (let i = 0; i < maxRetries; i++) {
try {
const result = await this.generateImageSync(params)
if (result.status) {
return result
}
lastError = new Error(result.msg || 'Generation failed')
} catch (error) {
lastError = error as Error
console.warn(`Image generation attempt ${i + 1} failed:`, error)
// 如果不是最后一次重试,等待一段时间再重试
if (i < maxRetries - 1) {
await new Promise(resolve => setTimeout(resolve, 2000 * (i + 1)))
}
}
}
throw lastError || new Error('All retry attempts failed')
}
/**
* 完整的视频生成流程(带重试机制)
*/
async generateVideoWithRetry(
params: VideoGenerationParams,
maxRetries: number = 3
): Promise<APIResponse> {
let lastError: Error | null = null
for (let i = 0; i < maxRetries; i++) {
try {
const result = await this.generateVideoSync(params)
if (result.status) {
return result
}
lastError = new Error(result.msg || 'Generation failed')
} catch (error) {
lastError = error as Error
console.warn(`Video generation attempt ${i + 1} failed:`, error)
// 如果不是最后一次重试,等待一段时间再重试
if (i < maxRetries - 1) {
await new Promise(resolve => setTimeout(resolve, 5000 * (i + 1)))
}
}
}
throw lastError || new Error('All retry attempts failed')
}
/**
* 轮询任务状态直到完成
*/
async pollTaskUntilComplete(
taskId: string,
pollInterval: number = 2000,
maxWaitTime: number = 300000, // 5分钟
onProgress?: (status: any) => void
): Promise<APIResponse> {
const startTime = Date.now()
while (Date.now() - startTime < maxWaitTime) {
try {
const status = await this.getTaskStatusAsync(taskId)
if (onProgress) {
onProgress(status)
}
if (status.data?.status === 'completed') {
return status
}
if (status.data?.status === 'failed') {
throw new Error(status.data?.error || 'Task failed')
}
// 等待下次轮询
await new Promise(resolve => setTimeout(resolve, pollInterval))
} catch (error) {
console.error('Error polling task status:', error)
throw error
}
}
throw new Error('Task polling timeout')
}
/**
* 端到端的内容生成流程
*/
async generateContentEndToEnd(
prompt: string,
options: {
taskType?: string
referenceImageFile?: File
referenceImageUrl?: string
aspectRatio?: string
videoDuration?: string
generateVideo?: boolean
onProgress?: (step: string, progress: number) => void
} = {}
): Promise<{
imageUrl?: string
videoUrl?: string
taskId: string
}> {
const { onProgress } = options
try {
// 步骤1: 创建任务
onProgress?.('创建任务', 10)
const taskRequest: TaskRequest = {
task_type: options.taskType || 'vlog',
prompt,
img_url: options.referenceImageUrl,
ar: options.aspectRatio || '9:16'
}
const taskResult = await this.createTask(taskRequest)
const taskId = taskResult.data?.task_id
if (!taskId) {
throw new Error('Failed to create task')
}
// 步骤2: 生成图片
onProgress?.('生成图片', 30)
const imageParams: ImageGenerationParams = {
prompt,
img_file: options.referenceImageFile,
max_wait_time: 120
}
const imageResult = await this.generateImageWithRetry(imageParams)
const imageUrl = imageResult.data?.image_url || imageResult.data?.url
if (!imageUrl) {
throw new Error('Failed to generate image')
}
let videoUrl: string | undefined
// 步骤3: 生成视频(如果需要)
if (options.generateVideo) {
onProgress?.('生成视频', 60)
const videoParams: VideoGenerationParams = {
prompt,
img_url: imageUrl,
duration: options.videoDuration || '5',
max_wait_time: 300
}
const videoResult = await this.generateVideoWithRetry(videoParams)
videoUrl = videoResult.data?.video_url || videoResult.data?.url
}
onProgress?.('完成', 100)
return {
imageUrl,
videoUrl,
taskId
}
} catch (error) {
console.error('End-to-end generation failed:', error)
throw error
}
}
}
// 导出单例实例
export const textVideoAgentAPI = new TextVideoAgentAPI()
// 导出默认实例
export default textVideoAgentAPI