260 lines
20 KiB
Python
260 lines
20 KiB
Python
import json
|
||
import os
|
||
import unittest
|
||
from typing import Optional, List
|
||
|
||
from loguru import logger
|
||
from google.genai import types
|
||
from pydantic import BaseModel, Field, computed_field
|
||
from langfuse import Langfuse
|
||
from jinja2 import Template
|
||
from BowongModalFunctions.utils.HTTPUtils import GoogleAuthUtils, FlatJsonSchemaGenerator
|
||
|
||
|
||
class VisualFeatureColor(BaseModel):
|
||
pattern: str = Field(description="商品详细图案纹理等有辨识度的材质特征")
|
||
style: str = Field(description="商品详细款式版型等有辨识度的风格特征")
|
||
|
||
|
||
class VisualFeature(BaseModel):
|
||
color: VisualFeatureColor = Field(description="商品详细颜色配色等有辨识度的色彩特征")
|
||
|
||
|
||
class VisualRecognizeResult(BaseModel):
|
||
image_order: int = Field(description="图片的顺序, 从0开始")
|
||
image_name: str = Field(description="图片上显示的原始文字")
|
||
matched_product: Optional[str] = Field(description="匹配到的标准商品名称或null")
|
||
match_confidence: int = Field(description="0到100的可信度评分")
|
||
visual_features: List[VisualFeature] = Field(description="所有识别到的商品详细颜色配色等有辨识度的色彩特征")
|
||
|
||
class VisualRecognizeResults(BaseModel):
|
||
results: VisualRecognizeResult = Field(description="每一个图片识别的结果")
|
||
image_count: int = Field(description="输入待识别图片的总数")
|
||
product_count: int = Field(description="输入待识别商品的总数")
|
||
|
||
class PromptVariables(BaseModel):
|
||
product_list: List[str] = Field(description="商品列表")
|
||
|
||
@computed_field(description="xml格式排列的商品列表")
|
||
@property
|
||
def product_list_xml(self) -> str:
|
||
xml_items = [f"<product>{product}</product>" for product in self.product_list]
|
||
xml_string = "\n".join(xml_items)
|
||
return f"<products>\n{xml_string}\n</products>"
|
||
|
||
|
||
class GoogleTestCase(unittest.IsolatedAsyncioTestCase):
|
||
service_account_info: dict = {
|
||
"type": "service_account",
|
||
"project_id": "gen-lang-client-0413414134",
|
||
"private_key_id": "48c91fc4cae8158edaad1f52577e4c98143a8cd9",
|
||
"private_key": "-----BEGIN PRIVATE KEY-----\nMIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQCzNfkzjGOSAv+e\nHSWOEq87sE8cNdt0AXdAyRL66rMuerGjGpOoP5Ok/LfZrx7DdGg7f9w1DZmw8P81\nvj7s2ZchEGfRrDVQNigaogJzDWQBnCUZBMmaFBcnMndPDb9gqM9fP4gWJoAcRoxw\nFzBi7sPdl5C5Y24UdoHky6z+YKHtLqo3kdB+qXCsJR8U4eqJG16EW/OlS26L/hSP\n8tLNFI3SgcJiRWeCO5pRRpX6nfGf5wju0KMaJKzBRbDJwF3NEj3nmoXSfyoD+itV\nuv5DDCwojB/4nLT2EuxAr5vyY+JY6LmCZhWgPcXy60nsDcUxJzcvRaCsb+C2exZR\n9q4jgXUdAgMBAAECggEAOfle6Zcj6us/aBYDvSc8OvH5VaXynV+QBYxGsJdWadXV\nO29wjwAqMjhy/V/ScuZohb8CLMN+kagU13z4/EQTyOV2wHSWNqGebac1ZaTSUlcC\nBUrwMQEI0GxZ/l/zJkDV/PkffBLuZLdJ3UUTKR4WjMvoTKDmzoXb1XkyOIRoPcLN\nQXqGRl2A/BLgL0mxZsnvBXavcp0o4TfIxC73+ZEnJmrbuoHlDGbXWQvSOGJbi3gE\nLLSJ/+Sn2o9nhHPJI3M9xfMHnU7Fwo5Dt+vSl/Vn4+dvNd2djEjefQTSU19yE6n8\nW5/QzriG4IClBEjqTxYxnL+VQNmUm5dwXqB0C2ph/QKBgQDZ8gKJgU4xH0Woxh+d\nh5AjjnKVk0YlCS9MNA9VlGu2O8ohVj9LZ4azNfYyZMp5DP4gXW21elZWwRyfvSv0\nRlo1CrQIZwY/ETw8NOp9L8+OXQinjL8pLuqYo8rWU4jwdyHrBlwic0KSSOq785Xk\nmdSdU3NOPqTnnUHnDrJLFyEFowKBgQDSgI9v798XAgNhwRLVAwyXTf1N39R8jAl6\nm+m3xzPEblnOKp04cMcjjhV8AqNadg5bZ1Io5Qwy30PofcQmNLCCXs31gCn+0c0i\nEehSXz+QgkNmSxWLXl3SODWY4XN7ThmLJcL36iebKF8t50xvzKbER0xEMwzrrmVq\nW0YwpjfGPwKBgF7Plyb2Z2ubLRSUy+AdvyiYqWREYzltW3QNGbajEJCARhhmirZk\n3QZNLUMS8bnjWxH9UuKly7WF4Mvk4aAsksWMWHFnUCJTfx657mBzUhmeg0tQQUDL\nNicc6fp+8I2bZdf2NlKOTaGRsvv8pXKDMSkXyot5WQehM7AuhoWAFE99AoGAWJ8F\nREQBcQdI8zO8wO8asuyDkvCD3bd7GiJfwB5eXflzV4e7TxKz0/UyeFYH/cKsArE5\n9ruPai9ywIOKO+d81DYjkZLWm1Aqg4h0fZFaCnW8+Gjt9hHRf/poHif0XVohCOLp\n9UOgTwMtJv80v/Cx2PqHUkMH0oVGbwNkRoEEBDMCgYEAvf7a3Xxjl3Ymmy+15oOW\n+iX0/3Ntmr6TUjxFzRRnamO3CO7Vm3qCLOceE2C5/TCi07NxGXkY/NIEUywBGrLe\nSmm2ny5/u6vrDygZMGSB59RVnrAiX7zkqaIy6pY6cgPRQhclHZ9s34Nnd1J2GM4v\nxfdWj16ZNTMAaaWd9u+nZhg=\n-----END PRIVATE KEY-----\n",
|
||
"client_email": "gemini-api@gen-lang-client-0413414134.iam.gserviceaccount.com",
|
||
"client_id": "116149182781835050625",
|
||
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
|
||
"token_uri": "https://oauth2.googleapis.com/token",
|
||
"auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
|
||
"client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/gemini-api%40gen-lang-client-0413414134.iam.gserviceaccount.com",
|
||
"universe_domain": "googleapis.com"
|
||
}
|
||
bucket_name = "dy-media-storage"
|
||
cloudflare_project_id = "67720b647ff2b55cf37ba3ef9e677083"
|
||
cloudflare_gateway_id = "bowong-dev"
|
||
|
||
async def test_google_new_token(self):
|
||
cred = await GoogleAuthUtils.get_google_auth_jwt(service_account_info=self.service_account_info,
|
||
scopes=[
|
||
'https://www.googleapis.com/auth/cloud-platform'])
|
||
self.assertIsNotNone(cred)
|
||
|
||
async def test_google_cloud_storage_upload(self):
|
||
filepath = "./videos/input_1.mp4"
|
||
cred = await GoogleAuthUtils.get_google_auth_jwt(service_account_info=self.service_account_info,
|
||
scopes=[
|
||
'https://www.googleapis.com/auth/cloud-platform'])
|
||
prefix = "test/123"
|
||
filename = os.path.basename(filepath)
|
||
with open(filepath, 'rb') as file:
|
||
response = await GoogleAuthUtils.google_upload_file(file_stream=file,
|
||
content_type="video/mp4",
|
||
google_api_key=cred.access_token,
|
||
bucket_name=self.bucket_name,
|
||
filename=f"{prefix}/{filename}")
|
||
logger.info(response.model_dump_json(indent=2))
|
||
self.assertIsNotNone(response)
|
||
|
||
async def test_google_inference_with_sdk(self):
|
||
cred = await GoogleAuthUtils.get_google_auth_jwt(service_account_info=self.service_account_info,
|
||
scopes=[
|
||
'https://www.googleapis.com/auth/cloud-platform'])
|
||
|
||
logger.info(cred.model_dump_json(indent=2))
|
||
client = GoogleAuthUtils.GoogleGenaiClient(
|
||
cloudflare_project_id=self.cloudflare_project_id,
|
||
cloudflare_gateway_id=self.cloudflare_gateway_id,
|
||
google_project_id=self.service_account_info.get('project_id'),
|
||
regions=['us-central1'], access_token=cred.access_token,
|
||
)
|
||
config = types.GenerateContentConfig(temperature=0.1,
|
||
top_p=0.7,
|
||
safety_settings=[
|
||
types.SafetySetting(
|
||
category=types.HarmCategory.HARM_CATEGORY_HARASSMENT,
|
||
threshold=types.HarmBlockThreshold.BLOCK_NONE,
|
||
),
|
||
types.SafetySetting(
|
||
category=types.HarmCategory.HARM_CATEGORY_CIVIC_INTEGRITY,
|
||
threshold=types.HarmBlockThreshold.BLOCK_NONE,
|
||
),
|
||
types.SafetySetting(
|
||
category=types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
|
||
threshold=types.HarmBlockThreshold.BLOCK_NONE,
|
||
),
|
||
types.SafetySetting(
|
||
category=types.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
|
||
threshold=types.HarmBlockThreshold.BLOCK_NONE,
|
||
),
|
||
types.SafetySetting(
|
||
category=types.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
|
||
threshold=types.HarmBlockThreshold.BLOCK_NONE,
|
||
)
|
||
],
|
||
response_mime_type="application/json",
|
||
response_schema=VisualRecognizeResult)
|
||
result = client.generate_content(model_id="gemini-2.5-flash",
|
||
contents=[types.Content(role='user',
|
||
parts=[
|
||
types.Part(file_data=types.FileData(
|
||
mime_type="video/mp4",
|
||
file_uri="gs://dy-media-storage/videos/035b3053-73f8-45b7-9bf8-428df9025608.mp4"
|
||
)),
|
||
types.Part.from_text(
|
||
text="帮我总结一下这个视频里有什么"),
|
||
])],
|
||
config=config)
|
||
logger.info(result.model_dump_json(indent=2, exclude_none=True))
|
||
self.assertIsNotNone(result)
|
||
|
||
async def test_google_save_prompt(self):
|
||
config = types.GenerateContentConfig(temperature=0.1, top_p=0.7,
|
||
safety_settings=[
|
||
types.SafetySetting(
|
||
category=types.HarmCategory.HARM_CATEGORY_HARASSMENT,
|
||
threshold=types.HarmBlockThreshold.BLOCK_NONE,
|
||
),
|
||
types.SafetySetting(
|
||
category=types.HarmCategory.HARM_CATEGORY_CIVIC_INTEGRITY,
|
||
threshold=types.HarmBlockThreshold.BLOCK_NONE,
|
||
),
|
||
types.SafetySetting(
|
||
category=types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
|
||
threshold=types.HarmBlockThreshold.BLOCK_NONE,
|
||
),
|
||
types.SafetySetting(
|
||
category=types.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
|
||
threshold=types.HarmBlockThreshold.BLOCK_NONE,
|
||
),
|
||
types.SafetySetting(
|
||
category=types.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
|
||
threshold=types.HarmBlockThreshold.BLOCK_NONE,
|
||
)
|
||
],
|
||
response_mime_type="application/json",
|
||
response_schema=VisualRecognizeResults.model_json_schema(
|
||
schema_generator=FlatJsonSchemaGenerator)
|
||
)
|
||
logger.info(config.model_dump_json(indent=2))
|
||
langfuse = Langfuse(host="https://us.cloud.langfuse.com",
|
||
secret_key="sk-lf-dd20cb0b-ef2e-49f6-80f0-b2d9cff1bb11",
|
||
public_key="pk-lf-15f9d809-0bf6-4a84-ae1c-18f7a7d927c7",
|
||
tracing_enabled=False)
|
||
prompt = """
|
||
<prompt><instruction>你是专业的商品识别专家。我上传了商品图片网格,需要你识别图片中的商品并与商品列表进行匹配。 **输入材料**: - 🖼️ **商品图片网格**:包含多个黑色边框区域,每个区域内有商品图片+商品名称文字 - 📋 **商品列表**:标准商品名称参考清单 **核心任务**: 1. **扫描黑色边框区域**:从左上角开始,按行扫描每个黑色边框区域 2. **提取文字信息**:精确提取每个区域内的所有文字信息 3. **与商品列表匹配**:将图片文字与商品列表进行高相似度匹配 4. **提取商品图片特征**:从商品图片提取详细可识别特征,包括颜色、图案、纹理、材质、版型、款式等 **严格约束**: - 🚫 只识别有黑色边框包围的商品区域 - 🚫 每个商品必须有清晰可见的文字标注 - 🚫 不得推测或添加图片中不存在的商品 - ✅ 输出商品数量不得超过图片中的黑色边框区域数量 **商品列表**: {{PRODUCT_LIST}}
|
||
</instruction></prompt>
|
||
"""
|
||
prompt = langfuse.create_prompt(name="Gemini自动切条", prompt=prompt, type="text", labels=["production"],
|
||
config=config.model_dump(exclude_none=True))
|
||
logger.info(prompt)
|
||
|
||
async def test_google_get_prompt(self):
|
||
langfuse = Langfuse(host="https://us.cloud.langfuse.com",
|
||
secret_key="sk-lf-dd20cb0b-ef2e-49f6-80f0-b2d9cff1bb11",
|
||
public_key="pk-lf-15f9d809-0bf6-4a84-ae1c-18f7a7d927c7",
|
||
tracing_enabled=False)
|
||
product_title_list = [
|
||
"A美洋MEIYANG【商场同款】碧螺春墨镜 醋酸纤维素防晒太阳眼镜-周四",
|
||
"A美洋MEIYANG【欧若风】微风背心 慵懒百搭圆领无袖上衣-周二",
|
||
"A美洋MEIYANG【商场同款】幸运T恤 复古做旧印花圆领短袖上衣-周四",
|
||
"A美洋MEIYANG 黑武士厚底老爹鞋 赛博末日风~厚底增高运动休闲鞋",
|
||
"合金项链 A美洋MEIYANG 香水瓶毛衣链 个性吊坠麻花链项链-周四",
|
||
"A美洋MEIYANG【商场同款】纱暮半裙 抗起球镂空蕾丝半身中长裙-周四",
|
||
]
|
||
variables = PromptVariables(product_list=product_title_list, )
|
||
latest_prompt = langfuse.get_prompt("Gemini自动切条", type="text", label="latest")
|
||
logger.info(f"variables={latest_prompt.variables}")
|
||
runtime_prompt = Template(latest_prompt.prompt).render(PRODUCT_LIST=variables.product_list_xml)
|
||
|
||
cred = await GoogleAuthUtils.get_google_auth_jwt(service_account_info=self.service_account_info,
|
||
scopes=[
|
||
'https://www.googleapis.com/auth/cloud-platform'])
|
||
|
||
logger.info(cred.model_dump_json(indent=2))
|
||
client = GoogleAuthUtils.GoogleGenaiClient(
|
||
cloudflare_project_id=self.cloudflare_project_id,
|
||
cloudflare_gateway_id=self.cloudflare_gateway_id,
|
||
google_project_id=self.service_account_info.get('project_id'),
|
||
regions=['us-central1'], access_token=cred.access_token,
|
||
)
|
||
config = types.GenerateContentConfig.model_validate(latest_prompt.config)
|
||
result = client.generate_content(model_id="gemini-2.5-flash",
|
||
contents=[types.Content(role='user',
|
||
parts=[
|
||
types.Part(file_data=types.FileData(
|
||
mime_type="image/jpeg",
|
||
file_uri="gs://dy-media-storage/images/grid_56fe257b-f81f-4ad0-8958-530ad557b876.jpg"
|
||
)),
|
||
types.Part(file_data=types.FileData(
|
||
mime_type="image/jpeg",
|
||
file_uri="gs://dy-media-storage/images/grid_f74b4c4f-a305-4a96-9045-f09e0eb90a30.jpg"
|
||
)),
|
||
types.Part(file_data=types.FileData(
|
||
mime_type="image/jpeg",
|
||
file_uri="gs://dy-media-storage/images/grid_e4c10111-e9ec-4e76-88db-abd57b3bb92e.jpg"
|
||
)),
|
||
types.Part(file_data=types.FileData(
|
||
mime_type="image/jpeg",
|
||
file_uri="gs://dy-media-storage/images/grid_52b73a7d-f5e0-4a17-b7c2-194ef3852856.jpg"
|
||
)),
|
||
types.Part(file_data=types.FileData(
|
||
mime_type="image/jpeg",
|
||
file_uri="gs://dy-media-storage/images/grid_f8d8c35e-1c35-48c4-b9ce-df6365b7fc55.jpg"
|
||
)),
|
||
types.Part(file_data=types.FileData(
|
||
mime_type="image/jpeg",
|
||
file_uri="gs://dy-media-storage/images/grid_b9132a81-f4b6-45f7-8b37-c0caaee06705.jpg"
|
||
)),
|
||
types.Part(file_data=types.FileData(
|
||
mime_type="image/jpeg",
|
||
file_uri="gs://dy-media-storage/images/grid_d7ef44fc-8767-405d-a9dc-f16e4a918efd.jpg"
|
||
)),
|
||
types.Part(file_data=types.FileData(
|
||
mime_type="image/jpeg",
|
||
file_uri="gs://dy-media-storage/images/grid_564b9ff9-fa3a-4d34-9592-038d454f0834.jpg"
|
||
)),
|
||
|
||
types.Part.from_text(text=runtime_prompt, ),
|
||
])],
|
||
config=config)
|
||
logger.info(result.model_dump_json(indent=2, exclude_none=True))
|
||
json_result = result.candidates[0].content.parts[0].text
|
||
result_model = VisualRecognizeResults.model_validate_json(json_result)
|
||
logger.info(result_model.model_dump_json(indent=2, exclude_none=True))
|
||
|
||
async def test_flat_json_schema(self):
|
||
json_schema = VisualRecognizeResults.model_json_schema(schema_generator=FlatJsonSchemaGenerator)
|
||
logger.info(json.dumps(json_schema, indent=2, ensure_ascii=False))
|
||
|
||
|
||
if __name__ == '__main__':
|
||
unittest.main()
|