diff --git a/__init__.py b/__init__.py index 5ad79cf..f9b0f90 100644 --- a/__init__.py +++ b/__init__.py @@ -54,7 +54,7 @@ class FaceDetect: config = {"start": start, "end": end} else: raise RuntimeError("未找到符合要求的视频片段") - return (image, image_selected, cls, prob, nums, str(period), json.dumps(config), start, end - start) + return (image, image_selected, cls, prob, nums, str(period), json.dumps(config), start, end - start+1) class FaceExtract: diff --git a/test_single_image.py b/test_single_image.py index cb18d47..b533de0 100644 --- a/test_single_image.py +++ b/test_single_image.py @@ -96,7 +96,13 @@ def test_node(image:torch.Tensor,length=10,thres=95,model_name="convnext_tiny"): end = nums[idx] if end - start + 1 >= length: period.append([start, end]) - return (image.permute(0,2,3,1), image.permute(0,2,3,1)[nums,:,:,:], str(preds), str(probs), str(nums), period) + temp_period = [] + for i in period: + a = i[0] + while a+length-1 <= i[1]: + temp_period.append([a,a+length-1]) + a = a+length + return (image.permute(0,2,3,1), image.permute(0,2,3,1)[nums,:,:,:], str(preds), str(probs), str(nums), temp_period)