bw-mini-app-server/docs/ai-generation-service.md

15 KiB
Raw Blame History

AI图片/视频生成服务配置指南

1. 服务架构设计

1.1 AI生成服务流程

用户请求 → 积分校验 → 任务创建 → 队列处理 → AI模型调用 → 结果存储 → 用户通知

1.2 核心组件

  • 模板管理器: 管理AI生成模板的注册和执行
  • 任务管理器: 管理生成任务的生命周期
  • 积分系统: 校验和扣除用户积分
  • 文件存储: 处理输入图片和生成结果的存储
  • 异步队列: 处理耗时的AI生成任务

1.3 与模板系统集成

本服务与面向对象的模板管理系统深度集成:

  • 通过 TemplateService 执行具体的AI生成模板
  • 使用 GenerationTask 实体记录任务状态和结果
  • 模板信息存储在任务的 metadata 字段中

2. AI生成服务实现

2.1 生成任务服务 (与模板系统集成版本)

// src/services/generation.service.ts
import { Injectable } from '@nestjs/common';
import { InjectRepository } from '@nestjs/typeorm';
import { Repository } from 'typeorm';
import { GenerationTask, GenerationType, TaskStatus } from '../entities/generation-task.entity';
import { CreditService } from './credit.service';
import { TemplateService } from './template.service';
import { FileService } from './file.service';
import { MessageProducerService } from './message-producer.service';

@Injectable()
export class GenerationService {
  constructor(
    @InjectRepository(GenerationTask)
    private readonly taskRepository: Repository<GenerationTask>,
    private readonly creditService: CreditService,
    private readonly templateService: TemplateService,
    private readonly fileService: FileService,
    private readonly messageProducer: MessageProducerService,
  ) {}

  // 新版本:通过模板代码创建生成任务
  async createGenerationTaskByTemplate(createTaskDto: CreateTemplateTaskDto): Promise<GenerationTask> {
    const { userId, platform, templateCode, inputParameters } = createTaskDto;

    // 1. 获取模板信息
    const templateInfo = this.templateService.getTemplateInfo(templateCode);
    if (!templateInfo) {
      throw new Error(`模板 ${templateCode} 不存在`);
    }

    // 2. 校验用户积分
    const hasEnoughCredits = await this.creditService.checkBalance(userId, platform, templateInfo.creditCost);
    if (!hasEnoughCredits) {
      throw new Error('积分不足');
    }

    // 3. 上传输入图片(如果有)
    let inputImageUrl = null;
    if (inputParameters.inputImage) {
      inputImageUrl = await this.fileService.uploadImage(inputParameters.inputImage, userId);
      inputParameters.inputImage = inputImageUrl; // 替换为URL
    }

    // 4. 创建任务记录
    const task = this.taskRepository.create({
      userId,
      platform,
      type: templateInfo.category.includes('视频') ? GenerationType.VIDEO : GenerationType.IMAGE,
      prompt: inputParameters.prompt || inputParameters.clothingDescription || '',
      inputImageUrl,
      creditCost: templateInfo.creditCost,
      parameters: inputParameters,
      status: TaskStatus.PENDING,
      metadata: {
        templateCode: templateInfo.code,
        templateName: templateInfo.name,
        templateVersion: templateInfo.version,
        templateCategory: templateInfo.category,
      },
    });

    const savedTask = await this.taskRepository.save(task);

    // 5. 扣除积分
    await this.creditService.consumeCredits(userId, platform, templateInfo.creditCost, 'ai_generation', savedTask.id);

    // 6. 发送到处理队列
    await this.messageProducer.sendGenerationTask({
      taskId: savedTask.id,
      userId,
      platform,
      type: savedTask.type,
      templateCode,
      inputParameters,
    });

    return savedTask;
  }

