From 6a2e0468849ef8178799616f5f6331af49426058 Mon Sep 17 00:00:00 2001 From: Boen_Shi Date: Tue, 27 Jan 2026 11:59:45 +0800 Subject: [PATCH] =?UTF-8?q?feat(api):=20=E6=B7=BB=E5=8A=A0=E5=9B=BE?= =?UTF-8?q?=E5=83=8F=E5=88=86=E6=9E=90=E5=8A=9F=E8=83=BD=E5=92=8C=E7=9B=B8?= =?UTF-8?q?=E5=85=B3=E8=B7=AF=E7=94=B1=E6=8E=A5=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 analyze、analyze_result、analyze_status 和 health 路由 - 实现图像上传和任务提交功能 - 添加任务状态查询和结果获取接口 - 集成 segformer 和 yolo 模型进行图像检测 - 实现 SAM3 预处理功能用于图像预处理判断 - 添加模型选择配置支持 segformer 和 yolo - 实现任务队列管理和异步处理机制 - 添加 Dockerfile 用于容器化部署 - 配置环境变量和 gitignore 规则 - 创建数据模型定义 API 响应结构 --- .env | 4 + .gitattributes | 4 + .gitignore | 6 + Dockerfile | 29 +++ README.md | 56 +++++ app/__init__.py | 0 app/core/__init__.py | 0 app/core/model.py | 16 ++ app/core/preprocess.py | 12 + app/core/sam3/__init__.py | 0 app/core/sam3/preprocess.py | 172 +++++++++++++ app/core/sam3/sam3.pt | 3 + app/core/segformer/__init__.py | 0 app/core/segformer/detect.py | 99 ++++++++ app/core/segformer/segformer.onnx | 3 + app/core/segformer/segformer_b2.onnx.data | 3 + app/core/yolo/__init__.py | 0 app/core/yolo/detect.py | 291 ++++++++++++++++++++++ app/core/yolo/model.onnx | 3 + app/main.py | 35 +++ app/routes/__init__.py | 0 app/routes/analyze.py | 54 ++++ app/routes/analyze_result.py | 86 +++++++ app/routes/analyze_status.py | 34 +++ app/routes/health.py | 15 ++ app/schemas/__init__.py | 0 app/schemas/analyze.py | 18 ++ app/schemas/analyze_result.py | 32 +++ app/schemas/analyze_status.py | 15 ++ app/schemas/health.py | 7 + app/services/__init__.py | 0 app/services/model.py | 66 +++++ app/services/worker.py | 80 ++++++ requirements.txt | 15 ++ 34 files changed, 1158 insertions(+) create mode 100644 .env create mode 100644 .gitattributes create mode 100644 .gitignore create mode 100644 Dockerfile create mode 100644 README.md create mode 100644 app/__init__.py create mode 100644 app/core/__init__.py create mode 100644 app/core/model.py create mode 100644 app/core/preprocess.py create mode 100644 app/core/sam3/__init__.py create mode 100644 app/core/sam3/preprocess.py create mode 100644 app/core/sam3/sam3.pt create mode 100644 app/core/segformer/__init__.py create mode 100644 app/core/segformer/detect.py create mode 100644 app/core/segformer/segformer.onnx create mode 100644 app/core/segformer/segformer_b2.onnx.data create mode 100644 app/core/yolo/__init__.py create mode 100644 app/core/yolo/detect.py create mode 100644 app/core/yolo/model.onnx create mode 100644 app/main.py create mode 100644 app/routes/__init__.py create mode 100644 app/routes/analyze.py create mode 100644 app/routes/analyze_result.py create mode 100644 app/routes/analyze_status.py create mode 100644 app/routes/health.py create mode 100644 app/schemas/__init__.py create mode 100644 app/schemas/analyze.py create mode 100644 app/schemas/analyze_result.py create mode 100644 app/schemas/analyze_status.py create mode 100644 app/schemas/health.py create mode 100644 app/services/__init__.py create mode 100644 app/services/model.py create mode 100644 app/services/worker.py create mode 100644 requirements.txt diff --git a/.env b/.env new file mode 100644 index 0000000..3ea3ff0 --- /dev/null +++ b/.env @@ -0,0 +1,4 @@ +UPLOAD_DIR=uploads +MOCK=false +MODEL=yolo #segformer, yolo +PREPROCESS=sam3 \ No newline at end of file diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..b00bad4 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,4 @@ +*.pt filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.onnx.data filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..57449f8 --- /dev/null +++ b/.gitignore @@ -0,0 +1,6 @@ +.idea/ +docker-test/ +tests/ +uploads/ +Wall Docker 镜像使用教程.md +Wall Docker 镜像使用教程.pdf \ No newline at end of file diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..1207d96 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,29 @@ +FROM python:3.11 + +# 安装系统依赖并清理缓存 +RUN apt-get update && apt-get install -y \ + libgl1 \ + libglib2.0-0 \ + libsm6 \ + libxrender1 \ + libxext6 \ + && rm -rf /var/lib/apt/lists/* + +# 设置工作目录 +WORKDIR /code + +# 复制并安装 Python 依赖 +COPY requirements.txt /code/requirements.txt + +# 安装 Python 依赖(加速源) +RUN pip install --no-cache-dir --upgrade -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple \ + && pip install --no-cache-dir --upgrade torch torchvision --index-url https://download.pytorch.org/whl/cu130 + +# 复制应用代码 +COPY ./app /code/app + +# 删除无用的文件,避免占用磁盘空间 +RUN rm -rf /code/app/core/*.onnx /code/app/core/*.data /code/app/core/*.pt + +# 暴露端口并启动应用 +CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "80"] diff --git a/README.md b/README.md new file mode 100644 index 0000000..bc6d100 --- /dev/null +++ b/README.md @@ -0,0 +1,56 @@ +### Wall Docker 镜像使用教程 + +--- + +> 采用模型文件持久化,方便后续更新模型而不需要重新创建容器,以及统一配置配置文件 + +1. 导入docker images `docker load -i wall.tar` + +2. 进入持久化目录,新建.env,添加相应内容 + + ```env + UPLOAD_DIR=uploads + MOCK=false + MODEL=segformer #segformer, yolo,目前打包模型只有segformer + ``` + +3. 解压算法模型目录到core文件夹 + + ```bash + tar -xvf core.tar + ``` + +4. 使用指令运行docker镜像 + + ```bash + sudo docker run -d \ + --name [docker_container_name] \ + --gpus all \ + -p [local_port]:80 \ + -v $(pwd)/uploads:/code/uploads \ + -v $(pwd)/core:/code/app/core \ + -v $(pwd)/.env:/code/.env \ + wall + ``` + +5. 如果后续需要更新模型,只需要覆盖掉core内的文件,更改.env配置文件后,即可继续运行 + +--- + +> 如果不想要模型文件持久化,则不需要解压算法文件了 + +1. 导入docker images `docker load -i wall.tar` + +2. 使用指令运行docker镜像 + + ```bash + sudo docker run -d \ + --name [docker_container_name] \ + --gpus all \ + -p [local_port]:80 \ + -v $(pwd)/core:/code/app/core \ + -v $(pwd)/.env:/code/.env \ + wall + ``` + + \ No newline at end of file diff --git a/app/__init__.py b/app/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/core/__init__.py b/app/core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/core/model.py b/app/core/model.py new file mode 100644 index 0000000..495de48 --- /dev/null +++ b/app/core/model.py @@ -0,0 +1,16 @@ +from app.core.segformer.detect import Detection as SegFormer, DetectionMock as SegFormerMock +from app.core.yolo.detect import YOLOSeg + + +class Model: + def __init__(self): + from app.main import MODEL + if MODEL == "segformer": + print("使用 SegFormer 模型") + self.detection = SegFormer() + elif MODEL == "yolo": + print("使用 YOLO 模型") + self.detection = YOLOSeg() + + def getModel(self): + return self.detection diff --git a/app/core/preprocess.py b/app/core/preprocess.py new file mode 100644 index 0000000..632d2ef --- /dev/null +++ b/app/core/preprocess.py @@ -0,0 +1,12 @@ +from app.core.sam3.preprocess import SAM3 + + +class Preprocess: + def __init__(self): + from app.main import PREPROCESS + if PREPROCESS == "sam3": + print("使用 SAM3 进行预处理判断") + self.preprocess = SAM3() + + def getPreprocess(self): + return self.preprocess diff --git a/app/core/sam3/__init__.py b/app/core/sam3/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/core/sam3/preprocess.py b/app/core/sam3/preprocess.py new file mode 100644 index 0000000..63e6c98 --- /dev/null +++ b/app/core/sam3/preprocess.py @@ -0,0 +1,172 @@ +import os +import torch +from pathlib import Path +from PIL import Image +from torch.utils.data import Dataset, DataLoader +from sam3.model_builder import build_sam3_image_model +from sam3.train.data.collator import collate_fn_api as collate +from sam3.model.utils.misc import copy_data_to_device +from sam3.train.data.sam3_image_dataset import ( + Datapoint, Image as SAMImage, FindQueryLoaded, InferenceMetadata +) +from sam3.train.transforms.basic_for_api import ComposeAPI, RandomResizeAPI, ToTensorAPI, NormalizeAPI +from sam3.eval.postprocessors import PostProcessImage + +os.environ["CUDA_VISIBLE_DEVICES"] = "0" +os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True,max_split_size_mb:256" + +# ===== 配置 ===== +CKPT_PATH = os.path.join(os.getcwd(), "app/core/sam3", "sam3.pt") +DEVICE = "cuda:0" + +BATCH_SIZE = 12 # 批量大小,前端要设置 +NUM_WORKERS = 12 # 加载图片的线程数,看前端要不要设置 +CONF_TH = 0.5 +RATIO_TH = 0.5 # 阈值,越大的话过滤越多,但太大会影响近景图片 +_GLOBAL_ID = 1 + +PROMPTS = [ + "wall", + "building wall", + "building facade", + "building exterior wall", + "exterior building facade", +] +IMG_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff", ".webp"} + + +# ============ + + +class ImgPathList(Dataset): + def __init__(self, img_paths: list): + """ + 初始化 ImgFolder,传入一个图片路径的列表 + + Args: + img_paths (list): 一个包含图片路径的列表 + """ + self.paths = img_paths # 使用传入的路径列表 + + def __len__(self): + return len(self.paths) + + def __getitem__(self, i): + p = self.paths[i] # 直接使用列表中的路径 + img = Image.open(p).convert("RGB") # 打开图片并转换为RGB模式 + return p, img # 返回图片的路径和图片本身 + + + +class SAM3: + def __init__(self): + self.dev = torch.device(DEVICE) + self.postprocessor = PostProcessImage( + max_dets_per_img=-1, + iou_type="segm", + use_original_sizes_box=True, + use_original_sizes_mask=True, + convert_mask_to_rle=False, + detection_threshold=CONF_TH, + to_cpu=False, + ) + self.transform = ComposeAPI( + transforms=[ + RandomResizeAPI(sizes=1008, max_size=1008, square=True, consistent_transform=False), + ToTensorAPI(), + NormalizeAPI(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ] + ) + self.model = build_sam3_image_model( + checkpoint_path=CKPT_PATH, load_from_HF=False, device=DEVICE + ).to(DEVICE).eval() + + def preprocess(self, image_path_list): + labels = [] + + loader = DataLoader( + ImgPathList(image_path_list), + batch_size=BATCH_SIZE, + shuffle=False, + num_workers=NUM_WORKERS, + pin_memory=True, + collate_fn=self.collate_fn, + ) + + with torch.inference_mode(): + for names, images in loader: + datapoints = [] + name2qids = {} # name -> [qid,...] + for name, img in zip(names, images): + dp = self.create_empty_datapoint() + self.set_image(dp, img) + + qids = [self.add_text_prompt(dp, p) for p in PROMPTS] + name2qids[name] = qids + + datapoints.append(self.transform(dp)) + + batch = collate(datapoints, dict_key="dummy")["dummy"] + batch = copy_data_to_device(batch, self.dev, non_blocking=True) + output = self.model(batch) + + processed = self.postprocessor.process_results(output, batch.find_metadatas) + + for name in names: + any_masks = [] + for qid in name2qids[name]: + res = processed[qid] + m = res.get("masks", None) # 期望: [N,H,W] + if m is None: + any_masks.append(torch.zeros(1, 1, device=self.dev, dtype=torch.bool).squeeze()) + else: + if not torch.is_tensor(m): + m = torch.as_tensor(m, device=self.dev) + any_masks.append(m.any(0)) # [H,W] + + wall_mask = torch.stack(any_masks, 0).any(0) # [H,W] bool + ratio = wall_mask.float().mean().item() + lab = 1 if ratio >= RATIO_TH else 0 + labels.append(lab) + print(f"{name} | wall_ratio={ratio:.4f} -> {lab}") # 这行可以不要 + + return labels + + @staticmethod + def add_text_prompt(datapoint, text_query): + global _GLOBAL_ID + assert len(datapoint.images) == 1, "please set the image first" + w, h = datapoint.images[0].size + datapoint.find_queries.append( + FindQueryLoaded( + query_text=text_query, + image_id=0, + object_ids_output=[], + is_exhaustive=True, + query_processing_order=0, + inference_metadata=InferenceMetadata( + coco_image_id=_GLOBAL_ID, + original_image_id=_GLOBAL_ID, + original_category_id=1, + original_size=[w, h], + object_id=0, + frame_index=0, + ), + ) + ) + _GLOBAL_ID += 1 + return _GLOBAL_ID - 1 + + @staticmethod + def create_empty_datapoint(): + return Datapoint(find_queries=[], images=[]) + + @staticmethod + def set_image(datapoint, pil_image): + w, h = pil_image.size + datapoint.images = [SAMImage(data=pil_image, objects=[], size=[h, w])] # size 用 [H,W] + + @staticmethod + def collate_fn(batch): + names, imgs = zip(*batch) + return list(names), list(imgs) diff --git a/app/core/sam3/sam3.pt b/app/core/sam3/sam3.pt new file mode 100644 index 0000000..5b7c2ea --- /dev/null +++ b/app/core/sam3/sam3.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9999e2341ceef5e136daa386eecb55cb414446a00ac2b55eb2dfd2f7c3cf8c9e +size 3450062241 diff --git a/app/core/segformer/__init__.py b/app/core/segformer/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/core/segformer/detect.py b/app/core/segformer/detect.py new file mode 100644 index 0000000..a4a2722 --- /dev/null +++ b/app/core/segformer/detect.py @@ -0,0 +1,99 @@ +import os + +import cv2 +import numpy as np +import albumentations as A +import onnxruntime as ort + + +class Detection: + def __init__(self): + self.CLASS_NAMES = ["background", "Hollowing", "Water seepage", "Cracking"] + self.PALETTE = np.array([[0, 0, 0], [255, 0, 0], [0, 255, 0], [0, 0, 255]], dtype=np.uint8) + self.tfm = A.Compose([ + A.Resize(512, 512), + A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ]) + real_path = os.path.join(os.path.dirname(__file__), "segformer.onnx") + self.model = ort.InferenceSession( + real_path, + providers=["CUDAExecutionProvider", "CPUExecutionProvider"] + ) + print("ONNX 模型加载完成!") + + def get_contours(self, mask_np): + res = [] + for idx in range(1, len(self.CLASS_NAMES)): + name = self.CLASS_NAMES[idx] + binary = (mask_np == idx).astype(np.uint8) * 255 + cnts, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + for c in cnts: + if cv2.contourArea(c) < 20: + continue + eps = 0.002 * cv2.arcLength(c, True) + poly = cv2.approxPolyDP(c, eps, True) + if len(poly) >= 3: + res.append((name, poly.reshape(-1, 2).tolist())) + return res + + def detect(self, img_input): + # 读取图片 + if isinstance(img_input, str): + img_bgr = cv2.imdecode(np.fromfile(img_input, dtype=np.uint8), cv2.IMREAD_COLOR) + else: + img_bgr = img_input + + h, w = img_bgr.shape[:2] + img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) + + # 数据预处理 + x = self.tfm(image=img_rgb)["image"] + # 将 HWC -> CHW 并增加 batch 维度 + x = np.transpose(x, (2, 0, 1)) # HWC -> CHW + x = np.expand_dims(x, axis=0).astype(np.float32) # batch x C x H x W + + # ONNX 推理 + inp_name = self.model.get_inputs()[0].name + out_name = self.model.get_outputs()[0].name + out = self.model.run([out_name], {inp_name: x})[0] + + # 获取预测结果 + pred = out.argmax(axis=1)[0].astype(np.uint8) + + mask = cv2.resize(pred, (w, h), interpolation=cv2.INTER_NEAREST) + mask_rgb = self.PALETTE[mask] + coords = self.get_contours(mask) + # print(coords) + return mask_rgb, coords + + +class DetectionMock: + def __init__(self): + self.onnx_path = "segformer.onnx" + + def detect(self, img_input): + if isinstance(img_input, str): + img_bgr = cv2.imdecode(np.fromfile(img_input, dtype=np.uint8), cv2.IMREAD_COLOR) + else: + img_bgr = img_input + + img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) + + return img_rgb, [ + ('Cracking', [[771, 928], [757, 935]]), + ('Cracking', [[254, 740], [251, 942]]), + ('Cracking', [[764, 420], [764, 424]]), + ('Cracking', [[257, 238], [251, 245]]), + ('Cracking', [[1436, 145], [1401, 145]]) + ] + + +if __name__ == "__main__": + img_path = "test.jpg" + + detection = Detection() + mask_res, coords_res = detection.detect(img_path) + + print("Mask Shape:", mask_res.shape) + print("Detections:", coords_res) + cv2.imwrite("res_onnx.png", cv2.cvtColor(mask_res, cv2.COLOR_RGB2BGR)) diff --git a/app/core/segformer/segformer.onnx b/app/core/segformer/segformer.onnx new file mode 100644 index 0000000..399c663 --- /dev/null +++ b/app/core/segformer/segformer.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c3073aadff85295ed491765e60a2a6a3eaf1d6b4e1161c779260030d85271695 +size 198159486 diff --git a/app/core/segformer/segformer_b2.onnx.data b/app/core/segformer/segformer_b2.onnx.data new file mode 100644 index 0000000..5303731 --- /dev/null +++ b/app/core/segformer/segformer_b2.onnx.data @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:279691d9e2b4ceb4f6edc3c9b07df3656b1d52d0716c1431ce01b6e742b9e79a +size 103153664 diff --git a/app/core/yolo/__init__.py b/app/core/yolo/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/core/yolo/detect.py b/app/core/yolo/detect.py new file mode 100644 index 0000000..f26ce1a --- /dev/null +++ b/app/core/yolo/detect.py @@ -0,0 +1,291 @@ +import os + +import cv2 +import numpy as np +import onnxruntime as ort + +CLASS_NAMES = { + 0: "Hollowing", + 1: "Water seepage", + 2: "Cracking", +} + +CLASS_COLORS = { + 0: (0, 0, 255), # Hollowing -> 红色 + 1: (0, 255, 0), # Water seepage -> 绿色 + 2: (255, 0, 0), # Cracking -> 蓝色 +} + +IMG_SIZE = 640 + + +class YOLOSeg: + def __init__(self, onnx_path: str = "model.onnx", imgsz: int = IMG_SIZE): + real_path = os.path.join(os.path.dirname(__file__), onnx_path) + self.session = ort.InferenceSession( + real_path, + providers=["CUDAExecutionProvider", "CPUExecutionProvider"] + if ort.get_device() == "GPU" + else ["CPUExecutionProvider"], + ) + self.ndtype = ( + np.half + if self.session.get_inputs()[0].type == "tensor(float16)" + else np.single + ) + self.imgsz = imgsz + self.classes = CLASS_NAMES + + # ---------- 预处理:letterbox ---------- + def _preprocess(self, img_bgr): + h0, w0 = img_bgr.shape[:2] + new_shape = (self.imgsz, self.imgsz) + + r = min(new_shape[0] / h0, new_shape[1] / w0) + ratio = (r, r) + new_unpad = (int(round(w0 * r)), int(round(h0 * r))) + pad_w = (new_shape[1] - new_unpad[0]) / 2 + pad_h = (new_shape[0] - new_unpad[1]) / 2 + + if (w0, h0) != new_unpad: + img = cv2.resize(img_bgr, new_unpad, interpolation=cv2.INTER_LINEAR) + else: + img = img_bgr.copy() + + top, bottom = int(round(pad_h - 0.1)), int(round(pad_h + 0.1)) + left, right = int(round(pad_w - 0.1)), int(round(pad_w + 0.1)) + img = cv2.copyMakeBorder( + img, top, bottom, left, right, + borderType=cv2.BORDER_CONSTANT, value=(114, 114, 114) + ) + + # HWC -> CHW, BGR->RGB, /255 + img = np.ascontiguousarray(np.einsum("HWC->CHW", img)[::-1], dtype=self.ndtype) + img = img / 255.0 + if img.ndim == 3: + img = img[None] # (1,3,H,W) + return img, ratio, (pad_w, pad_h) + + # ---------- mask -> 多边形 ---------- + @staticmethod + def _masks2segments(masks): + """masks: (N,H,W) -> 每个实例的多边形坐标""" + segments = [] + for x in masks.astype("uint8"): + cs = cv2.findContours(x, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)[0] + if cs: + # 取点数最多的一条轮廓 + c = np.array(cs[np.argmax([len(i) for i in cs])]).reshape(-1, 2) + else: + c = np.zeros((0, 2)) + segments.append(c.astype("float32")) + return segments + + @staticmethod + def _crop_mask(masks, boxes): + """masks: (N,H,W), boxes: (N,4) xyxy""" + n, h, w = masks.shape + x1, y1, x2, y2 = np.split(boxes[:, :, None], 4, 1) + r = np.arange(w, dtype=x1.dtype)[None, None, :] + c = np.arange(h, dtype=x1.dtype)[None, :, None] + return masks * ((r >= x1) * (r < x2) * (c >= y1) * (c < y2)) + + @staticmethod + def _scale_mask(masks, im0_shape, ratio_pad=None): + """把特征图上的 mask 缩放到原图大小""" + im1_shape = masks.shape[:2] + if ratio_pad is None: + gain = min(im1_shape[0] / im0_shape[0], im1_shape[1] / im0_shape[1]) + pad = ( + (im1_shape[1] - im0_shape[1] * gain) / 2, + (im1_shape[0] - im0_shape[0] * gain) / 2, + ) + else: + pad = ratio_pad[1] + + top, left = int(round(pad[1] - 0.1)), int(round(pad[0] - 0.1)) + bottom = int(round(im1_shape[0] - pad[1] + 0.1)) + right = int(round(im1_shape[1] - pad[0] + 0.1)) + + if masks.ndim < 2: + raise ValueError("masks ndim 应该是 2 或 3") + + masks = masks[top:bottom, left:right] + masks = cv2.resize( + masks, (im0_shape[1], im0_shape[0]), interpolation=cv2.INTER_LINEAR + ) + if masks.ndim == 2: + masks = masks[:, :, None] + return masks + + def _process_mask(self, protos, masks_in, bboxes, im0_shape): + """ + protos: (C,Hm,Wm) + masks_in: (N,C) + bboxes: (N,4) xyxy + 返回: (N,H,W) bool + """ + c, mh, mw = protos.shape + masks = ( + np.matmul(masks_in, protos.reshape(c, -1)) + .reshape(-1, mh, mw) + .transpose(1, 2, 0) + ) # HWN + masks = np.ascontiguousarray(masks) + masks = self._scale_mask(masks, im0_shape) # HWC + masks = np.einsum("HWN->NHW", masks) # NHW + masks = self._crop_mask(masks, bboxes) + return masks > 0.5 + + @staticmethod + def _get_cid(name): + for k, v in CLASS_NAMES.items(): + if v == name: + return k + + @staticmethod + def _make_color_mask(img_bgr, masks, coords): + """ + 生成一张“带颜色的掩码图” + - 背景为黑色 + - 每个实例区域按类别上色(不叠加到原图) + 返回:color_mask (H,W,3) BGR uint8 + """ + h, w = img_bgr.shape[:2] + color_mask = np.zeros((h, w, 3), dtype=np.uint8) + + N = masks.shape[0] + for i in range(N): + m = masks[i] # (H,W) bool + inst = coords[i] + cid = YOLOSeg._get_cid(inst[0]) + # print(f"name: {inst[0]}, cid: {cid}") + color = CLASS_COLORS.get(cid, (0, 255, 255)) # 没配置的类用黄青色 + + # 只在掩码区域上色 + color_mask[m] = color + + return color_mask + + # ---------- 推理主入口 ---------- + def detect(self, img_input): + conf_thres = 0.1 + iou_thres = 0.1 + """ + 输入: 原始 BGR 图像 + 输出: + masks: (N,H,W) bool 掩码 + coords: List[dict] 每个实例包含 class_name, confidence, points(多边形) + color_mask: 带有颜色的掩码图(黑背景,上面是彩色的缺陷区域) + """ + if isinstance(img_input, str): + img_bgr = cv2.imdecode(np.fromfile(img_input, dtype=np.uint8), cv2.IMREAD_COLOR) + else: + img_bgr = img_input + + if img_bgr is None: + raise ValueError("img_bgr is None, 请检查图片读取是否成功") + + im0 = img_bgr.copy() + im, ratio, (pad_w, pad_h) = self._preprocess(im0) + + # ONNX 推理 + input_name = self.session.get_inputs()[0].name + preds = self.session.run(None, {input_name: im}) + x, protos = preds[0], preds[1] # x:(1,C,N), protos:(1,32,Hm,Wm) + + # (1,C,N) -> (N,C) + x = np.einsum("bcn->bnc", x)[0] # (N, C) + + # 从 protos 动态推断 mask 通道数 + nm = int(protos.shape[1]) # 一般是 32 + C = x.shape[1] + nc = C - 4 - nm # 类别数 + + # 类别分数区间 [4:4+nc] + cls_scores = x[:, 4:4 + nc] + cls_max = np.max(cls_scores, axis=-1) + keep = cls_max > conf_thres + x = x[keep] + cls_scores = cls_scores[keep] + + h0, w0 = im0.shape[:2] + + if x.size == 0: + # 没有检测到任何目标:返回空 mask、空坐标、空彩色掩码 + empty_masks = np.zeros((0, h0, w0), dtype=bool) + empty_color_mask = np.zeros((h0, w0, 3), dtype=np.uint8) + return empty_masks, [], empty_color_mask + + conf = cls_max[keep] + cls_id = np.argmax(cls_scores, axis=-1) + # 拼成 [cx,cy,w,h, conf, cls_id, mask_coeffs...] + x = np.c_[x[:, :4], conf, cls_id, x[:, -nm:]] + + # ===== NMS:OpenCV NMSBoxes 需要 [x, y, w, h] 左上角坐标 ===== + # 当前 x[:, :4] 是 [cx, cy, w, h],先转换成 [x, y, w, h] + bboxes_xywh = x[:, :4].copy() + bboxes_xywh[:, 0] = bboxes_xywh[:, 0] - bboxes_xywh[:, 2] / 2 # x = cx - w/2 + bboxes_xywh[:, 1] = bboxes_xywh[:, 1] - bboxes_xywh[:, 3] / 2 # y = cy - h/2 + + indices = cv2.dnn.NMSBoxes( + bboxes_xywh.tolist(), x[:, 4].tolist(), conf_thres, iou_thres + ) + + # 不同 OpenCV 版本,indices 可能是 []、[0,1]、[[0],[1]]、np.array([...]) + if indices is None or len(indices) == 0: + empty_masks = np.zeros((0, h0, w0), dtype=bool) + empty_color_mask = np.zeros((h0, w0, 3), dtype=np.uint8) + return empty_masks, [], empty_color_mask + + # 统一成一维整型索引数组 + indices = np.array(indices).reshape(-1) + x = x[indices] + + # cxcywh -> xyxy(这里用处理后的 x[:, :4]) + x[:, 0:2] -= x[:, 2:4] / 2 + x[:, 2:4] += x[:, 0:2] + + # 去掉 pad,缩放回原图 + x[:, [0, 2]] -= pad_w + x[:, [1, 3]] -= pad_h + x[:, :4] /= min(ratio) + + # 限制在图像范围内 + x[:, [0, 2]] = x[:, [0, 2]].clip(0, w0) + x[:, [1, 3]] = x[:, [1, 3]].clip(0, h0) + + # 解码 mask + protos = protos[0] # (32,Hm,Wm) + bboxes_xyxy = x[:, :4] + mask_coeffs = x[:, 6:] + masks = self._process_mask(protos, mask_coeffs, bboxes_xyxy, im0.shape) + + # 掩码 -> 多边形 + segments = self._masks2segments(masks) + + # 打包坐标结果 + coords = [] + for (x1, y1, x2, y2, conf_i, cls_i), seg in zip(x[:, :6], segments): + cid = int(cls_i) + coords.append( + ( + self.classes.get(cid, str(cid)), + seg.tolist() + ) + ) + + color_mask = self._make_color_mask(im0, masks, coords) + + return color_mask, coords + + +if __name__ == "__main__": + img_path = r"D:\Projects\Python\wall\app\core\yolo\test.jpg" + IMG_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff"} + # ====== 加载模型 ====== + model = YOLOSeg() + img = cv2.imdecode(np.fromfile(str(img_path), dtype=np.uint8), cv2.IMREAD_COLOR) + color_mask, coords = model.detect(img) + print(color_mask.shape) + print(coords) diff --git a/app/core/yolo/model.onnx b/app/core/yolo/model.onnx new file mode 100644 index 0000000..a3a1037 --- /dev/null +++ b/app/core/yolo/model.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9581e0a80fa950fb8ac6c6cfbaa928f3d3c7cfb428367b3c1ab55fe4e99bca05 +size 258267500 diff --git a/app/main.py b/app/main.py new file mode 100644 index 0000000..d84274b --- /dev/null +++ b/app/main.py @@ -0,0 +1,35 @@ +import os +from dotenv import load_dotenv +from fastapi import FastAPI +from fastapi.staticfiles import StaticFiles +from app.services.worker import Worker +from app.routes.health import router as health +from app.routes.analyze import router as submit_analyze +from app.routes.analyze_status import router as get_task_status +from app.routes.analyze_result import router as get_task_result + +load_dotenv() +UPLOAD_DIR = os.getenv("UPLOAD_DIR", "uploads") +MOCK = os.getenv("MOCK", "false") +MODEL = os.getenv("MODEL", "segformer") +PREPROCESS = os.getenv("PREPROCESS", "sam3") +WORKER = Worker() + +os.makedirs(UPLOAD_DIR, exist_ok=True) + +app = FastAPI() +app.mount("/" + UPLOAD_DIR, StaticFiles(directory=UPLOAD_DIR), name=UPLOAD_DIR) +app.include_router(health) +app.include_router(submit_analyze) +app.include_router(get_task_status) +app.include_router(get_task_result) + + + + + + + + + + diff --git a/app/routes/__init__.py b/app/routes/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/routes/analyze.py b/app/routes/analyze.py new file mode 100644 index 0000000..24ce994 --- /dev/null +++ b/app/routes/analyze.py @@ -0,0 +1,54 @@ +import os +import time +import uuid +from typing import Optional, List +from fastapi import APIRouter, UploadFile, File, Response +from app.schemas.analyze import Analyze, AnalyzeData +from app.services.model import TaskStatus, TaskStore + +router = APIRouter() + + +@router.post("/analyze") +async def submit_analyze(response: Response, images: Optional[List[UploadFile]] = File(default=[])): + from app.main import UPLOAD_DIR, WORKER + if not images: + response.status_code = 400 + return Analyze( + success=False, + data=AnalyzeData( + taskId="", + status=TaskStatus.FAILED.name, + message="请上传图片", + estimatedTime="", + filesReceived=0 + ) + ) + + task_id = f"task_{int(time.time() * 1000)}_{uuid.uuid4().hex[:10]}" + + task_dir = os.path.join(UPLOAD_DIR, task_id, "inputs") + os.makedirs(task_dir, exist_ok=True) + + saved_paths = [] + if images: + for idx, img in enumerate(images): + ext = os.path.splitext(img.filename)[1] + file_path = os.path.join(task_dir, f"{idx}{ext}") + with open(file_path, "wb") as f: + f.write(await img.read()) + saved_paths.append(file_path) + + WORKER.task_store[task_id] = TaskStore(images=saved_paths) + WORKER.task_queue.put(task_id) + + return Analyze( + success=True, + data=AnalyzeData( + taskId=task_id, + status=TaskStatus.QUEUED.name, + message=f"已提交 {len(saved_paths)} 张图片,正在分析", + estimatedTime="", + filesReceived=len(saved_paths) + ) + ) diff --git a/app/routes/analyze_result.py b/app/routes/analyze_result.py new file mode 100644 index 0000000..ea7bdcb --- /dev/null +++ b/app/routes/analyze_result.py @@ -0,0 +1,86 @@ +import json +import os + +from fastapi import APIRouter, Response +from app.schemas.analyze_result import AnalyzeResult, AnalyzeResultData, ImageInfo, MaskInfo, ResultItem +from app.services.model import TaskStatus + +router = APIRouter() + + +@router.get("/analyze/result/{task_id}") +async def get_task_result(task_id: str, response: Response): + from app.main import UPLOAD_DIR, WORKER + task = WORKER.task_store.get(task_id) + + if not task: + response.status_code = 404 + return AnalyzeResult( + success=False, + data=AnalyzeResultData( + taskId=task_id, + status=TaskStatus.NOT_FOUND.name, + completedAt=None, + results=None + ) + ) + + if task.status == TaskStatus.COMPLETED.name: + # 构建完成状态的结果数据 + result_items = [] + + # 输入和输出目录路径 + input_dir = os.path.join(UPLOAD_DIR, task_id, "inputs") + output_dir = os.path.join(UPLOAD_DIR, task_id, "outputs") + + for idx, result_data in enumerate(task.result): + # 解析坐标数据 + coords_data = json.loads(result_data.get("coords", "[]")) + + # 构建图片信息 + input_img_path = result_data.get("input_img_path", "") + output_img_path = result_data.get("output_img_path", "") + + # 构建URL路径 + input_filename = os.path.basename(input_img_path) + output_filename = os.path.basename(output_img_path) + + image_info = ImageInfo( + origin=f"/uploads/{task_id}/inputs/{input_filename}", + image=f"/uploads/{task_id}/outputs/{output_filename}" if output_img_path is not "" else "", + ) + + # 构建mask信息 + masks = [ + MaskInfo(name=mask["name"], coords=mask["coords"]) + for mask in coords_data + ] + + result_item = ResultItem( + id=str(idx), + images=image_info, + masks=masks + ) + result_items.append(result_item) + + response.status_code = 200 + return AnalyzeResult( + success=True, + data=AnalyzeResultData( + taskId=task_id, + status=task.status, + completedAt=task.completedAt.isoformat() if task.completedAt else "", + results=result_items + ) + ) + else: + # 其他状态(处理中、失败等) + return AnalyzeResult( + success=True, + data=AnalyzeResultData( + taskId=task_id, + status=task.status, + completedAt=None, + results=None + ) + ) diff --git a/app/routes/analyze_status.py b/app/routes/analyze_status.py new file mode 100644 index 0000000..ec4f26b --- /dev/null +++ b/app/routes/analyze_status.py @@ -0,0 +1,34 @@ +from fastapi import APIRouter, Response +from app.schemas.analyze_status import AnalyzeStatus, AnalyzeStatusData +from app.services.model import TaskStatus + +router = APIRouter() + + +@router.get("/analyze/status/{task_id}") +async def get_task_status(task_id: str, response: Response): + from app.main import WORKER + task = WORKER.task_store.get(task_id) + + if task: + response.status_code = 200 + return AnalyzeStatus( + success=True, + data=AnalyzeStatusData( + taskId=task_id, + status=task.status, + progress=task.progress, + message=task.message + ) + ) + else: + response.status_code = 404 + return AnalyzeStatus( + success=False, + data=AnalyzeStatusData( + taskId=task_id, + status=TaskStatus.NOT_FOUND.name, + progress=0, + message=TaskStatus.NOT_FOUND.value + ) + ) diff --git a/app/routes/health.py b/app/routes/health.py new file mode 100644 index 0000000..1462721 --- /dev/null +++ b/app/routes/health.py @@ -0,0 +1,15 @@ +from datetime import datetime +from fastapi import APIRouter +from app.schemas.health import Health +from app.services.model import TaskStatus + +router = APIRouter() + + +@router.get("/health", response_model=Health) +async def health(): + return Health( + status=TaskStatus.OK.name, + message="算法服务器正常运行", + timestamp=str(datetime.now()) + ) diff --git a/app/schemas/__init__.py b/app/schemas/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/schemas/analyze.py b/app/schemas/analyze.py new file mode 100644 index 0000000..a5413f1 --- /dev/null +++ b/app/schemas/analyze.py @@ -0,0 +1,18 @@ +from typing import Optional + +from pydantic import BaseModel + +""" 提交任务响应结构 """ + + +class AnalyzeData(BaseModel): + taskId: Optional[str] = None + status: str + message: str + estimatedTime: Optional[str] = None + filesReceived: Optional[int] = None + + +class Analyze(BaseModel): + success: bool + data: AnalyzeData diff --git a/app/schemas/analyze_result.py b/app/schemas/analyze_result.py new file mode 100644 index 0000000..bb64d02 --- /dev/null +++ b/app/schemas/analyze_result.py @@ -0,0 +1,32 @@ +from typing import List, Optional +from pydantic import BaseModel + +""" 获取任务结果响应结构 """ + + +class ImageInfo(BaseModel): + origin: str + image: str + + +class MaskInfo(BaseModel): + name: str + coords: List[List[int]] + + +class ResultItem(BaseModel): + id: str + images: ImageInfo + masks: List[MaskInfo] + + +class AnalyzeResultData(BaseModel): + taskId: str + status: str + completedAt: Optional[str] = None + results: Optional[List[ResultItem]] = None + + +class AnalyzeResult(BaseModel): + success: bool + data: AnalyzeResultData diff --git a/app/schemas/analyze_status.py b/app/schemas/analyze_status.py new file mode 100644 index 0000000..3fa0976 --- /dev/null +++ b/app/schemas/analyze_status.py @@ -0,0 +1,15 @@ +from pydantic import BaseModel + +""" 获取任务状态响应结构 """ + + +class AnalyzeStatusData(BaseModel): + taskId: str + status: str + progress: int + message: str + + +class AnalyzeStatus(BaseModel): + success: bool + data: AnalyzeStatusData diff --git a/app/schemas/health.py b/app/schemas/health.py new file mode 100644 index 0000000..18b30a0 --- /dev/null +++ b/app/schemas/health.py @@ -0,0 +1,7 @@ +from pydantic import BaseModel + + +class Health(BaseModel): + status: str + message: str + timestamp: str diff --git a/app/services/__init__.py b/app/services/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/services/model.py b/app/services/model.py new file mode 100644 index 0000000..7a2882f --- /dev/null +++ b/app/services/model.py @@ -0,0 +1,66 @@ +from datetime import datetime +from enum import Enum +from typing import List, Dict, Optional + + +class TaskStatus(Enum): + OK = "正常运行" + QUEUED = "已入队" + PROCESSING = "处理中" + COMPLETED = "处理完成" + FAILED = "处理错误" + NOT_FOUND = "任务不存在" + + +class TaskStore: + def __init__(self, images: List[str]): + self._status: str = TaskStatus.QUEUED.name + self._progress: int = 0 + self._images: List[str] = images + self._result: List[Dict[str, str]] = [] + self._message: str = "" + self._completedAt: Optional[datetime] = None + + @property + def images(self): + return self._images + + @property + def status(self): + return self._status + + @status.setter + def status(self, status: TaskStatus): + self._status = status + + @property + def progress(self): + return self._progress + + @progress.setter + def progress(self, progress: int): + self._progress = progress + + @property + def result(self): + return self._result + + @result.setter + def result(self, result: str): + self._result = result + + @property + def message(self): + return self._message + + @message.setter + def message(self, message: str): + self._message = message + + @property + def completedAt(self): + return self._completedAt + + @completedAt.setter + def completedAt(self, completedAt: datetime): + self._completedAt = completedAt \ No newline at end of file diff --git a/app/services/worker.py b/app/services/worker.py new file mode 100644 index 0000000..d058086 --- /dev/null +++ b/app/services/worker.py @@ -0,0 +1,80 @@ +import json +import os +import cv2 +import threading +from datetime import datetime +from queue import Queue +from typing import Dict + +from app.core.model import Model +from app.core.preprocess import Preprocess +from app.services.model import TaskStatus, TaskStore + + +class Worker: + def __init__(self): + self.detection = Model().getModel() + self.preprocess = Preprocess().getPreprocess() + self.task_queue = Queue() + self.task_store: Dict[str, TaskStore] = {} + + threading.Thread(target=self.worker, daemon=True).start() + + def worker(self): + from app.main import UPLOAD_DIR + while True: + task_id = self.task_queue.get() + if task_id is None: + break + + task = self.task_store.get(task_id) + if not task: + continue + + try: + task.status = TaskStatus.PROCESSING.name + task.progress = 0 + print(f"开始处理任务 {task_id}...") + + # 创建输出目录 + output_dir = os.path.join(UPLOAD_DIR, task_id, "outputs") + os.makedirs(output_dir, exist_ok=True) + + # 获取图像的标签列表 + image_labels = self.preprocess.preprocess(task.images) # 返回一个0和1的列表,0代表跳过,1代表进行检测 + + for idx, (input_img_path, label) in enumerate(zip(task.images, image_labels)): + print(f"处理任务 {task_id}, 处理图片 {input_img_path}...") + + if label == 0: + # 如果标签是0,跳过模型检测,输出路径和坐标为空 + task.result.append( + {"input_img_path": input_img_path, "output_img_path": "", "coords": "[]"} + ) + else: + # 进行模型检测 + img_res, coords_res = self.detection.detect(input_img_path) + + coords_res = [{"name": name, "coords": coords} for name, coords in coords_res] + coords_json = json.dumps(coords_res, ensure_ascii=False) + + out_img_path = os.path.join(output_dir, f"{idx}.jpg") + cv2.imwrite(out_img_path, img_res) + + task.result.append( + {"input_img_path": input_img_path, "output_img_path": out_img_path, "coords": coords_json} + ) + + task.progress = int((idx + 1) / len(task.images) * 100) + + task.status = TaskStatus.COMPLETED.name + task.completedAt = datetime.now() + task.message = "处理完成" + + print(f"任务 {task_id} 处理完成") + + except Exception as e: + task.status = TaskStatus.FAILED.name + task.message = str(e) + finally: + self.task_queue.task_done() diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..947a3d0 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,15 @@ +fastapi[all]==0.122.0 +albumentations==2.0.8 +albucore==0.0.24 +numpy==2.2.6 +onnxruntime-gpu==1.23.2 +opencv-python==4.12.0.88 +PyYAML>=6.0 +scipy>=1.10.0 +pydantic>=2.10.0 +typing-extensions>=4.12.0 +python-dotenv~=1.2.1 +opencv-python +matplotlib +pandas +tqdm==4.67.1 \ No newline at end of file