From fd590d1294acddfe3f9cbb6e0e9d3f3610f26048 Mon Sep 17 00:00:00 2001 From: Boen_Shi Date: Thu, 29 Jan 2026 15:50:14 +0800 Subject: [PATCH] =?UTF-8?q?feat(model):=20=E6=B7=BB=E5=8A=A0=E6=A3=80?= =?UTF-8?q?=E6=B5=8B=E7=BB=93=E6=9E=9C=E7=BD=AE=E4=BF=A1=E5=BA=A6=E5=88=86?= =?UTF-8?q?=E6=95=B0=E6=94=AF=E6=8C=81=E3=80=81=E4=BF=AE=E6=94=B9Dockerfil?= =?UTF-8?q?e=E4=BB=A5=E6=94=AF=E6=8C=81=E4=BB=A3=E7=A0=81=E6=8C=81?= =?UTF-8?q?=E4=B9=85=E5=8C=96=E9=83=A8=E7=BD=B2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 在MaskInfo模型中添加score字段用于存储检测置信度 - 修改YOLO检测逻辑以提取和传递预测分数 - 更新坐标数据结构以包含置信度信息 - 调整数据处理流程以正确传输分数数据 - 修改Dockerfile以支持代码持久化部署 - 更新README文档说明代码持久化配置方式 --- Dockerfile | 8 ++++---- README.md | 7 ++++--- app/core/yolo_detect/detect.py | 10 ++++++++-- app/core/yolo_detect/predict.py | 2 +- app/routes/analyze_result.py | 2 +- app/schemas/analyze_result.py | 1 + app/services/worker.py | 3 ++- 7 files changed, 21 insertions(+), 12 deletions(-) diff --git a/Dockerfile b/Dockerfile index e339ed4..d30c5d8 100644 --- a/Dockerfile +++ b/Dockerfile @@ -19,11 +19,11 @@ COPY requirements.txt /code/requirements.txt RUN pip install --no-cache-dir --upgrade -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple \ && pip install --no-cache-dir --upgrade torch torchvision --index-url https://download.pytorch.org/whl/cu130 -# 复制应用代码 -COPY ./app /code/app +# 复制应用代码(正式上线后取消注释) +# COPY ./app /code/app -# 删除核心文件,减小体积 -RUN rm -rf /code/app/core +# 删除核心文件,减小体积(正式上线后取消注释) +# RUN rm -rf /code/app/core # 暴露端口并启动应用 CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "80"] diff --git a/README.md b/README.md index 75aeea4..80e6e2e 100644 --- a/README.md +++ b/README.md @@ -8,11 +8,11 @@ docker build -t wall . #### 启动方式 -> 采用模型文件持久化,方便后续更新模型而不需要重新创建容器,以及统一配置配置文件 +> 采用代码持久化,方便后续更新代码及模型而不需要重新创建容器,以及统一配置配置文件 > > 后续更新,如果只更新了核心文件,则仅需git pull后重新启动容器 > -> 如果更新了代码文件,则需要重新构建镜像 +> 如果更新了requirements.txt,则需要重新构建镜像 ```bash sudo docker run -d \ @@ -20,10 +20,11 @@ sudo docker run -d \ --gpus all \ -p [local_port]:80 \ -v $(pwd)/uploads:/code/uploads \ - -v $(pwd)/app/core:/code/app/core \ + -v $(pwd)/app:/code/app \ -v $(pwd)/.env:/code/.env \ wall ``` > TIPS:由于部分文件采用GIT LFS 管理,请先安装GIT LFS +> > 在clone或者pull时,建议先clone代码文件,然后停掉,再用git lfs pull可查看大文件下载进度 \ No newline at end of file diff --git a/app/core/yolo_detect/detect.py b/app/core/yolo_detect/detect.py index 2f65362..8ab5a1a 100644 --- a/app/core/yolo_detect/detect.py +++ b/app/core/yolo_detect/detect.py @@ -98,6 +98,7 @@ class YOLODetect(YOLO): predicted_class = self.class_names[int(c)] if predicted_class != "wall": box = top_boxes[i] + score = top_conf[i] top, left, bottom, right = box top = max(0, np.floor(top).astype('int32')) left = max(0, np.floor(left).astype('int32')) @@ -126,8 +127,13 @@ class YOLODetect(YOLO): if keep: color = self.colors[int(c)] mask[top:bottom, left:right] = color - coords.append((self.classes.get(predicted_class), - [(int(left), int(top)), (int(right), int(top)), (int(right), int(bottom)), (int(left), int(bottom))])) + coords.append( + ( + self.classes.get(predicted_class), + float(score), + [(int(left), int(top)), (int(right), int(top)), (int(right), int(bottom)), (int(left), int(bottom))] + ) + ) mask = cv2.cvtColor(mask, cv2.COLOR_RGB2BGR) # print("coords:", coords) diff --git a/app/core/yolo_detect/predict.py b/app/core/yolo_detect/predict.py index ed72ec2..f70a519 100644 --- a/app/core/yolo_detect/predict.py +++ b/app/core/yolo_detect/predict.py @@ -23,7 +23,7 @@ if __name__ == "__main__": # 'heatmap' 表示进行预测结果的热力图可视化,详情查看下方注释。 # 'export_onnx' 表示将模型导出为onnx,需要pytorch1.7.1以上。 #----------------------------------------------------------------------------------------------------------# - mode = "dir_predict" + mode = "predict" #-------------------------------------------------------------------------# # crop 指定了是否在单张图片预测后对目标进行截取 # count 指定了是否进行目标的计数 diff --git a/app/routes/analyze_result.py b/app/routes/analyze_result.py index ea7bdcb..1c459e0 100644 --- a/app/routes/analyze_result.py +++ b/app/routes/analyze_result.py @@ -52,7 +52,7 @@ async def get_task_result(task_id: str, response: Response): # 构建mask信息 masks = [ - MaskInfo(name=mask["name"], coords=mask["coords"]) + MaskInfo(name=mask["name"], score=mask["score"], coords=mask["coords"]) for mask in coords_data ] diff --git a/app/schemas/analyze_result.py b/app/schemas/analyze_result.py index bb64d02..73c37e0 100644 --- a/app/schemas/analyze_result.py +++ b/app/schemas/analyze_result.py @@ -11,6 +11,7 @@ class ImageInfo(BaseModel): class MaskInfo(BaseModel): name: str + score: float coords: List[List[int]] diff --git a/app/services/worker.py b/app/services/worker.py index 7a395e4..ec8c638 100644 --- a/app/services/worker.py +++ b/app/services/worker.py @@ -42,7 +42,8 @@ class Worker: print(f"处理任务 {task_id}, 处理图片 {input_img_path}...") img_res, coords_res = self.detection.detect(input_img_path) - coords_res = [{"name": name, "coords": coords} for name, coords in coords_res] + coords_res = [{"name": name, "score": score, "coords": coords} for name, score, coords in coords_res] + print(coords_res) coords_json = json.dumps(coords_res, ensure_ascii=False) out_img_path = os.path.join(str(output_dir), f"{idx}.jpg") cv2.imwrite(out_img_path, img_res)