bw-mini-app-server/docs/ai-generation-service.md

520 lines
15 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# AI图片/视频生成服务配置指南
## 1. 服务架构设计
### 1.1 AI生成服务流程
```
用户请求 → 积分校验 → 任务创建 → 队列处理 → AI模型调用 → 结果存储 → 用户通知
```
### 1.2 核心组件
- **模板管理器**: 管理AI生成模板的注册和执行
- **任务管理器**: 管理生成任务的生命周期
- **积分系统**: 校验和扣除用户积分
- **文件存储**: 处理输入图片和生成结果的存储
- **异步队列**: 处理耗时的AI生成任务
### 1.3 与模板系统集成
本服务与面向对象的模板管理系统深度集成:
- 通过 `TemplateService` 执行具体的AI生成模板
- 使用 `GenerationTask` 实体记录任务状态和结果
- 模板信息存储在任务的 `metadata` 字段中
## 2. AI生成服务实现
### 2.1 生成任务服务 (与模板系统集成版本)
```typescript
// src/services/generation.service.ts
import { Injectable } from '@nestjs/common';
import { InjectRepository } from '@nestjs/typeorm';
import { Repository } from 'typeorm';
import { GenerationTask, GenerationType, TaskStatus } from '../entities/generation-task.entity';
import { CreditService } from './credit.service';
import { TemplateService } from './template.service';
import { FileService } from './file.service';
import { MessageProducerService } from './message-producer.service';
@Injectable()
export class GenerationService {
constructor(
@InjectRepository(GenerationTask)
private readonly taskRepository: Repository<GenerationTask>,
private readonly creditService: CreditService,
private readonly templateService: TemplateService,
private readonly fileService: FileService,
private readonly messageProducer: MessageProducerService,
) {}
// 新版本:通过模板代码创建生成任务
async createGenerationTaskByTemplate(createTaskDto: CreateTemplateTaskDto): Promise<GenerationTask> {
const { userId, platform, templateCode, inputParameters } = createTaskDto;
// 1. 获取模板信息
const templateInfo = this.templateService.getTemplateInfo(templateCode);
if (!templateInfo) {
throw new Error(`模板 ${templateCode} 不存在`);
}
// 2. 校验用户积分
const hasEnoughCredits = await this.creditService.checkBalance(userId, platform, templateInfo.creditCost);
if (!hasEnoughCredits) {
throw new Error('积分不足');
}
// 3. 上传输入图片(如果有)
let inputImageUrl = null;
if (inputParameters.inputImage) {
inputImageUrl = await this.fileService.uploadImage(inputParameters.inputImage, userId);
inputParameters.inputImage = inputImageUrl; // 替换为URL
}
// 4. 创建任务记录
const task = this.taskRepository.create({
userId,
platform,
type: templateInfo.category.includes('视频') ? GenerationType.VIDEO : GenerationType.IMAGE,
prompt: inputParameters.prompt || inputParameters.clothingDescription || '',
inputImageUrl,
creditCost: templateInfo.creditCost,
parameters: inputParameters,
status: TaskStatus.PENDING,
metadata: {
templateCode: templateInfo.code,
templateName: templateInfo.name,
templateVersion: templateInfo.version,
templateCategory: templateInfo.category,
},
});
const savedTask = await this.taskRepository.save(task);
// 5. 扣除积分
await this.creditService.consumeCredits(userId, platform, templateInfo.creditCost, 'ai_generation', savedTask.id);
// 6. 发送到处理队列
await this.messageProducer.sendGenerationTask({
taskId: savedTask.id,
userId,
platform,
type: savedTask.type,
templateCode,
inputParameters,
});
return savedTask;
}
async processGenerationTask(taskId: string): Promise<void> {
const task = await this.taskRepository.findOne({ where: { id: taskId } });
if (!task) {
throw new Error('任务不存在');
}
try {
// 更新状态为处理中
await this.updateTaskStatus(taskId, TaskStatus.PROCESSING);
const startTime = Date.now();
// 调用AI模型生成
const result = await this.aiModelService.generate({
type: task.type,
prompt: task.prompt,
inputImageUrl: task.inputImageUrl,
parameters: task.parameters,
});
const processingTime = Math.floor((Date.now() - startTime) / 1000);
// 上传生成结果
const outputUrl = await this.fileService.uploadGeneratedContent(result.content, task.userId);
const thumbnailUrl = await this.fileService.generateThumbnail(outputUrl, task.userId);
// 更新任务结果
await this.taskRepository.update(taskId, {
status: TaskStatus.COMPLETED,
outputUrl,
thumbnailUrl,
processingTime,
});
// 发送完成通知
await this.messageProducer.sendGenerationCompleted({
taskId,
userId: task.userId,
platform: task.platform,
outputUrl,
thumbnailUrl,
});
} catch (error) {
// 处理失败,退还积分
await this.creditService.refundCredits(
task.userId,
task.platform,
task.creditCost,
'ai_generation_failed',
taskId
);
await this.taskRepository.update(taskId, {
status: TaskStatus.FAILED,
errorMessage: error.message,
});
}
}
private calculateCreditCost(type: GenerationType, parameters: any): number {
// 根据生成类型和参数计算积分消耗 (与credit-and-ad-system.md保持一致)
let baseCost = type === GenerationType.IMAGE ? 10 : 50; // 图片10积分视频50积分
// 根据参数调整消耗
if (type === GenerationType.IMAGE) {
// 图片生成基础10积分高质量15积分
if (parameters?.quality === 'high') {
baseCost = 15;
}
} else {
// 视频生成基础50积分高质量75积分
if (parameters?.quality === 'high') {
baseCost = 75;
}
}
// 分辨率额外消耗
if (parameters?.resolution === '4k') {
baseCost = Math.ceil(baseCost * 1.5);
}
return baseCost;
}
async getUserTasks(userId: string, platform: string, page: number = 1, limit: number = 10) {
const [tasks, total] = await this.taskRepository.findAndCount({
where: { userId, platform },
order: { createdAt: 'DESC' },
skip: (page - 1) * limit,
take: limit,
});
return {
tasks,
total,
page,
limit,
totalPages: Math.ceil(total / limit),
};
}
async getTaskById(taskId: string, userId: string): Promise<GenerationTask> {
const task = await this.taskRepository.findOne({
where: { id: taskId, userId },
});
if (!task) {
throw new Error('任务不存在或无权限访问');
}
return task;
}
private async updateTaskStatus(taskId: string, status: TaskStatus): Promise<void> {
await this.taskRepository.update(taskId, { status });
}
}
```
### 2.2 AI模型适配器服务
```typescript
// src/services/ai-model.service.ts
import { Injectable } from '@nestjs/common';
import { ConfigService } from '@nestjs/config';
import axios from 'axios';
export interface GenerationRequest {
type: 'image' | 'video';
prompt: string;
inputImageUrl?: string;
parameters?: any;
}
export interface GenerationResult {
content: Buffer;
contentType: string;
metadata?: any;
}
@Injectable()
export class AIModelService {
constructor(private readonly configService: ConfigService) {}
async generate(request: GenerationRequest): Promise<GenerationResult> {
if (request.type === 'image') {
return this.generateImage(request);
} else {
return this.generateVideo(request);
}
}
private async generateImage(request: GenerationRequest): Promise<GenerationResult> {
// 示例调用Stable Diffusion API
const apiUrl = this.configService.get('STABLE_DIFFUSION_API_URL');
const apiKey = this.configService.get('STABLE_DIFFUSION_API_KEY');
const payload = {
prompt: request.prompt,
negative_prompt: "low quality, blurry, distorted",
width: request.parameters?.width || 512,
height: request.parameters?.height || 512,
steps: request.parameters?.steps || 20,
cfg_scale: request.parameters?.cfg_scale || 7,
sampler_name: request.parameters?.sampler || "DPM++ 2M Karras",
};
if (request.inputImageUrl) {
// 图生图模式
const inputImageBuffer = await this.downloadImage(request.inputImageUrl);
payload['init_images'] = [inputImageBuffer.toString('base64')];
payload['denoising_strength'] = request.parameters?.denoising_strength || 0.7;
}
const response = await axios.post(`${apiUrl}/sdapi/v1/txt2img`, payload, {
headers: {
'Authorization': `Bearer ${apiKey}`,
'Content-Type': 'application/json',
},
timeout: 120000, // 2分钟超时
});
const imageBase64 = response.data.images[0];
const imageBuffer = Buffer.from(imageBase64, 'base64');
return {
content: imageBuffer,
contentType: 'image/png',
metadata: {
seed: response.data.info?.seed,
parameters: payload,
},
};
}
private async generateVideo(request: GenerationRequest): Promise<GenerationResult> {
// 示例调用RunwayML或其他视频生成API
const apiUrl = this.configService.get('VIDEO_GENERATION_API_URL');
const apiKey = this.configService.get('VIDEO_GENERATION_API_KEY');
const payload = {
prompt: request.prompt,
duration: request.parameters?.duration || 5, // 5秒视频
fps: request.parameters?.fps || 24,
resolution: request.parameters?.resolution || '720p',
};
if (request.inputImageUrl) {
payload['image_url'] = request.inputImageUrl;
}
// 创建生成任务
const createResponse = await axios.post(`${apiUrl}/generate`, payload, {
headers: {
'Authorization': `Bearer ${apiKey}`,
'Content-Type': 'application/json',
},
});
const taskId = createResponse.data.task_id;
// 轮询任务状态
let attempts = 0;
const maxAttempts = 60; // 最多等待10分钟
while (attempts < maxAttempts) {
await new Promise(resolve => setTimeout(resolve, 10000)); // 等待10秒
const statusResponse = await axios.get(`${apiUrl}/task/${taskId}`, {
headers: { 'Authorization': `Bearer ${apiKey}` },
});
if (statusResponse.data.status === 'completed') {
const videoUrl = statusResponse.data.result_url;
const videoBuffer = await this.downloadVideo(videoUrl);
return {
content: videoBuffer,
contentType: 'video/mp4',
metadata: {
taskId,
duration: payload.duration,
fps: payload.fps,
},
};
} else if (statusResponse.data.status === 'failed') {
throw new Error(`视频生成失败: ${statusResponse.data.error}`);
}
attempts++;
}
throw new Error('视频生成超时');
}
private async downloadImage(url: string): Promise<Buffer> {
const response = await axios.get(url, { responseType: 'arraybuffer' });
return Buffer.from(response.data);
}
private async downloadVideo(url: string): Promise<Buffer> {
const response = await axios.get(url, { responseType: 'arraybuffer' });
return Buffer.from(response.data);
}
}
```
## 3. 文件存储服务
### 3.1 文件管理服务
```typescript
// src/services/file.service.ts
import { Injectable } from '@nestjs/common';
import { ConfigService } from '@nestjs/config';
import * as AWS from 'aws-sdk';
import * as sharp from 'sharp';
import { v4 as uuidv4 } from 'uuid';
@Injectable()
export class FileService {
private s3: AWS.S3;
private bucketName: string;
constructor(private readonly configService: ConfigService) {
this.s3 = new AWS.S3({
accessKeyId: this.configService.get('AWS_ACCESS_KEY_ID'),
secretAccessKey: this.configService.get('AWS_SECRET_ACCESS_KEY'),
region: this.configService.get('AWS_REGION'),
});
this.bucketName = this.configService.get('AWS_S3_BUCKET');
}
async uploadImage(imageBuffer: Buffer, userId: string): Promise<string> {
const fileName = `inputs/${userId}/${uuidv4()}.jpg`;
// 压缩图片
const compressedImage = await sharp(imageBuffer)
.jpeg({ quality: 85 })
.resize(1024, 1024, { fit: 'inside', withoutEnlargement: true })
.toBuffer();
const uploadParams = {
Bucket: this.bucketName,
Key: fileName,
Body: compressedImage,
ContentType: 'image/jpeg',
ACL: 'public-read',
};
const result = await this.s3.upload(uploadParams).promise();
return result.Location;
}
async uploadGeneratedContent(content: Buffer, userId: string): Promise<string> {
const fileName = `outputs/${userId}/${uuidv4()}.png`;
const uploadParams = {
Bucket: this.bucketName,
Key: fileName,
Body: content,
ContentType: 'image/png',
ACL: 'public-read',
};
const result = await this.s3.upload(uploadParams).promise();
return result.Location;
}
async generateThumbnail(imageUrl: string, userId: string): Promise<string> {
// 下载原图
const response = await fetch(imageUrl);
const imageBuffer = Buffer.from(await response.arrayBuffer());
// 生成缩略图
const thumbnail = await sharp(imageBuffer)
.resize(200, 200, { fit: 'cover' })
.jpeg({ quality: 80 })
.toBuffer();
const fileName = `thumbnails/${userId}/${uuidv4()}.jpg`;
const uploadParams = {
Bucket: this.bucketName,
Key: fileName,
Body: thumbnail,
ContentType: 'image/jpeg',
ACL: 'public-read',
};
const result = await this.s3.upload(uploadParams).promise();
return result.Location;
}
}
```
## 4. 环境配置
### 4.1 AI服务配置
```env
# Stable Diffusion API
STABLE_DIFFUSION_API_URL=https://api.stability.ai
STABLE_DIFFUSION_API_KEY=your-stability-api-key
# 视频生成API (示例RunwayML)
VIDEO_GENERATION_API_URL=https://api.runwayml.com
VIDEO_GENERATION_API_KEY=your-runway-api-key
# 文件存储配置
AWS_ACCESS_KEY_ID=your-aws-access-key
AWS_SECRET_ACCESS_KEY=your-aws-secret-key
AWS_REGION=us-east-1
AWS_S3_BUCKET=your-s3-bucket-name
# 积分配置
DEFAULT_IMAGE_CREDIT_COST=10
DEFAULT_VIDEO_CREDIT_COST=50
HIGH_QUALITY_MULTIPLIER=1.5
```
## 5. API接口示例
### 5.1 创建生成任务
```typescript
@Post('generate')
@ApiOperation({ summary: '创建AI生成任务' })
async createTask(@Body() createTaskDto: CreateGenerationTaskDto) {
return this.generationService.createGenerationTask(createTaskDto);
}
```
### 5.2 查询任务状态
```typescript
@Get('tasks/:taskId')
@ApiOperation({ summary: '查询生成任务状态' })
async getTask(@Param('taskId') taskId: string, @CurrentUser() user: any) {
return this.generationService.getTaskById(taskId, user.id);
}
```
### 5.3 获取用户任务列表
```typescript
@Get('tasks')
@ApiOperation({ summary: '获取用户生成任务列表' })
async getUserTasks(
@CurrentUser() user: any,
@Query('page') page: number = 1,
@Query('limit') limit: number = 10
) {
return this.generationService.getUserTasks(user.id, user.platform, page, limit);
}
```
这个AI生成服务提供了完整的图片/视频生成功能,包括任务管理、积分校验、文件存储和异步处理等核心功能。