From 6307b216d64fba82b21de09c1c77f7a6af80cd98 Mon Sep 17 00:00:00 2001 From: iHeyTang Date: Thu, 4 Sep 2025 23:23:01 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=E7=94=A8=E6=88=B7?= =?UTF-8?q?=E4=BB=BB=E5=8A=A1=E9=99=90=E5=88=B6=E6=A3=80=E6=9F=A5=E5=8A=9F?= =?UTF-8?q?=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 在 executeTemplateByCode 方法中添加用户当前任务数量限制检查,确保用户在执行新任务前不超过最大并发任务数。 - 新增 checkUserTaskLimit 方法,查询用户正在进行的任务并判断是否满足执行条件。 - 更新 Swagger API 文档,添加任务数量限制的响应示例。 --- src/controllers/template.controller.ts | 131 +++++++++++++++++++------ 1 file changed, 100 insertions(+), 31 deletions(-) diff --git a/src/controllers/template.controller.ts b/src/controllers/template.controller.ts index aa21a00..cf93cf2 100644 --- a/src/controllers/template.controller.ts +++ b/src/controllers/template.controller.ts @@ -34,7 +34,11 @@ import { TemplateListDto, BatchExecuteDto, } from '../dto/template.dto'; -import { TemplateExecutionEntity, ExecutionStatus, ExecutionType } from '../entities/template-execution.entity'; +import { + TemplateExecutionEntity, + ExecutionStatus, + ExecutionType, +} from '../entities/template-execution.entity'; import { InjectRepository } from '@nestjs/typeorm'; import { Repository } from 'typeorm'; import { ApiCommonResponses } from '../decorators/api-common-responses.decorator'; @@ -51,7 +55,7 @@ export class TemplateController { private readonly templateRepository: Repository, @InjectRepository(TemplateExecutionEntity) private readonly executionRepository: Repository, - ) { } + ) {} @Post(':templateId/execute') @ApiOperation({ @@ -127,10 +131,22 @@ export class TemplateController { description: '执行成功', type: TemplateExecuteResponseDto, }) + @SwaggerApiResponse({ + status: 429, + description: '任务数量限制', + schema: { + example: { + code: 429, + message: + '当前账号有3个任务正在进行中,且距开始时间不足5分钟,请稍后再试', + data: null, + }, + }, + }) async executeTemplateByCode( @Param('code') code: string, @Body() body: { imageUrl: string }, - @Request() req + @Request() req, ): Promise> { try { const { imageUrl } = body; @@ -139,23 +155,30 @@ export class TemplateController { throw new HttpException('imageUrl is required', HttpStatus.BAD_REQUEST); } + const userId = req.user.userId; + + // 检查用户当前的任务限制 + await this.checkUserTaskLimit(userId); + // 首先获取模板配置以确定模板类型 const templateConfig = await this.templateFactory.getTemplateByCode(code); if (!templateConfig) { throw new HttpException('Template not found', HttpStatus.NOT_FOUND); } - + // 通过模板代码创建实例并执行 const template = await this.templateFactory.createTemplateByCode(code); const taskId = await template.execute(imageUrl); - + // 将任务保存到 TemplateExecutionEntity - const userId = req.user.userId; const execution = this.executionRepository.create({ templateId: templateConfig.id, userId, platform: req.user.platform, - type: templateConfig.templateType === TemplateType.VIDEO ? ExecutionType.VIDEO : ExecutionType.IMAGE, + type: + templateConfig.templateType === TemplateType.VIDEO + ? ExecutionType.VIDEO + : ExecutionType.IMAGE, prompt: '', // 可以从请求参数中获取,如果有的话 inputImageUrl: imageUrl, taskId: taskId, // 保存外部系统返回的任务ID,用于回调匹配 @@ -164,13 +187,13 @@ export class TemplateController { creditCost: templateConfig.creditCost, startedAt: new Date(), }); - + const savedExecution = await this.executionRepository.save(execution); // 返回任务id (执行记录的ID) return ResponseUtil.success(savedExecution.id, '模板执行已启动'); } catch (error) { - console.error(error) + console.error(error); throw new HttpException( error.message || 'Template execution failed', HttpStatus.INTERNAL_SERVER_ERROR, @@ -178,6 +201,47 @@ export class TemplateController { } } + /** + * 检查用户当前任务限制 + * @param userId 用户ID + * @throws HttpException 如果超出任务限制 + */ + private async checkUserTaskLimit(userId: string): Promise { + // 查询用户当前进行中的任务 + const processingTasks = await this.executionRepository.find({ + where: { + userId, + status: ExecutionStatus.PROCESSING, + }, + order: { + startedAt: 'ASC', + }, + }); + + // 如果当前没有进行中的任务,允许执行 + if (processingTasks.length === 0) { + return; + } + + // 如果当前进行中任务数量已达到3个,需要检查时间限制 + if (processingTasks.length >= 3) { + const now = new Date(); + const fiveMinutesAgo = new Date(now.getTime() - 5 * 60 * 1000); + + // 检查是否有任务在5分钟内开始 + const recentTasks = processingTasks.filter( + (task) => task.startedAt && task.startedAt > fiveMinutesAgo, + ); + + if (recentTasks.length > 0) { + throw new HttpException( + `有${processingTasks.length}个任务正在进行中,请稍后再试`, + HttpStatus.TOO_MANY_REQUESTS, + ); + } + } + } + @Post(':templateId/batch-execute') @ApiOperation({ summary: '批量执行模板', @@ -432,7 +496,9 @@ export class TemplateController { * @returns 执行进度信息 */ @Get('execution/:taskId/progress') - async getExecutionProgress(@Param('taskId', ParseIntPipe) taskId: number): Promise> { + async getExecutionProgress( + @Param('taskId', ParseIntPipe) taskId: number, + ): Promise> { try { const execution = await this.executionRepository.findOne({ where: { id: taskId }, @@ -446,27 +512,30 @@ export class TemplateController { ); } - return ResponseUtil.success({ - taskId: execution.id, - templateId: execution.templateId, - templateName: execution.template?.name, - userId: execution.userId, - platform: execution.platform, - type: execution.type, - status: execution.status, - progress: execution.progress, - inputImageUrl: execution.inputImageUrl, - outputUrl: execution.outputUrl, - thumbnailUrl: execution.thumbnailUrl, - errorMessage: execution.errorMessage, - creditCost: execution.creditCost, - startedAt: execution.startedAt, - completedAt: execution.completedAt, - executionDuration: execution.executionDuration, - createdAt: execution.createdAt, - updatedAt: execution.updatedAt, - executionResult: execution.executionResult - }, '获取执行进度成功'); + return ResponseUtil.success( + { + taskId: execution.id, + templateId: execution.templateId, + templateName: execution.template?.name, + userId: execution.userId, + platform: execution.platform, + type: execution.type, + status: execution.status, + progress: execution.progress, + inputImageUrl: execution.inputImageUrl, + outputUrl: execution.outputUrl, + thumbnailUrl: execution.thumbnailUrl, + errorMessage: execution.errorMessage, + creditCost: execution.creditCost, + startedAt: execution.startedAt, + completedAt: execution.completedAt, + executionDuration: execution.executionDuration, + createdAt: execution.createdAt, + updatedAt: execution.updatedAt, + executionResult: execution.executionResult, + }, + '获取执行进度成功', + ); } catch (error) { if (error instanceof HttpException) { throw error;