566 lines
14 KiB
TypeScript
566 lines
14 KiB
TypeScript
/**
|
||
* Text Video Agent API 工具库
|
||
* 基于 https://bowongai-dev--text-video-agent-fastapi-app.modal.run 的API封装
|
||
*/
|
||
|
||
// 基础配置
|
||
const API_BASE_URL = 'https://bowongai-dev--text-video-agent-fastapi-app.modal.run'
|
||
|
||
// 通用响应接口
|
||
export interface APIResponse<T = any> {
|
||
status: boolean
|
||
msg: string
|
||
data?: T
|
||
}
|
||
|
||
// 文件上传响应
|
||
export interface FileUploadResponse {
|
||
status: boolean
|
||
msg: string
|
||
data?: string // 文件URL
|
||
}
|
||
|
||
// 任务请求接口
|
||
export interface TaskRequest {
|
||
task_type?: string // 任务类型如: tea, chop, lady, vlog
|
||
prompt: string // 生图的提示词
|
||
img_url?: string // 参考图
|
||
ar?: string // 生成图片,视频的分辨率,默认9:16
|
||
}
|
||
|
||
// 任务状态接口
|
||
export interface TaskStatus {
|
||
task_id: string
|
||
status: 'pending' | 'running' | 'completed' | 'failed'
|
||
progress?: number
|
||
result?: any
|
||
error?: string
|
||
}
|
||
|
||
// 图片生成参数
|
||
export interface ImageGenerationParams {
|
||
prompt: string
|
||
img_file?: File
|
||
max_wait_time?: number // 默认120秒
|
||
poll_interval?: number // 默认2秒
|
||
}
|
||
|
||
// 视频生成参数
|
||
export interface VideoGenerationParams {
|
||
prompt: string
|
||
img_url?: string
|
||
img_file?: File
|
||
duration?: string // 默认5秒
|
||
max_wait_time?: number // 默认300秒
|
||
poll_interval?: number // 默认5秒
|
||
}
|
||
|
||
// 图片描述参数
|
||
export interface ImageDescribeParams {
|
||
image_url?: string
|
||
img_file?: File
|
||
max_wait_time?: number // 默认120秒
|
||
poll_interval?: number // 默认2秒
|
||
}
|
||
|
||
/**
|
||
* HTTP请求工具类
|
||
*/
|
||
class HTTPClient {
|
||
private baseURL: string
|
||
|
||
constructor(baseURL: string) {
|
||
this.baseURL = baseURL
|
||
}
|
||
|
||
private async request<T>(
|
||
endpoint: string,
|
||
options: RequestInit = {}
|
||
): Promise<T> {
|
||
const url = `${this.baseURL}${endpoint}`
|
||
|
||
try {
|
||
const response = await fetch(url, {
|
||
...options,
|
||
headers: {
|
||
...options.headers,
|
||
},
|
||
})
|
||
|
||
if (!response.ok) {
|
||
throw new Error(`HTTP error! status: ${response.status}`)
|
||
}
|
||
|
||
const data = await response.json()
|
||
return data
|
||
} catch (error) {
|
||
console.error(`API request failed: ${url}`, error)
|
||
throw error
|
||
}
|
||
}
|
||
|
||
async get<T>(endpoint: string, params?: Record<string, any>): Promise<T> {
|
||
const url = new URL(endpoint, this.baseURL)
|
||
if (params) {
|
||
Object.entries(params).forEach(([key, value]) => {
|
||
if (value !== undefined && value !== null) {
|
||
url.searchParams.append(key, String(value))
|
||
}
|
||
})
|
||
}
|
||
|
||
return this.request<T>(url.pathname + url.search)
|
||
}
|
||
|
||
async post<T>(
|
||
endpoint: string,
|
||
data?: any,
|
||
options: RequestInit = {}
|
||
): Promise<T> {
|
||
return this.request<T>(endpoint, {
|
||
method: 'POST',
|
||
...options,
|
||
body: data,
|
||
})
|
||
}
|
||
|
||
async postJSON<T>(endpoint: string, data?: any): Promise<T> {
|
||
return this.post<T>(endpoint, JSON.stringify(data), {
|
||
headers: {
|
||
'Content-Type': 'application/json',
|
||
},
|
||
})
|
||
}
|
||
|
||
async postFormData<T>(endpoint: string, formData: FormData): Promise<T> {
|
||
return this.post<T>(endpoint, formData)
|
||
}
|
||
|
||
async postFormUrlEncoded<T>(
|
||
endpoint: string,
|
||
data: Record<string, any>
|
||
): Promise<T> {
|
||
const formData = new URLSearchParams()
|
||
Object.entries(data).forEach(([key, value]) => {
|
||
if (value !== undefined && value !== null) {
|
||
formData.append(key, String(value))
|
||
}
|
||
})
|
||
|
||
return this.post<T>(endpoint, formData, {
|
||
headers: {
|
||
'Content-Type': 'application/x-www-form-urlencoded',
|
||
},
|
||
})
|
||
}
|
||
}
|
||
|
||
/**
|
||
* Text Video Agent API 客户端
|
||
*/
|
||
export class TextVideoAgentAPI {
|
||
private client: HTTPClient
|
||
|
||
constructor(baseURL: string = API_BASE_URL) {
|
||
this.client = new HTTPClient(baseURL)
|
||
}
|
||
|
||
// ==================== 基础功能 ====================
|
||
|
||
/**
|
||
* 健康检查
|
||
*/
|
||
async healthCheck(): Promise<APIResponse> {
|
||
return this.client.get('/health')
|
||
}
|
||
|
||
/**
|
||
* 获取示例提示词
|
||
*/
|
||
async getSamplePrompt(taskType?: string): Promise<APIResponse> {
|
||
return this.client.get('/api/prompt/default', { task_type: taskType })
|
||
}
|
||
|
||
// ==================== 文件操作 ====================
|
||
|
||
/**
|
||
* 上传文件到COS
|
||
*/
|
||
async uploadFile(file: File): Promise<FileUploadResponse> {
|
||
const formData = new FormData()
|
||
formData.append('file', file)
|
||
return this.client.postFormData('/api/file/upload', formData)
|
||
}
|
||
|
||
// ==================== Midjourney图片生成 ====================
|
||
|
||
/**
|
||
* Midjourney健康检查
|
||
*/
|
||
async mjHealthCheck(): Promise<APIResponse> {
|
||
return this.client.get('/api/mj/health')
|
||
}
|
||
|
||
/**
|
||
* 同步生成图片(阻塞等待)
|
||
*/
|
||
async generateImageSync(params: ImageGenerationParams): Promise<APIResponse> {
|
||
const formData = new FormData()
|
||
if (params.img_file) {
|
||
formData.append('img_file', params.img_file)
|
||
}
|
||
|
||
const queryParams = {
|
||
prompt: params.prompt,
|
||
max_wait_time: params.max_wait_time,
|
||
poll_interval: params.poll_interval,
|
||
}
|
||
|
||
const url = new URL('/api/mj/sync/image', this.client['baseURL'])
|
||
Object.entries(queryParams).forEach(([key, value]) => {
|
||
if (value !== undefined && value !== null) {
|
||
url.searchParams.append(key, String(value))
|
||
}
|
||
})
|
||
|
||
return this.client.postFormData(url.pathname + url.search, formData)
|
||
}
|
||
|
||
/**
|
||
* 生成图片(推荐使用同步接口)
|
||
*/
|
||
async generateImage(params: ImageGenerationParams): Promise<APIResponse> {
|
||
const formData = new FormData()
|
||
formData.append('prompt', params.prompt)
|
||
if (params.img_file) {
|
||
formData.append('img_file', params.img_file)
|
||
}
|
||
if (params.max_wait_time) {
|
||
formData.append('max_wait_time', String(params.max_wait_time))
|
||
}
|
||
if (params.poll_interval) {
|
||
formData.append('poll_interval', String(params.poll_interval))
|
||
}
|
||
|
||
return this.client.postFormData('/api/mj/generate-image', formData)
|
||
}
|
||
|
||
/**
|
||
* 异步提交生图任务
|
||
*/
|
||
async generateImageAsync(prompt: string, imgFile?: File): Promise<APIResponse> {
|
||
const formData = new FormData()
|
||
if (imgFile) {
|
||
formData.append('img_file', imgFile)
|
||
}
|
||
|
||
return this.client.postFormData(`/api/mj/async/generate/image?prompt=${encodeURIComponent(prompt)}`, formData)
|
||
}
|
||
|
||
/**
|
||
* 查询异步任务状态
|
||
*/
|
||
async queryImageTaskStatus(taskId: string): Promise<APIResponse> {
|
||
return this.client.get('/api/mj/async/query/status', { task_id: taskId })
|
||
}
|
||
|
||
/**
|
||
* 通过URL获取图像描述
|
||
*/
|
||
async describeImageByUrl(params: ImageDescribeParams): Promise<APIResponse> {
|
||
if (!params.image_url) {
|
||
throw new Error('image_url is required')
|
||
}
|
||
|
||
return this.client.postFormUrlEncoded('/api/mj/sync/img/describe', {
|
||
image_url: params.image_url,
|
||
max_wait_time: params.max_wait_time,
|
||
poll_interval: params.poll_interval,
|
||
})
|
||
}
|
||
|
||
/**
|
||
* 通过文件获取图像描述
|
||
*/
|
||
async describeImageByFile(params: ImageDescribeParams): Promise<APIResponse> {
|
||
if (!params.img_file) {
|
||
throw new Error('img_file is required')
|
||
}
|
||
|
||
const formData = new FormData()
|
||
formData.append('img_file', params.img_file)
|
||
if (params.max_wait_time) {
|
||
formData.append('max_wait_time', String(params.max_wait_time))
|
||
}
|
||
if (params.poll_interval) {
|
||
formData.append('poll_interval', String(params.poll_interval))
|
||
}
|
||
|
||
return this.client.postFormData('/api/mj/sync/file/img/describe', formData)
|
||
}
|
||
|
||
// ==================== 极梦视频生成 ====================
|
||
|
||
/**
|
||
* 极梦健康检查
|
||
*/
|
||
async jmHealthCheck(): Promise<APIResponse> {
|
||
return this.client.get('/api/jm/health')
|
||
}
|
||
|
||
/**
|
||
* 同步生成视频(阻塞等待)
|
||
*/
|
||
async generateVideoSync(params: VideoGenerationParams): Promise<APIResponse> {
|
||
if (!params.img_url) {
|
||
throw new Error('img_url is required for sync video generation')
|
||
}
|
||
|
||
return this.client.postFormUrlEncoded('/api/jm/generate-video', {
|
||
prompt: params.prompt,
|
||
img_url: params.img_url,
|
||
duration: params.duration || '5',
|
||
max_wait_time: params.max_wait_time || 300,
|
||
poll_interval: params.poll_interval || 5,
|
||
})
|
||
}
|
||
|
||
/**
|
||
* 异步生成视频
|
||
*/
|
||
async generateVideoAsync(params: VideoGenerationParams): Promise<APIResponse> {
|
||
const formData = new FormData()
|
||
formData.append('prompt', params.prompt)
|
||
|
||
if (params.img_url) {
|
||
formData.append('img_url', params.img_url)
|
||
}
|
||
|
||
if (params.img_file) {
|
||
formData.append('img_file', params.img_file)
|
||
}
|
||
|
||
formData.append('duration', params.duration || '5')
|
||
|
||
return this.client.postFormData('/api/jm/async/generate/video', formData)
|
||
}
|
||
|
||
/**
|
||
* 查询视频生成任务状态
|
||
*/
|
||
async queryVideoTaskStatus(taskId: string): Promise<APIResponse> {
|
||
return this.client.get('/api/jm/async/query/status', { task_id: taskId })
|
||
}
|
||
|
||
// ==================== 任务管理 ====================
|
||
|
||
/**
|
||
* 创建异步任务(不阻塞)
|
||
*/
|
||
async createTask(request: TaskRequest): Promise<APIResponse> {
|
||
return this.client.postJSON('/api/task/create/task', request)
|
||
}
|
||
|
||
/**
|
||
* 同步查询任务结果(阻塞等待)
|
||
*/
|
||
async getTaskResultSync(taskId: string): Promise<APIResponse> {
|
||
return this.client.get(`/api/task/${taskId}`)
|
||
}
|
||
|
||
/**
|
||
* 异步查询任务状态(立即返回)
|
||
*/
|
||
async getTaskStatusAsync(taskId: string): Promise<APIResponse> {
|
||
return this.client.get(`/api/task/status/${taskId}`)
|
||
}
|
||
|
||
// ==================== 高级封装方法 ====================
|
||
|
||
/**
|
||
* 完整的图片生成流程(带重试机制)
|
||
*/
|
||
async generateImageWithRetry(
|
||
params: ImageGenerationParams,
|
||
maxRetries: number = 3
|
||
): Promise<APIResponse> {
|
||
let lastError: Error | null = null
|
||
|
||
for (let i = 0; i < maxRetries; i++) {
|
||
try {
|
||
const result = await this.generateImageSync(params)
|
||
if (result.status) {
|
||
return result
|
||
}
|
||
lastError = new Error(result.msg || 'Generation failed')
|
||
} catch (error) {
|
||
lastError = error as Error
|
||
console.warn(`Image generation attempt ${i + 1} failed:`, error)
|
||
|
||
// 如果不是最后一次重试,等待一段时间再重试
|
||
if (i < maxRetries - 1) {
|
||
await new Promise(resolve => setTimeout(resolve, 2000 * (i + 1)))
|
||
}
|
||
}
|
||
}
|
||
|
||
throw lastError || new Error('All retry attempts failed')
|
||
}
|
||
|
||
/**
|
||
* 完整的视频生成流程(带重试机制)
|
||
*/
|
||
async generateVideoWithRetry(
|
||
params: VideoGenerationParams,
|
||
maxRetries: number = 3
|
||
): Promise<APIResponse> {
|
||
let lastError: Error | null = null
|
||
|
||
for (let i = 0; i < maxRetries; i++) {
|
||
try {
|
||
const result = await this.generateVideoSync(params)
|
||
if (result.status) {
|
||
return result
|
||
}
|
||
lastError = new Error(result.msg || 'Generation failed')
|
||
} catch (error) {
|
||
lastError = error as Error
|
||
console.warn(`Video generation attempt ${i + 1} failed:`, error)
|
||
|
||
// 如果不是最后一次重试,等待一段时间再重试
|
||
if (i < maxRetries - 1) {
|
||
await new Promise(resolve => setTimeout(resolve, 5000 * (i + 1)))
|
||
}
|
||
}
|
||
}
|
||
|
||
throw lastError || new Error('All retry attempts failed')
|
||
}
|
||
|
||
/**
|
||
* 轮询任务状态直到完成
|
||
*/
|
||
async pollTaskUntilComplete(
|
||
taskId: string,
|
||
pollInterval: number = 2000,
|
||
maxWaitTime: number = 300000, // 5分钟
|
||
onProgress?: (status: any) => void
|
||
): Promise<APIResponse> {
|
||
const startTime = Date.now()
|
||
|
||
while (Date.now() - startTime < maxWaitTime) {
|
||
try {
|
||
const status = await this.getTaskStatusAsync(taskId)
|
||
|
||
if (onProgress) {
|
||
onProgress(status)
|
||
}
|
||
|
||
if (status.data?.status === 'completed') {
|
||
return status
|
||
}
|
||
|
||
if (status.data?.status === 'failed') {
|
||
throw new Error(status.data?.error || 'Task failed')
|
||
}
|
||
|
||
// 等待下次轮询
|
||
await new Promise(resolve => setTimeout(resolve, pollInterval))
|
||
} catch (error) {
|
||
console.error('Error polling task status:', error)
|
||
throw error
|
||
}
|
||
}
|
||
|
||
throw new Error('Task polling timeout')
|
||
}
|
||
|
||
/**
|
||
* 端到端的内容生成流程
|
||
*/
|
||
async generateContentEndToEnd(
|
||
prompt: string,
|
||
options: {
|
||
taskType?: string
|
||
referenceImageFile?: File
|
||
referenceImageUrl?: string
|
||
aspectRatio?: string
|
||
videoDuration?: string
|
||
generateVideo?: boolean
|
||
onProgress?: (step: string, progress: number) => void
|
||
} = {}
|
||
): Promise<{
|
||
imageUrl?: string
|
||
videoUrl?: string
|
||
taskId: string
|
||
}> {
|
||
const { onProgress } = options
|
||
|
||
try {
|
||
// 步骤1: 创建任务
|
||
onProgress?.('创建任务', 10)
|
||
const taskRequest: TaskRequest = {
|
||
task_type: options.taskType || 'vlog',
|
||
prompt,
|
||
img_url: options.referenceImageUrl,
|
||
ar: options.aspectRatio || '9:16'
|
||
}
|
||
|
||
const taskResult = await this.createTask(taskRequest)
|
||
const taskId = taskResult.data?.task_id
|
||
|
||
if (!taskId) {
|
||
throw new Error('Failed to create task')
|
||
}
|
||
|
||
// 步骤2: 生成图片
|
||
onProgress?.('生成图片', 30)
|
||
const imageParams: ImageGenerationParams = {
|
||
prompt,
|
||
img_file: options.referenceImageFile,
|
||
max_wait_time: 120
|
||
}
|
||
|
||
const imageResult = await this.generateImageWithRetry(imageParams)
|
||
const imageUrl = imageResult.data?.image_url || imageResult.data?.url
|
||
|
||
if (!imageUrl) {
|
||
throw new Error('Failed to generate image')
|
||
}
|
||
|
||
let videoUrl: string | undefined
|
||
|
||
// 步骤3: 生成视频(如果需要)
|
||
if (options.generateVideo) {
|
||
onProgress?.('生成视频', 60)
|
||
const videoParams: VideoGenerationParams = {
|
||
prompt,
|
||
img_url: imageUrl,
|
||
duration: options.videoDuration || '5',
|
||
max_wait_time: 300
|
||
}
|
||
|
||
const videoResult = await this.generateVideoWithRetry(videoParams)
|
||
videoUrl = videoResult.data?.video_url || videoResult.data?.url
|
||
}
|
||
|
||
onProgress?.('完成', 100)
|
||
|
||
return {
|
||
imageUrl,
|
||
videoUrl,
|
||
taskId
|
||
}
|
||
} catch (error) {
|
||
console.error('End-to-end generation failed:', error)
|
||
throw error
|
||
}
|
||
}
|
||
}
|
||
|
||
// 导出单例实例
|
||
export const textVideoAgentAPI = new TextVideoAgentAPI()
|
||
|
||
// 导出默认实例
|
||
export default textVideoAgentAPI
|