  async processGenerationTask(taskId: string): Promise<void> {
    const task = await this.taskRepository.findOne({ where: { id: taskId } });
    if (!task) {
      throw new Error('任务不存在');
    }

    try {
      // 更新状态为处理中
      await this.updateTaskStatus(taskId, TaskStatus.PROCESSING);
      
      const startTime = Date.now();
      
      // 调用AI模型生成
      const result = await this.aiModelService.generate({
        type: task.type,
        prompt: task.prompt,
        inputImageUrl: task.inputImageUrl,
        parameters: task.parameters,
      });
      
      const processingTime = Math.floor((Date.now() - startTime) / 1000);
      
      // 上传生成结果
      const outputUrl = await this.fileService.uploadGeneratedContent(result.content, task.userId);
      const thumbnailUrl = await this.fileService.generateThumbnail(outputUrl, task.userId);
      
      // 更新任务结果
      await this.taskRepository.update(taskId, {
        status: TaskStatus.COMPLETED,
        outputUrl,
        thumbnailUrl,
        processingTime,
      });
      
      // 发送完成通知
      await this.messageProducer.sendGenerationCompleted({
        taskId,
        userId: task.userId,
        platform: task.platform,
        outputUrl,
        thumbnailUrl,
      });
      
    } catch (error) {
      // 处理失败,退还积分
      await this.creditService.refundCredits(
        task.userId, 
        task.platform, 
        task.creditCost, 
        'ai_generation_failed', 
        taskId
      );
      
      await this.taskRepository.update(taskId, {
        status: TaskStatus.FAILED,
        errorMessage: error.message,
      });
    }
  }

  private calculateCreditCost(type: GenerationType, parameters: any): number {
    // 根据生成类型和参数计算积分消耗 (与credit-and-ad-system.md保持一致)
    let baseCost = type === GenerationType.IMAGE ? 10 : 50; // 图片10积分视频50积分

    // 根据参数调整消耗
    if (type === GenerationType.IMAGE) {
      // 图片生成基础10积分高质量15积分
      if (parameters?.quality === 'high') {
        baseCost = 15;
      }
    } else {
      // 视频生成基础50积分高质量75积分
      if (parameters?.quality === 'high') {
        baseCost = 75;
      }
    }

    // 分辨率额外消耗
    if (parameters?.resolution === '4k') {
      baseCost = Math.ceil(baseCost * 1.5);
    }

    return baseCost;
  }

  async getUserTasks(userId: string, platform: string, page: number = 1, limit: number = 10) {
    const [tasks, total] = await this.taskRepository.findAndCount({
      where: { userId, platform },
      order: { createdAt: 'DESC' },
      skip: (page - 1) * limit,
      take: limit,
    });

    return {
      tasks,
      total,
      page,
      limit,
      totalPages: Math.ceil(total / limit),
    };
  }

  async getTaskById(taskId: string, userId: string): Promise<GenerationTask> {
    const task = await this.taskRepository.findOne({
      where: { id: taskId, userId },
    });
    
    if (!task) {
      throw new Error('任务不存在或无权限访问');
    }
    
    return task;
  }

  private async updateTaskStatus(taskId: string, status: TaskStatus): Promise<void> {
    await this.taskRepository.update(taskId, { status });
  }
}

2.2 AI模型适配器服务

// src/services/ai-model.service.ts
import { Injectable } from '@nestjs/common';
import { ConfigService } from '@nestjs/config';
import axios from 'axios';

export interface GenerationRequest {
  type: 'image' | 'video';
  prompt: string;
  inputImageUrl?: string;
  parameters?: any;
}

export interface GenerationResult {
  content: Buffer;
  contentType: string;
  metadata?: any;
}

@Injectable()
export class AIModelService {
  constructor(private readonly configService: ConfigService) {}

  async generate(request: GenerationRequest): Promise<GenerationResult> {
    if (request.type === 'image') {
      return this.generateImage(request);
    } else {
      return this.generateVideo(request);
    }
  }

