diff --git a/nodes/llm_nodes.py b/nodes/llm_nodes.py index 9e0105f..adc458b 100644 --- a/nodes/llm_nodes.py +++ b/nodes/llm_nodes.py @@ -18,6 +18,7 @@ from jinja2 import Template, StrictUndefined from loguru import logger from retry import retry +from ..utils.http_utils import send_request from ..utils.image_utils import tensor_to_image_bytes, base64_to_tensor @@ -306,21 +307,21 @@ class ModalClothesMask: try: timeout = 60 logger.info("获取token") - api_key = requests.get(f"https://{endpoint}/google/access-token", + api_key = send_request("get", f"https://{endpoint}/google/access-token", headers={'Authorization': 'Bearer bowong7777'}, timeout=timeout).json()[ "access_token"] format = "PNG" logger.info("请求图像编辑") - job_resp = requests.post(f"https://{endpoint}/google/image/clothes_mark", - headers={'x-google-api-key': api_key}, - data={ - "mark_clothes_type": clothes_type, - "mark_color": mask_color, - }, - files={"origin_image": ( - 'image.' + format.lower(), tensor_to_image_bytes(image, format), - f'image/{format.lower()}')}, - timeout=timeout) + job_resp = send_request("post", f"https://{endpoint}/google/image/clothes_mark", + headers={'x-google-api-key': api_key}, + data={ + "mark_clothes_type": clothes_type, + "mark_color": mask_color, + }, + files={"origin_image": ( + 'image.' + format.lower(), tensor_to_image_bytes(image, format), + f'image/{format.lower()}')}, + timeout=timeout) job_resp.raise_for_status() job_resp = job_resp.json() if not job_resp["success"]: @@ -332,7 +333,7 @@ class ModalClothesMask: logger.info("开始轮询任务状态") for _ in range(0, wait_time, interval): logger.info("查询任务状态") - result = requests.get(f"https://{endpoint}/google/{job_id}", timeout=timeout) + result = send_request("get", f"https://{endpoint}/google/{job_id}", timeout=timeout) if result.status_code == 200: result = result.json() if result["status"] == "success": @@ -369,20 +370,20 @@ class ModalEditCustom: try: timeout = 60 logger.info("获取token") - api_key = requests.get(f"https://{endpoint}/google/access-token", + api_key = send_request("get", f"https://{endpoint}/google/access-token", headers={'Authorization': 'Bearer bowong7777'}, timeout=timeout).json()[ "access_token"] format = "PNG" logger.info("请求图像编辑") - job_resp = requests.post(f"https://{endpoint}/google/image/edit_custom", - headers={'x-google-api-key': api_key}, - data={ - "prompt": prompt - }, - files={"origin_image": ( - 'image.' + format.lower(), tensor_to_image_bytes(image, format), - f'image/{format.lower()}')}, - timeout=timeout) + job_resp = send_request("post", f"https://{endpoint}/google/image/edit_custom", + headers={'x-google-api-key': api_key}, + data={ + "prompt": prompt + }, + files={"origin_image": ( + 'image.' + format.lower(), tensor_to_image_bytes(image, format), + f'image/{format.lower()}')}, + timeout=timeout) job_resp.raise_for_status() job_resp = job_resp.json() if not job_resp["success"]: @@ -394,7 +395,7 @@ class ModalEditCustom: logger.info("开始轮询任务状态") for _ in range(0, wait_time, interval): logger.info("查询任务状态") - result = requests.get(f"https://{endpoint}/google/{job_id}", timeout=timeout) + result = send_request("get", f"https://{endpoint}/google/{job_id}", timeout=timeout) if result.status_code == 200: result = result.json() if result["status"] == "success": diff --git a/utils/http_utils.py b/utils/http_utils.py new file mode 100644 index 0000000..254b4f6 --- /dev/null +++ b/utils/http_utils.py @@ -0,0 +1,86 @@ +import requests +from requests.exceptions import RequestException +import time +from typing import Optional, Dict, Any, Tuple, Union + + +def send_request( + method: str, + url: str, + headers: Optional[Dict[str, str]] = None, + data: Optional[Union[Dict[str, Any], str]] = None, + files: Optional[Dict[str, Any]] = None, + timeout: int = 30, + retries: int = 3, + backoff_factor: float = 0.5, + retry_on_status: Optional[Tuple[int, ...]] = None +) -> requests.Response: + """ + 发送HTTP请求,支持重试机制 + + 参数: + method: 请求方法,支持"GET"或"POST" + url: 请求URL + headers: 请求头字典 + data: 请求数据,可以是字典或字符串 + files: 上传文件字典,格式为{'file': open('path', 'rb')} + timeout: 请求超时时间(秒) + retries: 失败重试次数 + backoff_factor: 重试间隔退避因子(秒) + retry_on_status: 需要重试的HTTP状态码元组,如(500, 502, 503) + + 返回: + requests.Response: 响应对象 + + 抛出: + RequestException: 所有重试后仍然失败时抛出 + """ + # 规范化方法名称 + method = method.upper() + + # 验证方法是否支持 + if method not in ["GET", "POST"]: + raise ValueError(f"Unsupported HTTP method: {method}") + + # 初始化重试计数器 + retry_attempts = 0 + + # 主循环:处理重试逻辑 + while True: + try: + # 根据请求方法调用对应的requests函数 + if method == "GET": + response = requests.get( + url=url, + headers=headers, + timeout=timeout + ) + else: # POST + response = requests.post( + url=url, + headers=headers, + data=data, + files=files, + timeout=timeout + ) + + # 检查状态码是否需要重试 + if retry_on_status and response.status_code in retry_on_status: + raise RequestException(f"Server returned retryable status code: {response.status_code}") + + # 状态码正常,返回响应 + return response + + except RequestException as e: + # 重试次数耗尽,抛出异常 + if retry_attempts >= retries: + raise RequestException(f"All {retries} retry attempts failed: {str(e)}") from e + + # 计算退避时间:backoff_factor * (2 ** (retry_attempts)) + wait_time = backoff_factor * (2 ** retry_attempts) + print(f"Request failed (attempt {retry_attempts + 1}/{retries}): {str(e)}") + print(f"Retrying in {wait_time:.2f} seconds...") + + # 等待后重试 + time.sleep(wait_time) + retry_attempts += 1