- 在MaskInfo模型中添加score字段用于存储检测置信度 - 修改YOLO检测逻辑以提取和传递预测分数 - 更新坐标数据结构以包含置信度信息 - 调整数据处理流程以正确传输分数数据 - 修改Dockerfile以支持代码持久化部署 - 更新README文档说明代码持久化配置方式
87 lines
2.7 KiB
Python
87 lines
2.7 KiB
Python
import json
|
|
import os
|
|
|
|
from fastapi import APIRouter, Response
|
|
from app.schemas.analyze_result import AnalyzeResult, AnalyzeResultData, ImageInfo, MaskInfo, ResultItem
|
|
from app.services.model import TaskStatus
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
@router.get("/analyze/result/{task_id}")
|
|
async def get_task_result(task_id: str, response: Response):
|
|
from app.main import UPLOAD_DIR, WORKER
|
|
task = WORKER.task_store.get(task_id)
|
|
|
|
if not task:
|
|
response.status_code = 404
|
|
return AnalyzeResult(
|
|
success=False,
|
|
data=AnalyzeResultData(
|
|
taskId=task_id,
|
|
status=TaskStatus.NOT_FOUND.name,
|
|
completedAt=None,
|
|
results=None
|
|
)
|
|
)
|
|
|
|
if task.status == TaskStatus.COMPLETED.name:
|
|
# 构建完成状态的结果数据
|
|
result_items = []
|
|
|
|
# 输入和输出目录路径
|
|
input_dir = os.path.join(UPLOAD_DIR, task_id, "inputs")
|
|
output_dir = os.path.join(UPLOAD_DIR, task_id, "outputs")
|
|
|
|
for idx, result_data in enumerate(task.result):
|
|
# 解析坐标数据
|
|
coords_data = json.loads(result_data.get("coords", "[]"))
|
|
|
|
# 构建图片信息
|
|
input_img_path = result_data.get("input_img_path", "")
|
|
output_img_path = result_data.get("output_img_path", "")
|
|
|
|
# 构建URL路径
|
|
input_filename = os.path.basename(input_img_path)
|
|
output_filename = os.path.basename(output_img_path)
|
|
|
|
image_info = ImageInfo(
|
|
origin=f"/uploads/{task_id}/inputs/{input_filename}",
|
|
image=f"/uploads/{task_id}/outputs/{output_filename}" if output_img_path is not "" else "",
|
|
)
|
|
|
|
# 构建mask信息
|
|
masks = [
|
|
MaskInfo(name=mask["name"], score=mask["score"], coords=mask["coords"])
|
|
for mask in coords_data
|
|
]
|
|
|
|
result_item = ResultItem(
|
|
id=str(idx),
|
|
images=image_info,
|
|
masks=masks
|
|
)
|
|
result_items.append(result_item)
|
|
|
|
response.status_code = 200
|
|
return AnalyzeResult(
|
|
success=True,
|
|
data=AnalyzeResultData(
|
|
taskId=task_id,
|
|
status=task.status,
|
|
completedAt=task.completedAt.isoformat() if task.completedAt else "",
|
|
results=result_items
|
|
)
|
|
)
|
|
else:
|
|
# 其他状态(处理中、失败等)
|
|
return AnalyzeResult(
|
|
success=True,
|
|
data=AnalyzeResultData(
|
|
taskId=task_id,
|
|
status=task.status,
|
|
completedAt=None,
|
|
results=None
|
|
)
|
|
)
|