  private async generateImage(request: GenerationRequest): Promise<GenerationResult> {
    // 示例调用Stable Diffusion API
    const apiUrl = this.configService.get('STABLE_DIFFUSION_API_URL');
    const apiKey = this.configService.get('STABLE_DIFFUSION_API_KEY');
    
    const payload = {
      prompt: request.prompt,
      negative_prompt: "low quality, blurry, distorted",
      width: request.parameters?.width || 512,
      height: request.parameters?.height || 512,
      steps: request.parameters?.steps || 20,
      cfg_scale: request.parameters?.cfg_scale || 7,
      sampler_name: request.parameters?.sampler || "DPM++ 2M Karras",
    };

    if (request.inputImageUrl) {
      // 图生图模式
      const inputImageBuffer = await this.downloadImage(request.inputImageUrl);
      payload['init_images'] = [inputImageBuffer.toString('base64')];
      payload['denoising_strength'] = request.parameters?.denoising_strength || 0.7;
    }

    const response = await axios.post(`${apiUrl}/sdapi/v1/txt2img`, payload, {
      headers: {
        'Authorization': `Bearer ${apiKey}`,
        'Content-Type': 'application/json',
      },
      timeout: 120000, // 2分钟超时
    });

    const imageBase64 = response.data.images[0];
    const imageBuffer = Buffer.from(imageBase64, 'base64');

    return {
      content: imageBuffer,
      contentType: 'image/png',
      metadata: {
        seed: response.data.info?.seed,
        parameters: payload,
      },
    };
  }

  private async generateVideo(request: GenerationRequest): Promise<GenerationResult> {
    // 示例调用RunwayML或其他视频生成API
    const apiUrl = this.configService.get('VIDEO_GENERATION_API_URL');
    const apiKey = this.configService.get('VIDEO_GENERATION_API_KEY');
    
    const payload = {
      prompt: request.prompt,
      duration: request.parameters?.duration || 5, // 5秒视频
      fps: request.parameters?.fps || 24,
      resolution: request.parameters?.resolution || '720p',
    };

    if (request.inputImageUrl) {
      payload['image_url'] = request.inputImageUrl;
    }

    // 创建生成任务
    const createResponse = await axios.post(`${apiUrl}/generate`, payload, {
      headers: {
        'Authorization': `Bearer ${apiKey}`,
        'Content-Type': 'application/json',
      },
    });

    const taskId = createResponse.data.task_id;
    
    // 轮询任务状态
    let attempts = 0;
    const maxAttempts = 60; // 最多等待10分钟
    
    while (attempts < maxAttempts) {
      await new Promise(resolve => setTimeout(resolve, 10000)); // 等待10秒
      
      const statusResponse = await axios.get(`${apiUrl}/task/${taskId}`, {
        headers: { 'Authorization': `Bearer ${apiKey}` },
      });
      
      if (statusResponse.data.status === 'completed') {
        const videoUrl = statusResponse.data.result_url;
        const videoBuffer = await this.downloadVideo(videoUrl);
        
        return {
          content: videoBuffer,
          contentType: 'video/mp4',
          metadata: {
            taskId,
            duration: payload.duration,
            fps: payload.fps,
          },
        };
      } else if (statusResponse.data.status === 'failed') {
        throw new Error(`视频生成失败: ${statusResponse.data.error}`);
      }
      
      attempts++;
    }
    
    throw new Error('视频生成超时');
  }

  private async downloadImage(url: string): Promise<Buffer> {
    const response = await axios.get(url, { responseType: 'arraybuffer' });
    return Buffer.from(response.data);
  }

  private async downloadVideo(url: string): Promise<Buffer> {
    const response = await axios.get(url, { responseType: 'arraybuffer' });
    return Buffer.from(response.data);
  }
}

3. 文件存储服务

3.1 文件管理服务

// src/services/file.service.ts
import { Injectable } from '@nestjs/common';
import { ConfigService } from '@nestjs/config';
import * as AWS from 'aws-sdk';
import * as sharp from 'sharp';
import { v4 as uuidv4 } from 'uuid';

@Injectable()
export class FileService {
  private s3: AWS.S3;
  private bucketName: string;

