15 KiB
15 KiB
AI图片/视频生成服务配置指南
1. 服务架构设计
1.1 AI生成服务流程
用户请求 → 积分校验 → 任务创建 → 队列处理 → AI模型调用 → 结果存储 → 用户通知
1.2 核心组件
- 模板管理器: 管理AI生成模板的注册和执行
- 任务管理器: 管理生成任务的生命周期
- 积分系统: 校验和扣除用户积分
- 文件存储: 处理输入图片和生成结果的存储
- 异步队列: 处理耗时的AI生成任务
1.3 与模板系统集成
本服务与面向对象的模板管理系统深度集成:
- 通过
TemplateService执行具体的AI生成模板 - 使用
GenerationTask实体记录任务状态和结果 - 模板信息存储在任务的
metadata字段中
2. AI生成服务实现
2.1 生成任务服务 (与模板系统集成版本)
// src/services/generation.service.ts
import { Injectable } from '@nestjs/common';
import { InjectRepository } from '@nestjs/typeorm';
import { Repository } from 'typeorm';
import { GenerationTask, GenerationType, TaskStatus } from '../entities/generation-task.entity';
import { CreditService } from './credit.service';
import { TemplateService } from './template.service';
import { FileService } from './file.service';
import { MessageProducerService } from './message-producer.service';
@Injectable()
export class GenerationService {
constructor(
@InjectRepository(GenerationTask)
private readonly taskRepository: Repository<GenerationTask>,
private readonly creditService: CreditService,
private readonly templateService: TemplateService,
private readonly fileService: FileService,
private readonly messageProducer: MessageProducerService,
) {}
// 新版本:通过模板代码创建生成任务
async createGenerationTaskByTemplate(createTaskDto: CreateTemplateTaskDto): Promise<GenerationTask> {
const { userId, platform, templateCode, inputParameters } = createTaskDto;
// 1. 获取模板信息
const templateInfo = this.templateService.getTemplateInfo(templateCode);
if (!templateInfo) {
throw new Error(`模板 ${templateCode} 不存在`);
}
// 2. 校验用户积分
const hasEnoughCredits = await this.creditService.checkBalance(userId, platform, templateInfo.creditCost);
if (!hasEnoughCredits) {
throw new Error('积分不足');
}
// 3. 上传输入图片(如果有)
let inputImageUrl = null;
if (inputParameters.inputImage) {
inputImageUrl = await this.fileService.uploadImage(inputParameters.inputImage, userId);
inputParameters.inputImage = inputImageUrl; // 替换为URL
}
// 4. 创建任务记录
const task = this.taskRepository.create({
userId,
platform,
type: templateInfo.category.includes('视频') ? GenerationType.VIDEO : GenerationType.IMAGE,
prompt: inputParameters.prompt || inputParameters.clothingDescription || '',
inputImageUrl,
creditCost: templateInfo.creditCost,
parameters: inputParameters,
status: TaskStatus.PENDING,
metadata: {
templateCode: templateInfo.code,
templateName: templateInfo.name,
templateVersion: templateInfo.version,
templateCategory: templateInfo.category,
},
});
const savedTask = await this.taskRepository.save(task);
// 5. 扣除积分
await this.creditService.consumeCredits(userId, platform, templateInfo.creditCost, 'ai_generation', savedTask.id);
// 6. 发送到处理队列
await this.messageProducer.sendGenerationTask({
taskId: savedTask.id,
userId,
platform,
type: savedTask.type,
templateCode,
inputParameters,
});
return savedTask;
}
async processGenerationTask(taskId: string): Promise<void> {
const task = await this.taskRepository.findOne({ where: { id: taskId } });
if (!task) {
throw new Error('任务不存在');
}
try {
// 更新状态为处理中
await this.updateTaskStatus(taskId, TaskStatus.PROCESSING);
const startTime = Date.now();
// 调用AI模型生成
const result = await this.aiModelService.generate({
type: task.type,
prompt: task.prompt,
inputImageUrl: task.inputImageUrl,
parameters: task.parameters,
});
const processingTime = Math.floor((Date.now() - startTime) / 1000);
// 上传生成结果
const outputUrl = await this.fileService.uploadGeneratedContent(result.content, task.userId);
const thumbnailUrl = await this.fileService.generateThumbnail(outputUrl, task.userId);
// 更新任务结果
await this.taskRepository.update(taskId, {
status: TaskStatus.COMPLETED,
outputUrl,
thumbnailUrl,
processingTime,
});
// 发送完成通知
await this.messageProducer.sendGenerationCompleted({
taskId,
userId: task.userId,
platform: task.platform,
outputUrl,
thumbnailUrl,
});
} catch (error) {
// 处理失败,退还积分
await this.creditService.refundCredits(
task.userId,
task.platform,
task.creditCost,
'ai_generation_failed',
taskId
);
await this.taskRepository.update(taskId, {
status: TaskStatus.FAILED,
errorMessage: error.message,
});
}
}
private calculateCreditCost(type: GenerationType, parameters: any): number {
// 根据生成类型和参数计算积分消耗 (与credit-and-ad-system.md保持一致)
let baseCost = type === GenerationType.IMAGE ? 10 : 50; // 图片10积分,视频50积分
// 根据参数调整消耗
if (type === GenerationType.IMAGE) {
// 图片生成:基础10积分,高质量15积分
if (parameters?.quality === 'high') {
baseCost = 15;
}
} else {
// 视频生成:基础50积分,高质量75积分
if (parameters?.quality === 'high') {
baseCost = 75;
}
}
// 分辨率额外消耗
if (parameters?.resolution === '4k') {
baseCost = Math.ceil(baseCost * 1.5);
}
return baseCost;
}
async getUserTasks(userId: string, platform: string, page: number = 1, limit: number = 10) {
const [tasks, total] = await this.taskRepository.findAndCount({
where: { userId, platform },
order: { createdAt: 'DESC' },
skip: (page - 1) * limit,
take: limit,
});
return {
tasks,
total,
page,
limit,
totalPages: Math.ceil(total / limit),
};
}
async getTaskById(taskId: string, userId: string): Promise<GenerationTask> {
const task = await this.taskRepository.findOne({
where: { id: taskId, userId },
});
if (!task) {
throw new Error('任务不存在或无权限访问');
}
return task;
}
private async updateTaskStatus(taskId: string, status: TaskStatus): Promise<void> {
await this.taskRepository.update(taskId, { status });
}
}
2.2 AI模型适配器服务
// src/services/ai-model.service.ts
import { Injectable } from '@nestjs/common';
import { ConfigService } from '@nestjs/config';
import axios from 'axios';
export interface GenerationRequest {
type: 'image' | 'video';
prompt: string;
inputImageUrl?: string;
parameters?: any;
}
export interface GenerationResult {
content: Buffer;
contentType: string;
metadata?: any;
}
@Injectable()
export class AIModelService {
constructor(private readonly configService: ConfigService) {}
async generate(request: GenerationRequest): Promise<GenerationResult> {
if (request.type === 'image') {
return this.generateImage(request);
} else {
return this.generateVideo(request);
}
}
private async generateImage(request: GenerationRequest): Promise<GenerationResult> {
// 示例:调用Stable Diffusion API
const apiUrl = this.configService.get('STABLE_DIFFUSION_API_URL');
const apiKey = this.configService.get('STABLE_DIFFUSION_API_KEY');
const payload = {
prompt: request.prompt,
negative_prompt: "low quality, blurry, distorted",
width: request.parameters?.width || 512,
height: request.parameters?.height || 512,
steps: request.parameters?.steps || 20,
cfg_scale: request.parameters?.cfg_scale || 7,
sampler_name: request.parameters?.sampler || "DPM++ 2M Karras",
};
if (request.inputImageUrl) {
// 图生图模式
const inputImageBuffer = await this.downloadImage(request.inputImageUrl);
payload['init_images'] = [inputImageBuffer.toString('base64')];
payload['denoising_strength'] = request.parameters?.denoising_strength || 0.7;
}
const response = await axios.post(`${apiUrl}/sdapi/v1/txt2img`, payload, {
headers: {
'Authorization': `Bearer ${apiKey}`,
'Content-Type': 'application/json',
},
timeout: 120000, // 2分钟超时
});
const imageBase64 = response.data.images[0];
const imageBuffer = Buffer.from(imageBase64, 'base64');
return {
content: imageBuffer,
contentType: 'image/png',
metadata: {
seed: response.data.info?.seed,
parameters: payload,
},
};
}
private async generateVideo(request: GenerationRequest): Promise<GenerationResult> {
// 示例:调用RunwayML或其他视频生成API
const apiUrl = this.configService.get('VIDEO_GENERATION_API_URL');
const apiKey = this.configService.get('VIDEO_GENERATION_API_KEY');
const payload = {
prompt: request.prompt,
duration: request.parameters?.duration || 5, // 5秒视频
fps: request.parameters?.fps || 24,
resolution: request.parameters?.resolution || '720p',
};
if (request.inputImageUrl) {
payload['image_url'] = request.inputImageUrl;
}
// 创建生成任务
const createResponse = await axios.post(`${apiUrl}/generate`, payload, {
headers: {
'Authorization': `Bearer ${apiKey}`,
'Content-Type': 'application/json',
},
});
const taskId = createResponse.data.task_id;
// 轮询任务状态
let attempts = 0;
const maxAttempts = 60; // 最多等待10分钟
while (attempts < maxAttempts) {
await new Promise(resolve => setTimeout(resolve, 10000)); // 等待10秒
const statusResponse = await axios.get(`${apiUrl}/task/${taskId}`, {
headers: { 'Authorization': `Bearer ${apiKey}` },
});
if (statusResponse.data.status === 'completed') {
const videoUrl = statusResponse.data.result_url;
const videoBuffer = await this.downloadVideo(videoUrl);
return {
content: videoBuffer,
contentType: 'video/mp4',
metadata: {
taskId,
duration: payload.duration,
fps: payload.fps,
},
};
} else if (statusResponse.data.status === 'failed') {
throw new Error(`视频生成失败: ${statusResponse.data.error}`);
}
attempts++;
}
throw new Error('视频生成超时');
}
private async downloadImage(url: string): Promise<Buffer> {
const response = await axios.get(url, { responseType: 'arraybuffer' });
return Buffer.from(response.data);
}
private async downloadVideo(url: string): Promise<Buffer> {
const response = await axios.get(url, { responseType: 'arraybuffer' });
return Buffer.from(response.data);
}
}
3. 文件存储服务
3.1 文件管理服务
// src/services/file.service.ts
import { Injectable } from '@nestjs/common';
import { ConfigService } from '@nestjs/config';
import * as AWS from 'aws-sdk';
import * as sharp from 'sharp';
import { v4 as uuidv4 } from 'uuid';
@Injectable()
export class FileService {
private s3: AWS.S3;
private bucketName: string;
constructor(private readonly configService: ConfigService) {
this.s3 = new AWS.S3({
accessKeyId: this.configService.get('AWS_ACCESS_KEY_ID'),
secretAccessKey: this.configService.get('AWS_SECRET_ACCESS_KEY'),
region: this.configService.get('AWS_REGION'),
});
this.bucketName = this.configService.get('AWS_S3_BUCKET');
}
async uploadImage(imageBuffer: Buffer, userId: string): Promise<string> {
const fileName = `inputs/${userId}/${uuidv4()}.jpg`;
// 压缩图片
const compressedImage = await sharp(imageBuffer)
.jpeg({ quality: 85 })
.resize(1024, 1024, { fit: 'inside', withoutEnlargement: true })
.toBuffer();
const uploadParams = {
Bucket: this.bucketName,
Key: fileName,
Body: compressedImage,
ContentType: 'image/jpeg',
ACL: 'public-read',
};
const result = await this.s3.upload(uploadParams).promise();
return result.Location;
}
async uploadGeneratedContent(content: Buffer, userId: string): Promise<string> {
const fileName = `outputs/${userId}/${uuidv4()}.png`;
const uploadParams = {
Bucket: this.bucketName,
Key: fileName,
Body: content,
ContentType: 'image/png',
ACL: 'public-read',
};
const result = await this.s3.upload(uploadParams).promise();
return result.Location;
}
async generateThumbnail(imageUrl: string, userId: string): Promise<string> {
// 下载原图
const response = await fetch(imageUrl);
const imageBuffer = Buffer.from(await response.arrayBuffer());
// 生成缩略图
const thumbnail = await sharp(imageBuffer)
.resize(200, 200, { fit: 'cover' })
.jpeg({ quality: 80 })
.toBuffer();
const fileName = `thumbnails/${userId}/${uuidv4()}.jpg`;
const uploadParams = {
Bucket: this.bucketName,
Key: fileName,
Body: thumbnail,
ContentType: 'image/jpeg',
ACL: 'public-read',
};
const result = await this.s3.upload(uploadParams).promise();
return result.Location;
}
}
4. 环境配置
4.1 AI服务配置
# Stable Diffusion API
STABLE_DIFFUSION_API_URL=https://api.stability.ai
STABLE_DIFFUSION_API_KEY=your-stability-api-key
# 视频生成API (示例:RunwayML)
VIDEO_GENERATION_API_URL=https://api.runwayml.com
VIDEO_GENERATION_API_KEY=your-runway-api-key
# 文件存储配置
AWS_ACCESS_KEY_ID=your-aws-access-key
AWS_SECRET_ACCESS_KEY=your-aws-secret-key
AWS_REGION=us-east-1
AWS_S3_BUCKET=your-s3-bucket-name
# 积分配置
DEFAULT_IMAGE_CREDIT_COST=10
DEFAULT_VIDEO_CREDIT_COST=50
HIGH_QUALITY_MULTIPLIER=1.5
5. API接口示例
5.1 创建生成任务
@Post('generate')
@ApiOperation({ summary: '创建AI生成任务' })
async createTask(@Body() createTaskDto: CreateGenerationTaskDto) {
return this.generationService.createGenerationTask(createTaskDto);
}
5.2 查询任务状态
@Get('tasks/:taskId')
@ApiOperation({ summary: '查询生成任务状态' })
async getTask(@Param('taskId') taskId: string, @CurrentUser() user: any) {
return this.generationService.getTaskById(taskId, user.id);
}
5.3 获取用户任务列表
@Get('tasks')
@ApiOperation({ summary: '获取用户生成任务列表' })
async getUserTasks(
@CurrentUser() user: any,
@Query('page') page: number = 1,
@Query('limit') limit: number = 10
) {
return this.generationService.getUserTasks(user.id, user.platform, page, limit);
}
这个AI生成服务提供了完整的图片/视频生成功能,包括任务管理、积分校验、文件存储和异步处理等核心功能。