392 lines
10 KiB
TypeScript
392 lines
10 KiB
TypeScript
import { DataFlow } from '../types/canvas';
|
||
|
||
/**
|
||
* AI服务配置接口
|
||
*/
|
||
export interface AIServiceConfig {
|
||
openai: {
|
||
apiKey: string;
|
||
baseURL: string;
|
||
model: string;
|
||
};
|
||
imageGeneration: {
|
||
provider: 'openai' | 'midjourney' | 'stable-diffusion';
|
||
apiKey: string;
|
||
defaultSettings: any;
|
||
};
|
||
videoGeneration: {
|
||
provider: 'runway' | 'pika' | 'stable-video';
|
||
apiKey: string;
|
||
defaultSettings: any;
|
||
};
|
||
}
|
||
|
||
/**
|
||
* AI服务管理器
|
||
*/
|
||
export class AIServiceManager {
|
||
private config: AIServiceConfig;
|
||
private baseURL: string;
|
||
|
||
constructor(config: AIServiceConfig) {
|
||
this.config = config;
|
||
this.baseURL = config.openai.baseURL;
|
||
}
|
||
|
||
/**
|
||
* 提示词优化
|
||
*/
|
||
async optimizePrompt(
|
||
prompt: string,
|
||
options: {
|
||
style?: 'creative' | 'detailed' | 'concise';
|
||
language?: 'zh' | 'en';
|
||
} = {},
|
||
onProgress?: (progress: number) => void
|
||
): Promise<string> {
|
||
try {
|
||
// 模拟进度更新
|
||
onProgress?.(20);
|
||
|
||
const response = await this.makeRequest('/api/prompt/optimize', {
|
||
prompt,
|
||
style: options.style || 'detailed',
|
||
language: options.language || 'zh'
|
||
});
|
||
|
||
onProgress?.(100);
|
||
return response.optimizedPrompt;
|
||
} catch (error) {
|
||
// 如果API调用失败,使用本地模拟
|
||
return this.mockOptimizePrompt(prompt, onProgress);
|
||
}
|
||
}
|
||
|
||
/**
|
||
* 图片生成
|
||
*/
|
||
async generateImage(
|
||
prompt: string,
|
||
options: {
|
||
size?: '512x512' | '1024x1024';
|
||
style?: 'realistic' | 'artistic' | 'anime';
|
||
quality?: 'standard' | 'hd';
|
||
} = {},
|
||
onProgress?: (progress: number) => void
|
||
): Promise<string> {
|
||
try {
|
||
onProgress?.(10);
|
||
|
||
const response = await this.makeRequest('/api/image/generate', {
|
||
prompt,
|
||
size: options.size || '512x512',
|
||
style: options.style || 'realistic',
|
||
quality: options.quality || 'standard'
|
||
});
|
||
|
||
// 如果返回任务ID,轮询进度
|
||
if (response.taskId) {
|
||
return await this.pollImageProgress(response.taskId, onProgress);
|
||
}
|
||
|
||
onProgress?.(100);
|
||
return response.imageUrl;
|
||
} catch (error) {
|
||
// 如果API调用失败,使用本地模拟
|
||
return this.mockGenerateImage(prompt, onProgress);
|
||
}
|
||
}
|
||
|
||
/**
|
||
* 图片编辑
|
||
*/
|
||
async editImage(
|
||
imageUrl: string,
|
||
instruction: string,
|
||
options: {
|
||
editType?: 'style-transfer' | 'background-remove' | 'enhance' | 'resize';
|
||
} = {},
|
||
onProgress?: (progress: number) => void
|
||
): Promise<string> {
|
||
try {
|
||
onProgress?.(15);
|
||
|
||
const response = await this.makeRequest('/api/image/edit', {
|
||
imageUrl,
|
||
instruction,
|
||
editType: options.editType || 'enhance'
|
||
});
|
||
|
||
if (response.taskId) {
|
||
return await this.pollImageProgress(response.taskId, onProgress);
|
||
}
|
||
|
||
onProgress?.(100);
|
||
return response.editedImageUrl;
|
||
} catch (error) {
|
||
return this.mockEditImage(imageUrl, onProgress);
|
||
}
|
||
}
|
||
|
||
/**
|
||
* 视频生成
|
||
*/
|
||
async generateVideo(
|
||
imageUrl: string,
|
||
motionPrompt: string,
|
||
options: {
|
||
duration?: number;
|
||
fps?: number;
|
||
quality?: 'low' | 'medium' | 'high';
|
||
} = {},
|
||
onProgress?: (progress: number) => void
|
||
): Promise<string> {
|
||
try {
|
||
onProgress?.(5);
|
||
|
||
const response = await this.makeRequest('/api/video/generate', {
|
||
imageUrl,
|
||
motionPrompt,
|
||
duration: options.duration || 3,
|
||
fps: options.fps || 24,
|
||
quality: options.quality || 'medium'
|
||
});
|
||
|
||
if (response.taskId) {
|
||
return await this.pollVideoProgress(response.taskId, onProgress);
|
||
}
|
||
|
||
onProgress?.(100);
|
||
return response.videoUrl;
|
||
} catch (error) {
|
||
return this.mockGenerateVideo(imageUrl, onProgress);
|
||
}
|
||
}
|
||
|
||
/**
|
||
* 批量处理
|
||
*/
|
||
async processBatch(
|
||
items: any[],
|
||
processType: 'prompt-optimize' | 'image-generate' | 'image-edit' | 'video-generate',
|
||
onProgress?: (completed: number, total: number) => void
|
||
): Promise<any[]> {
|
||
const results = [];
|
||
const total = items.length;
|
||
|
||
for (let i = 0; i < total; i++) {
|
||
try {
|
||
let result;
|
||
const item = items[i];
|
||
|
||
switch (processType) {
|
||
case 'prompt-optimize':
|
||
result = await this.optimizePrompt(item.prompt || item);
|
||
break;
|
||
case 'image-generate':
|
||
result = await this.generateImage(item.prompt || item);
|
||
break;
|
||
case 'image-edit':
|
||
result = await this.editImage(item.imageUrl, item.instruction);
|
||
break;
|
||
case 'video-generate':
|
||
result = await this.generateVideo(item.imageUrl, item.motionPrompt);
|
||
break;
|
||
default:
|
||
throw new Error(`Unknown process type: ${processType}`);
|
||
}
|
||
|
||
results.push(result);
|
||
onProgress?.(i + 1, total);
|
||
} catch (error) {
|
||
results.push({ error: error instanceof Error ? error.message : 'Processing failed' });
|
||
onProgress?.(i + 1, total);
|
||
}
|
||
}
|
||
|
||
return results;
|
||
}
|
||
|
||
/**
|
||
* 轮询图片生成进度
|
||
*/
|
||
private async pollImageProgress(taskId: string, onProgress?: (progress: number) => void): Promise<string> {
|
||
while (true) {
|
||
try {
|
||
const status = await this.makeRequest(`/api/image/progress/${taskId}`, null, 'GET');
|
||
|
||
if (onProgress) onProgress(status.progress || 0);
|
||
|
||
if (status.status === 'completed') {
|
||
return status.imageUrl;
|
||
}
|
||
|
||
if (status.status === 'failed') {
|
||
throw new Error(status.error || 'Image generation failed');
|
||
}
|
||
|
||
await new Promise(resolve => setTimeout(resolve, 2000));
|
||
} catch (error) {
|
||
// 如果轮询失败,返回模拟结果
|
||
return this.mockGenerateImage('', onProgress);
|
||
}
|
||
}
|
||
}
|
||
|
||
/**
|
||
* 轮询视频生成进度
|
||
*/
|
||
private async pollVideoProgress(taskId: string, onProgress?: (progress: number) => void): Promise<string> {
|
||
while (true) {
|
||
try {
|
||
const status = await this.makeRequest(`/api/video/progress/${taskId}`, null, 'GET');
|
||
|
||
if (onProgress) onProgress(status.progress || 0);
|
||
|
||
if (status.status === 'completed') {
|
||
return status.videoUrl;
|
||
}
|
||
|
||
if (status.status === 'failed') {
|
||
throw new Error(status.error || 'Video generation failed');
|
||
}
|
||
|
||
await new Promise(resolve => setTimeout(resolve, 3000));
|
||
} catch (error) {
|
||
return this.mockGenerateVideo('', onProgress);
|
||
}
|
||
}
|
||
}
|
||
|
||
/**
|
||
* 发起API请求
|
||
*/
|
||
private async makeRequest(endpoint: string, data?: any, method: 'GET' | 'POST' = 'POST'): Promise<any> {
|
||
const response = await fetch(`${this.baseURL}${endpoint}`, {
|
||
method,
|
||
headers: {
|
||
'Content-Type': 'application/json',
|
||
'Authorization': `Bearer ${this.config.openai.apiKey}`
|
||
},
|
||
body: data ? JSON.stringify(data) : undefined
|
||
});
|
||
|
||
if (!response.ok) {
|
||
throw new Error(`API Error: ${response.status} ${response.statusText}`);
|
||
}
|
||
|
||
return response.json();
|
||
}
|
||
|
||
// 模拟方法 - 当API不可用时使用
|
||
private async mockOptimizePrompt(prompt: string, onProgress?: (progress: number) => void): Promise<string> {
|
||
for (let i = 0; i <= 100; i += 20) {
|
||
onProgress?.(i);
|
||
await new Promise(resolve => setTimeout(resolve, 200));
|
||
}
|
||
return `优化后的提示词: ${prompt},添加了更多细节描述,包括光线、构图、风格等元素,使其更适合AI图像生成。`;
|
||
}
|
||
|
||
private async mockGenerateImage(prompt: string, onProgress?: (progress: number) => void): Promise<string> {
|
||
for (let i = 0; i <= 100; i += 10) {
|
||
onProgress?.(i);
|
||
await new Promise(resolve => setTimeout(resolve, 300));
|
||
}
|
||
return `https://picsum.photos/512/512?random=${Date.now()}`;
|
||
}
|
||
|
||
private async mockEditImage(imageUrl: string, onProgress?: (progress: number) => void): Promise<string> {
|
||
for (let i = 0; i <= 100; i += 15) {
|
||
onProgress?.(i);
|
||
await new Promise(resolve => setTimeout(resolve, 250));
|
||
}
|
||
return `https://picsum.photos/512/512?random=${Date.now()}&edit=true`;
|
||
}
|
||
|
||
private async mockGenerateVideo(imageUrl: string, onProgress?: (progress: number) => void): Promise<string> {
|
||
for (let i = 0; i <= 100; i += 5) {
|
||
onProgress?.(i);
|
||
await new Promise(resolve => setTimeout(resolve, 500));
|
||
}
|
||
return `https://sample-videos.com/zip/10/mp4/SampleVideo_360x240_1mb.mp4?t=${Date.now()}`;
|
||
}
|
||
}
|
||
|
||
/**
|
||
* 错误处理工具
|
||
*/
|
||
export class ErrorHandler {
|
||
static async withRetry<T>(
|
||
operation: () => Promise<T>,
|
||
maxRetries = 3,
|
||
delay = 1000
|
||
): Promise<T> {
|
||
for (let i = 0; i < maxRetries; i++) {
|
||
try {
|
||
return await operation();
|
||
} catch (error) {
|
||
if (i === maxRetries - 1) throw error;
|
||
await new Promise(resolve => setTimeout(resolve, delay * Math.pow(2, i)));
|
||
}
|
||
}
|
||
throw new Error('Max retries exceeded');
|
||
}
|
||
|
||
static handleAPIError(error: any): string {
|
||
if (error.response?.status === 429) {
|
||
return '请求过于频繁,请稍后重试';
|
||
}
|
||
if (error.response?.status === 401) {
|
||
return 'API密钥无效,请检查配置';
|
||
}
|
||
if (error.response?.status >= 500) {
|
||
return '服务器错误,请稍后重试';
|
||
}
|
||
return error.message || '未知错误';
|
||
}
|
||
}
|
||
|
||
/**
|
||
* 加载AI服务配置
|
||
*/
|
||
export const loadAIConfig = (): AIServiceConfig => {
|
||
// 从本地存储加载配置,如果没有则使用默认值
|
||
const savedConfig = localStorage.getItem('aiConfig');
|
||
if (savedConfig) {
|
||
try {
|
||
return JSON.parse(savedConfig);
|
||
} catch (error) {
|
||
console.warn('Failed to parse saved AI config:', error);
|
||
}
|
||
}
|
||
|
||
// 默认配置
|
||
return {
|
||
openai: {
|
||
apiKey: '',
|
||
baseURL: 'https://api.openai.com/v1',
|
||
model: 'gpt-4'
|
||
},
|
||
imageGeneration: {
|
||
provider: 'openai',
|
||
apiKey: '',
|
||
defaultSettings: {
|
||
size: '512x512',
|
||
quality: 'standard'
|
||
}
|
||
},
|
||
videoGeneration: {
|
||
provider: 'runway',
|
||
apiKey: '',
|
||
defaultSettings: {
|
||
duration: 3,
|
||
fps: 24,
|
||
quality: 'medium'
|
||
}
|
||
}
|
||
};
|
||
};
|
||
|
||
// 创建默认的AI服务实例
|
||
export const aiService = new AIServiceManager(loadAIConfig());
|