  constructor(private readonly configService: ConfigService) {
    this.s3 = new AWS.S3({
      accessKeyId: this.configService.get('AWS_ACCESS_KEY_ID'),
      secretAccessKey: this.configService.get('AWS_SECRET_ACCESS_KEY'),
      region: this.configService.get('AWS_REGION'),
    });
    this.bucketName = this.configService.get('AWS_S3_BUCKET');
  }

  async uploadImage(imageBuffer: Buffer, userId: string): Promise<string> {
    const fileName = `inputs/${userId}/${uuidv4()}.jpg`;
    
    // 压缩图片
    const compressedImage = await sharp(imageBuffer)
      .jpeg({ quality: 85 })
      .resize(1024, 1024, { fit: 'inside', withoutEnlargement: true })
      .toBuffer();

    const uploadParams = {
      Bucket: this.bucketName,
      Key: fileName,
      Body: compressedImage,
      ContentType: 'image/jpeg',
      ACL: 'public-read',
    };

    const result = await this.s3.upload(uploadParams).promise();
    return result.Location;
  }

  async uploadGeneratedContent(content: Buffer, userId: string): Promise<string> {
    const fileName = `outputs/${userId}/${uuidv4()}.png`;
    
    const uploadParams = {
      Bucket: this.bucketName,
      Key: fileName,
      Body: content,
      ContentType: 'image/png',
      ACL: 'public-read',
    };

    const result = await this.s3.upload(uploadParams).promise();
    return result.Location;
  }

  async generateThumbnail(imageUrl: string, userId: string): Promise<string> {
    // 下载原图
    const response = await fetch(imageUrl);
    const imageBuffer = Buffer.from(await response.arrayBuffer());
    
    // 生成缩略图
    const thumbnail = await sharp(imageBuffer)
      .resize(200, 200, { fit: 'cover' })
      .jpeg({ quality: 80 })
      .toBuffer();

    const fileName = `thumbnails/${userId}/${uuidv4()}.jpg`;
    
    const uploadParams = {
      Bucket: this.bucketName,
      Key: fileName,
      Body: thumbnail,
      ContentType: 'image/jpeg',
      ACL: 'public-read',
    };

    const result = await this.s3.upload(uploadParams).promise();
    return result.Location;
  }
}

4. 环境配置

4.1 AI服务配置

# Stable Diffusion API
STABLE_DIFFUSION_API_URL=https://api.stability.ai
STABLE_DIFFUSION_API_KEY=your-stability-api-key

# 视频生成API (示例RunwayML)
VIDEO_GENERATION_API_URL=https://api.runwayml.com
VIDEO_GENERATION_API_KEY=your-runway-api-key

# 文件存储配置
AWS_ACCESS_KEY_ID=your-aws-access-key
AWS_SECRET_ACCESS_KEY=your-aws-secret-key
AWS_REGION=us-east-1
AWS_S3_BUCKET=your-s3-bucket-name

# 积分配置
DEFAULT_IMAGE_CREDIT_COST=10
DEFAULT_VIDEO_CREDIT_COST=50
HIGH_QUALITY_MULTIPLIER=1.5

5. API接口示例

5.1 创建生成任务

@Post('generate')
@ApiOperation({ summary: '创建AI生成任务' })
async createTask(@Body() createTaskDto: CreateGenerationTaskDto) {
  return this.generationService.createGenerationTask(createTaskDto);
}

5.2 查询任务状态

@Get('tasks/:taskId')
@ApiOperation({ summary: '查询生成任务状态' })
async getTask(@Param('taskId') taskId: string, @CurrentUser() user: any) {
  return this.generationService.getTaskById(taskId, user.id);
}

5.3 获取用户任务列表

@Get('tasks')
@ApiOperation({ summary: '获取用户生成任务列表' })
async getUserTasks(
  @CurrentUser() user: any,
  @Query('page') page: number = 1,
  @Query('limit') limit: number = 10
) {
  return this.generationService.getUserTasks(user.id, user.platform, page, limit);
}

这个AI生成服务提供了完整的图片/视频生成功能,包括任务管理、积分校验、文件存储和异步处理等核心功能。