feat(api): 添加图像分析功能和相关路由接口
- 新增 analyze、analyze_result、analyze_status 和 health 路由 - 实现图像上传和任务提交功能 - 添加任务状态查询和结果获取接口 - 集成 segformer 和 yolo 模型进行图像检测 - 实现 SAM3 预处理功能用于图像预处理判断 - 添加模型选择配置支持 segformer 和 yolo - 实现任务队列管理和异步处理机制 - 添加 Dockerfile 用于容器化部署 - 配置环境变量和 gitignore 规则 - 创建数据模型定义 API 响应结构
This commit is contained in:
commit
6a2e046884
4
.env
Normal file
4
.env
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
UPLOAD_DIR=uploads
|
||||||
|
MOCK=false
|
||||||
|
MODEL=yolo #segformer, yolo
|
||||||
|
PREPROCESS=sam3
|
||||||
4
.gitattributes
vendored
Normal file
4
.gitattributes
vendored
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
*.pt filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.onnx filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.onnx.data filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.pth filter=lfs diff=lfs merge=lfs -text
|
||||||
6
.gitignore
vendored
Normal file
6
.gitignore
vendored
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
.idea/
|
||||||
|
docker-test/
|
||||||
|
tests/
|
||||||
|
uploads/
|
||||||
|
Wall Docker 镜像使用教程.md
|
||||||
|
Wall Docker 镜像使用教程.pdf
|
||||||
29
Dockerfile
Normal file
29
Dockerfile
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
FROM python:3.11
|
||||||
|
|
||||||
|
# 安装系统依赖并清理缓存
|
||||||
|
RUN apt-get update && apt-get install -y \
|
||||||
|
libgl1 \
|
||||||
|
libglib2.0-0 \
|
||||||
|
libsm6 \
|
||||||
|
libxrender1 \
|
||||||
|
libxext6 \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# 设置工作目录
|
||||||
|
WORKDIR /code
|
||||||
|
|
||||||
|
# 复制并安装 Python 依赖
|
||||||
|
COPY requirements.txt /code/requirements.txt
|
||||||
|
|
||||||
|
# 安装 Python 依赖(加速源)
|
||||||
|
RUN pip install --no-cache-dir --upgrade -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple \
|
||||||
|
&& pip install --no-cache-dir --upgrade torch torchvision --index-url https://download.pytorch.org/whl/cu130
|
||||||
|
|
||||||
|
# 复制应用代码
|
||||||
|
COPY ./app /code/app
|
||||||
|
|
||||||
|
# 删除无用的文件,避免占用磁盘空间
|
||||||
|
RUN rm -rf /code/app/core/*.onnx /code/app/core/*.data /code/app/core/*.pt
|
||||||
|
|
||||||
|
# 暴露端口并启动应用
|
||||||
|
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "80"]
|
||||||
56
README.md
Normal file
56
README.md
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
### Wall Docker 镜像使用教程
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
> 采用模型文件持久化,方便后续更新模型而不需要重新创建容器,以及统一配置配置文件
|
||||||
|
|
||||||
|
1. 导入docker images `docker load -i wall.tar`
|
||||||
|
|
||||||
|
2. 进入持久化目录,新建.env,添加相应内容
|
||||||
|
|
||||||
|
```env
|
||||||
|
UPLOAD_DIR=uploads
|
||||||
|
MOCK=false
|
||||||
|
MODEL=segformer #segformer, yolo,目前打包模型只有segformer
|
||||||
|
```
|
||||||
|
|
||||||
|
3. 解压算法模型目录到core文件夹
|
||||||
|
|
||||||
|
```bash
|
||||||
|
tar -xvf core.tar
|
||||||
|
```
|
||||||
|
|
||||||
|
4. 使用指令运行docker镜像
|
||||||
|
|
||||||
|
```bash
|
||||||
|
sudo docker run -d \
|
||||||
|
--name [docker_container_name] \
|
||||||
|
--gpus all \
|
||||||
|
-p [local_port]:80 \
|
||||||
|
-v $(pwd)/uploads:/code/uploads \
|
||||||
|
-v $(pwd)/core:/code/app/core \
|
||||||
|
-v $(pwd)/.env:/code/.env \
|
||||||
|
wall
|
||||||
|
```
|
||||||
|
|
||||||
|
5. 如果后续需要更新模型,只需要覆盖掉core内的文件,更改.env配置文件后,即可继续运行
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
> 如果不想要模型文件持久化,则不需要解压算法文件了
|
||||||
|
|
||||||
|
1. 导入docker images `docker load -i wall.tar`
|
||||||
|
|
||||||
|
2. 使用指令运行docker镜像
|
||||||
|
|
||||||
|
```bash
|
||||||
|
sudo docker run -d \
|
||||||
|
--name [docker_container_name] \
|
||||||
|
--gpus all \
|
||||||
|
-p [local_port]:80 \
|
||||||
|
-v $(pwd)/core:/code/app/core \
|
||||||
|
-v $(pwd)/.env:/code/.env \
|
||||||
|
wall
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
0
app/__init__.py
Normal file
0
app/__init__.py
Normal file
0
app/core/__init__.py
Normal file
0
app/core/__init__.py
Normal file
16
app/core/model.py
Normal file
16
app/core/model.py
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
from app.core.segformer.detect import Detection as SegFormer, DetectionMock as SegFormerMock
|
||||||
|
from app.core.yolo.detect import YOLOSeg
|
||||||
|
|
||||||
|
|
||||||
|
class Model:
|
||||||
|
def __init__(self):
|
||||||
|
from app.main import MODEL
|
||||||
|
if MODEL == "segformer":
|
||||||
|
print("使用 SegFormer 模型")
|
||||||
|
self.detection = SegFormer()
|
||||||
|
elif MODEL == "yolo":
|
||||||
|
print("使用 YOLO 模型")
|
||||||
|
self.detection = YOLOSeg()
|
||||||
|
|
||||||
|
def getModel(self):
|
||||||
|
return self.detection
|
||||||
12
app/core/preprocess.py
Normal file
12
app/core/preprocess.py
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
from app.core.sam3.preprocess import SAM3
|
||||||
|
|
||||||
|
|
||||||
|
class Preprocess:
|
||||||
|
def __init__(self):
|
||||||
|
from app.main import PREPROCESS
|
||||||
|
if PREPROCESS == "sam3":
|
||||||
|
print("使用 SAM3 进行预处理判断")
|
||||||
|
self.preprocess = SAM3()
|
||||||
|
|
||||||
|
def getPreprocess(self):
|
||||||
|
return self.preprocess
|
||||||
0
app/core/sam3/__init__.py
Normal file
0
app/core/sam3/__init__.py
Normal file
172
app/core/sam3/preprocess.py
Normal file
172
app/core/sam3/preprocess.py
Normal file
@ -0,0 +1,172 @@
|
|||||||
|
import os
|
||||||
|
import torch
|
||||||
|
from pathlib import Path
|
||||||
|
from PIL import Image
|
||||||
|
from torch.utils.data import Dataset, DataLoader
|
||||||
|
from sam3.model_builder import build_sam3_image_model
|
||||||
|
from sam3.train.data.collator import collate_fn_api as collate
|
||||||
|
from sam3.model.utils.misc import copy_data_to_device
|
||||||
|
from sam3.train.data.sam3_image_dataset import (
|
||||||
|
Datapoint, Image as SAMImage, FindQueryLoaded, InferenceMetadata
|
||||||
|
)
|
||||||
|
from sam3.train.transforms.basic_for_api import ComposeAPI, RandomResizeAPI, ToTensorAPI, NormalizeAPI
|
||||||
|
from sam3.eval.postprocessors import PostProcessImage
|
||||||
|
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||||
|
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True,max_split_size_mb:256"
|
||||||
|
|
||||||
|
# ===== 配置 =====
|
||||||
|
CKPT_PATH = os.path.join(os.getcwd(), "app/core/sam3", "sam3.pt")
|
||||||
|
DEVICE = "cuda:0"
|
||||||
|
|
||||||
|
BATCH_SIZE = 12 # 批量大小,前端要设置
|
||||||
|
NUM_WORKERS = 12 # 加载图片的线程数,看前端要不要设置
|
||||||
|
CONF_TH = 0.5
|
||||||
|
RATIO_TH = 0.5 # 阈值,越大的话过滤越多,但太大会影响近景图片
|
||||||
|
_GLOBAL_ID = 1
|
||||||
|
|
||||||
|
PROMPTS = [
|
||||||
|
"wall",
|
||||||
|
"building wall",
|
||||||
|
"building facade",
|
||||||
|
"building exterior wall",
|
||||||
|
"exterior building facade",
|
||||||
|
]
|
||||||
|
IMG_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff", ".webp"}
|
||||||
|
|
||||||
|
|
||||||
|
# ============
|
||||||
|
|
||||||
|
|
||||||
|
class ImgPathList(Dataset):
|
||||||
|
def __init__(self, img_paths: list):
|
||||||
|
"""
|
||||||
|
初始化 ImgFolder,传入一个图片路径的列表
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img_paths (list): 一个包含图片路径的列表
|
||||||
|
"""
|
||||||
|
self.paths = img_paths # 使用传入的路径列表
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.paths)
|
||||||
|
|
||||||
|
def __getitem__(self, i):
|
||||||
|
p = self.paths[i] # 直接使用列表中的路径
|
||||||
|
img = Image.open(p).convert("RGB") # 打开图片并转换为RGB模式
|
||||||
|
return p, img # 返回图片的路径和图片本身
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class SAM3:
|
||||||
|
def __init__(self):
|
||||||
|
self.dev = torch.device(DEVICE)
|
||||||
|
self.postprocessor = PostProcessImage(
|
||||||
|
max_dets_per_img=-1,
|
||||||
|
iou_type="segm",
|
||||||
|
use_original_sizes_box=True,
|
||||||
|
use_original_sizes_mask=True,
|
||||||
|
convert_mask_to_rle=False,
|
||||||
|
detection_threshold=CONF_TH,
|
||||||
|
to_cpu=False,
|
||||||
|
)
|
||||||
|
self.transform = ComposeAPI(
|
||||||
|
transforms=[
|
||||||
|
RandomResizeAPI(sizes=1008, max_size=1008, square=True, consistent_transform=False),
|
||||||
|
ToTensorAPI(),
|
||||||
|
NormalizeAPI(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.model = build_sam3_image_model(
|
||||||
|
checkpoint_path=CKPT_PATH, load_from_HF=False, device=DEVICE
|
||||||
|
).to(DEVICE).eval()
|
||||||
|
|
||||||
|
def preprocess(self, image_path_list):
|
||||||
|
labels = []
|
||||||
|
|
||||||
|
loader = DataLoader(
|
||||||
|
ImgPathList(image_path_list),
|
||||||
|
batch_size=BATCH_SIZE,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=NUM_WORKERS,
|
||||||
|
pin_memory=True,
|
||||||
|
collate_fn=self.collate_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
for names, images in loader:
|
||||||
|
datapoints = []
|
||||||
|
name2qids = {} # name -> [qid,...]
|
||||||
|
for name, img in zip(names, images):
|
||||||
|
dp = self.create_empty_datapoint()
|
||||||
|
self.set_image(dp, img)
|
||||||
|
|
||||||
|
qids = [self.add_text_prompt(dp, p) for p in PROMPTS]
|
||||||
|
name2qids[name] = qids
|
||||||
|
|
||||||
|
datapoints.append(self.transform(dp))
|
||||||
|
|
||||||
|
batch = collate(datapoints, dict_key="dummy")["dummy"]
|
||||||
|
batch = copy_data_to_device(batch, self.dev, non_blocking=True)
|
||||||
|
output = self.model(batch)
|
||||||
|
|
||||||
|
processed = self.postprocessor.process_results(output, batch.find_metadatas)
|
||||||
|
|
||||||
|
for name in names:
|
||||||
|
any_masks = []
|
||||||
|
for qid in name2qids[name]:
|
||||||
|
res = processed[qid]
|
||||||
|
m = res.get("masks", None) # 期望: [N,H,W]
|
||||||
|
if m is None:
|
||||||
|
any_masks.append(torch.zeros(1, 1, device=self.dev, dtype=torch.bool).squeeze())
|
||||||
|
else:
|
||||||
|
if not torch.is_tensor(m):
|
||||||
|
m = torch.as_tensor(m, device=self.dev)
|
||||||
|
any_masks.append(m.any(0)) # [H,W]
|
||||||
|
|
||||||
|
wall_mask = torch.stack(any_masks, 0).any(0) # [H,W] bool
|
||||||
|
ratio = wall_mask.float().mean().item()
|
||||||
|
lab = 1 if ratio >= RATIO_TH else 0
|
||||||
|
labels.append(lab)
|
||||||
|
print(f"{name} | wall_ratio={ratio:.4f} -> {lab}") # 这行可以不要
|
||||||
|
|
||||||
|
return labels
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def add_text_prompt(datapoint, text_query):
|
||||||
|
global _GLOBAL_ID
|
||||||
|
assert len(datapoint.images) == 1, "please set the image first"
|
||||||
|
w, h = datapoint.images[0].size
|
||||||
|
datapoint.find_queries.append(
|
||||||
|
FindQueryLoaded(
|
||||||
|
query_text=text_query,
|
||||||
|
image_id=0,
|
||||||
|
object_ids_output=[],
|
||||||
|
is_exhaustive=True,
|
||||||
|
query_processing_order=0,
|
||||||
|
inference_metadata=InferenceMetadata(
|
||||||
|
coco_image_id=_GLOBAL_ID,
|
||||||
|
original_image_id=_GLOBAL_ID,
|
||||||
|
original_category_id=1,
|
||||||
|
original_size=[w, h],
|
||||||
|
object_id=0,
|
||||||
|
frame_index=0,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
_GLOBAL_ID += 1
|
||||||
|
return _GLOBAL_ID - 1
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_empty_datapoint():
|
||||||
|
return Datapoint(find_queries=[], images=[])
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def set_image(datapoint, pil_image):
|
||||||
|
w, h = pil_image.size
|
||||||
|
datapoint.images = [SAMImage(data=pil_image, objects=[], size=[h, w])] # size 用 [H,W]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def collate_fn(batch):
|
||||||
|
names, imgs = zip(*batch)
|
||||||
|
return list(names), list(imgs)
|
||||||
BIN
app/core/sam3/sam3.pt
(Stored with Git LFS)
Normal file
BIN
app/core/sam3/sam3.pt
(Stored with Git LFS)
Normal file
Binary file not shown.
0
app/core/segformer/__init__.py
Normal file
0
app/core/segformer/__init__.py
Normal file
99
app/core/segformer/detect.py
Normal file
99
app/core/segformer/detect.py
Normal file
@ -0,0 +1,99 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import albumentations as A
|
||||||
|
import onnxruntime as ort
|
||||||
|
|
||||||
|
|
||||||
|
class Detection:
|
||||||
|
def __init__(self):
|
||||||
|
self.CLASS_NAMES = ["background", "Hollowing", "Water seepage", "Cracking"]
|
||||||
|
self.PALETTE = np.array([[0, 0, 0], [255, 0, 0], [0, 255, 0], [0, 0, 255]], dtype=np.uint8)
|
||||||
|
self.tfm = A.Compose([
|
||||||
|
A.Resize(512, 512),
|
||||||
|
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||||
|
])
|
||||||
|
real_path = os.path.join(os.path.dirname(__file__), "segformer.onnx")
|
||||||
|
self.model = ort.InferenceSession(
|
||||||
|
real_path,
|
||||||
|
providers=["CUDAExecutionProvider", "CPUExecutionProvider"]
|
||||||
|
)
|
||||||
|
print("ONNX 模型加载完成!")
|
||||||
|
|
||||||
|
def get_contours(self, mask_np):
|
||||||
|
res = []
|
||||||
|
for idx in range(1, len(self.CLASS_NAMES)):
|
||||||
|
name = self.CLASS_NAMES[idx]
|
||||||
|
binary = (mask_np == idx).astype(np.uint8) * 255
|
||||||
|
cnts, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||||
|
for c in cnts:
|
||||||
|
if cv2.contourArea(c) < 20:
|
||||||
|
continue
|
||||||
|
eps = 0.002 * cv2.arcLength(c, True)
|
||||||
|
poly = cv2.approxPolyDP(c, eps, True)
|
||||||
|
if len(poly) >= 3:
|
||||||
|
res.append((name, poly.reshape(-1, 2).tolist()))
|
||||||
|
return res
|
||||||
|
|
||||||
|
def detect(self, img_input):
|
||||||
|
# 读取图片
|
||||||
|
if isinstance(img_input, str):
|
||||||
|
img_bgr = cv2.imdecode(np.fromfile(img_input, dtype=np.uint8), cv2.IMREAD_COLOR)
|
||||||
|
else:
|
||||||
|
img_bgr = img_input
|
||||||
|
|
||||||
|
h, w = img_bgr.shape[:2]
|
||||||
|
img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
|
||||||
|
|
||||||
|
# 数据预处理
|
||||||
|
x = self.tfm(image=img_rgb)["image"]
|
||||||
|
# 将 HWC -> CHW 并增加 batch 维度
|
||||||
|
x = np.transpose(x, (2, 0, 1)) # HWC -> CHW
|
||||||
|
x = np.expand_dims(x, axis=0).astype(np.float32) # batch x C x H x W
|
||||||
|
|
||||||
|
# ONNX 推理
|
||||||
|
inp_name = self.model.get_inputs()[0].name
|
||||||
|
out_name = self.model.get_outputs()[0].name
|
||||||
|
out = self.model.run([out_name], {inp_name: x})[0]
|
||||||
|
|
||||||
|
# 获取预测结果
|
||||||
|
pred = out.argmax(axis=1)[0].astype(np.uint8)
|
||||||
|
|
||||||
|
mask = cv2.resize(pred, (w, h), interpolation=cv2.INTER_NEAREST)
|
||||||
|
mask_rgb = self.PALETTE[mask]
|
||||||
|
coords = self.get_contours(mask)
|
||||||
|
# print(coords)
|
||||||
|
return mask_rgb, coords
|
||||||
|
|
||||||
|
|
||||||
|
class DetectionMock:
|
||||||
|
def __init__(self):
|
||||||
|
self.onnx_path = "segformer.onnx"
|
||||||
|
|
||||||
|
def detect(self, img_input):
|
||||||
|
if isinstance(img_input, str):
|
||||||
|
img_bgr = cv2.imdecode(np.fromfile(img_input, dtype=np.uint8), cv2.IMREAD_COLOR)
|
||||||
|
else:
|
||||||
|
img_bgr = img_input
|
||||||
|
|
||||||
|
img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
|
||||||
|
|
||||||
|
return img_rgb, [
|
||||||
|
('Cracking', [[771, 928], [757, 935]]),
|
||||||
|
('Cracking', [[254, 740], [251, 942]]),
|
||||||
|
('Cracking', [[764, 420], [764, 424]]),
|
||||||
|
('Cracking', [[257, 238], [251, 245]]),
|
||||||
|
('Cracking', [[1436, 145], [1401, 145]])
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
img_path = "test.jpg"
|
||||||
|
|
||||||
|
detection = Detection()
|
||||||
|
mask_res, coords_res = detection.detect(img_path)
|
||||||
|
|
||||||
|
print("Mask Shape:", mask_res.shape)
|
||||||
|
print("Detections:", coords_res)
|
||||||
|
cv2.imwrite("res_onnx.png", cv2.cvtColor(mask_res, cv2.COLOR_RGB2BGR))
|
||||||
BIN
app/core/segformer/segformer.onnx
(Stored with Git LFS)
Normal file
BIN
app/core/segformer/segformer.onnx
(Stored with Git LFS)
Normal file
Binary file not shown.
BIN
app/core/segformer/segformer_b2.onnx.data
(Stored with Git LFS)
Normal file
BIN
app/core/segformer/segformer_b2.onnx.data
(Stored with Git LFS)
Normal file
Binary file not shown.
0
app/core/yolo/__init__.py
Normal file
0
app/core/yolo/__init__.py
Normal file
291
app/core/yolo/detect.py
Normal file
291
app/core/yolo/detect.py
Normal file
@ -0,0 +1,291 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import onnxruntime as ort
|
||||||
|
|
||||||
|
CLASS_NAMES = {
|
||||||
|
0: "Hollowing",
|
||||||
|
1: "Water seepage",
|
||||||
|
2: "Cracking",
|
||||||
|
}
|
||||||
|
|
||||||
|
CLASS_COLORS = {
|
||||||
|
0: (0, 0, 255), # Hollowing -> 红色
|
||||||
|
1: (0, 255, 0), # Water seepage -> 绿色
|
||||||
|
2: (255, 0, 0), # Cracking -> 蓝色
|
||||||
|
}
|
||||||
|
|
||||||
|
IMG_SIZE = 640
|
||||||
|
|
||||||
|
|
||||||
|
class YOLOSeg:
|
||||||
|
def __init__(self, onnx_path: str = "model.onnx", imgsz: int = IMG_SIZE):
|
||||||
|
real_path = os.path.join(os.path.dirname(__file__), onnx_path)
|
||||||
|
self.session = ort.InferenceSession(
|
||||||
|
real_path,
|
||||||
|
providers=["CUDAExecutionProvider", "CPUExecutionProvider"]
|
||||||
|
if ort.get_device() == "GPU"
|
||||||
|
else ["CPUExecutionProvider"],
|
||||||
|
)
|
||||||
|
self.ndtype = (
|
||||||
|
np.half
|
||||||
|
if self.session.get_inputs()[0].type == "tensor(float16)"
|
||||||
|
else np.single
|
||||||
|
)
|
||||||
|
self.imgsz = imgsz
|
||||||
|
self.classes = CLASS_NAMES
|
||||||
|
|
||||||
|
# ---------- 预处理:letterbox ----------
|
||||||
|
def _preprocess(self, img_bgr):
|
||||||
|
h0, w0 = img_bgr.shape[:2]
|
||||||
|
new_shape = (self.imgsz, self.imgsz)
|
||||||
|
|
||||||
|
r = min(new_shape[0] / h0, new_shape[1] / w0)
|
||||||
|
ratio = (r, r)
|
||||||
|
new_unpad = (int(round(w0 * r)), int(round(h0 * r)))
|
||||||
|
pad_w = (new_shape[1] - new_unpad[0]) / 2
|
||||||
|
pad_h = (new_shape[0] - new_unpad[1]) / 2
|
||||||
|
|
||||||
|
if (w0, h0) != new_unpad:
|
||||||
|
img = cv2.resize(img_bgr, new_unpad, interpolation=cv2.INTER_LINEAR)
|
||||||
|
else:
|
||||||
|
img = img_bgr.copy()
|
||||||
|
|
||||||
|
top, bottom = int(round(pad_h - 0.1)), int(round(pad_h + 0.1))
|
||||||
|
left, right = int(round(pad_w - 0.1)), int(round(pad_w + 0.1))
|
||||||
|
img = cv2.copyMakeBorder(
|
||||||
|
img, top, bottom, left, right,
|
||||||
|
borderType=cv2.BORDER_CONSTANT, value=(114, 114, 114)
|
||||||
|
)
|
||||||
|
|
||||||
|
# HWC -> CHW, BGR->RGB, /255
|
||||||
|
img = np.ascontiguousarray(np.einsum("HWC->CHW", img)[::-1], dtype=self.ndtype)
|
||||||
|
img = img / 255.0
|
||||||
|
if img.ndim == 3:
|
||||||
|
img = img[None] # (1,3,H,W)
|
||||||
|
return img, ratio, (pad_w, pad_h)
|
||||||
|
|
||||||
|
# ---------- mask -> 多边形 ----------
|
||||||
|
@staticmethod
|
||||||
|
def _masks2segments(masks):
|
||||||
|
"""masks: (N,H,W) -> 每个实例的多边形坐标"""
|
||||||
|
segments = []
|
||||||
|
for x in masks.astype("uint8"):
|
||||||
|
cs = cv2.findContours(x, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)[0]
|
||||||
|
if cs:
|
||||||
|
# 取点数最多的一条轮廓
|
||||||
|
c = np.array(cs[np.argmax([len(i) for i in cs])]).reshape(-1, 2)
|
||||||
|
else:
|
||||||
|
c = np.zeros((0, 2))
|
||||||
|
segments.append(c.astype("float32"))
|
||||||
|
return segments
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _crop_mask(masks, boxes):
|
||||||
|
"""masks: (N,H,W), boxes: (N,4) xyxy"""
|
||||||
|
n, h, w = masks.shape
|
||||||
|
x1, y1, x2, y2 = np.split(boxes[:, :, None], 4, 1)
|
||||||
|
r = np.arange(w, dtype=x1.dtype)[None, None, :]
|
||||||
|
c = np.arange(h, dtype=x1.dtype)[None, :, None]
|
||||||
|
return masks * ((r >= x1) * (r < x2) * (c >= y1) * (c < y2))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _scale_mask(masks, im0_shape, ratio_pad=None):
|
||||||
|
"""把特征图上的 mask 缩放到原图大小"""
|
||||||
|
im1_shape = masks.shape[:2]
|
||||||
|
if ratio_pad is None:
|
||||||
|
gain = min(im1_shape[0] / im0_shape[0], im1_shape[1] / im0_shape[1])
|
||||||
|
pad = (
|
||||||
|
(im1_shape[1] - im0_shape[1] * gain) / 2,
|
||||||
|
(im1_shape[0] - im0_shape[0] * gain) / 2,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
pad = ratio_pad[1]
|
||||||
|
|
||||||
|
top, left = int(round(pad[1] - 0.1)), int(round(pad[0] - 0.1))
|
||||||
|
bottom = int(round(im1_shape[0] - pad[1] + 0.1))
|
||||||
|
right = int(round(im1_shape[1] - pad[0] + 0.1))
|
||||||
|
|
||||||
|
if masks.ndim < 2:
|
||||||
|
raise ValueError("masks ndim 应该是 2 或 3")
|
||||||
|
|
||||||
|
masks = masks[top:bottom, left:right]
|
||||||
|
masks = cv2.resize(
|
||||||
|
masks, (im0_shape[1], im0_shape[0]), interpolation=cv2.INTER_LINEAR
|
||||||
|
)
|
||||||
|
if masks.ndim == 2:
|
||||||
|
masks = masks[:, :, None]
|
||||||
|
return masks
|
||||||
|
|
||||||
|
def _process_mask(self, protos, masks_in, bboxes, im0_shape):
|
||||||
|
"""
|
||||||
|
protos: (C,Hm,Wm)
|
||||||
|
masks_in: (N,C)
|
||||||
|
bboxes: (N,4) xyxy
|
||||||
|
返回: (N,H,W) bool
|
||||||
|
"""
|
||||||
|
c, mh, mw = protos.shape
|
||||||
|
masks = (
|
||||||
|
np.matmul(masks_in, protos.reshape(c, -1))
|
||||||
|
.reshape(-1, mh, mw)
|
||||||
|
.transpose(1, 2, 0)
|
||||||
|
) # HWN
|
||||||
|
masks = np.ascontiguousarray(masks)
|
||||||
|
masks = self._scale_mask(masks, im0_shape) # HWC
|
||||||
|
masks = np.einsum("HWN->NHW", masks) # NHW
|
||||||
|
masks = self._crop_mask(masks, bboxes)
|
||||||
|
return masks > 0.5
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_cid(name):
|
||||||
|
for k, v in CLASS_NAMES.items():
|
||||||
|
if v == name:
|
||||||
|
return k
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _make_color_mask(img_bgr, masks, coords):
|
||||||
|
"""
|
||||||
|
生成一张“带颜色的掩码图”
|
||||||
|
- 背景为黑色
|
||||||
|
- 每个实例区域按类别上色(不叠加到原图)
|
||||||
|
返回:color_mask (H,W,3) BGR uint8
|
||||||
|
"""
|
||||||
|
h, w = img_bgr.shape[:2]
|
||||||
|
color_mask = np.zeros((h, w, 3), dtype=np.uint8)
|
||||||
|
|
||||||
|
N = masks.shape[0]
|
||||||
|
for i in range(N):
|
||||||
|
m = masks[i] # (H,W) bool
|
||||||
|
inst = coords[i]
|
||||||
|
cid = YOLOSeg._get_cid(inst[0])
|
||||||
|
# print(f"name: {inst[0]}, cid: {cid}")
|
||||||
|
color = CLASS_COLORS.get(cid, (0, 255, 255)) # 没配置的类用黄青色
|
||||||
|
|
||||||
|
# 只在掩码区域上色
|
||||||
|
color_mask[m] = color
|
||||||
|
|
||||||
|
return color_mask
|
||||||
|
|
||||||
|
# ---------- 推理主入口 ----------
|
||||||
|
def detect(self, img_input):
|
||||||
|
conf_thres = 0.1
|
||||||
|
iou_thres = 0.1
|
||||||
|
"""
|
||||||
|
输入: 原始 BGR 图像
|
||||||
|
输出:
|
||||||
|
masks: (N,H,W) bool 掩码
|
||||||
|
coords: List[dict] 每个实例包含 class_name, confidence, points(多边形)
|
||||||
|
color_mask: 带有颜色的掩码图(黑背景,上面是彩色的缺陷区域)
|
||||||
|
"""
|
||||||
|
if isinstance(img_input, str):
|
||||||
|
img_bgr = cv2.imdecode(np.fromfile(img_input, dtype=np.uint8), cv2.IMREAD_COLOR)
|
||||||
|
else:
|
||||||
|
img_bgr = img_input
|
||||||
|
|
||||||
|
if img_bgr is None:
|
||||||
|
raise ValueError("img_bgr is None, 请检查图片读取是否成功")
|
||||||
|
|
||||||
|
im0 = img_bgr.copy()
|
||||||
|
im, ratio, (pad_w, pad_h) = self._preprocess(im0)
|
||||||
|
|
||||||
|
# ONNX 推理
|
||||||
|
input_name = self.session.get_inputs()[0].name
|
||||||
|
preds = self.session.run(None, {input_name: im})
|
||||||
|
x, protos = preds[0], preds[1] # x:(1,C,N), protos:(1,32,Hm,Wm)
|
||||||
|
|
||||||
|
# (1,C,N) -> (N,C)
|
||||||
|
x = np.einsum("bcn->bnc", x)[0] # (N, C)
|
||||||
|
|
||||||
|
# 从 protos 动态推断 mask 通道数
|
||||||
|
nm = int(protos.shape[1]) # 一般是 32
|
||||||
|
C = x.shape[1]
|
||||||
|
nc = C - 4 - nm # 类别数
|
||||||
|
|
||||||
|
# 类别分数区间 [4:4+nc]
|
||||||
|
cls_scores = x[:, 4:4 + nc]
|
||||||
|
cls_max = np.max(cls_scores, axis=-1)
|
||||||
|
keep = cls_max > conf_thres
|
||||||
|
x = x[keep]
|
||||||
|
cls_scores = cls_scores[keep]
|
||||||
|
|
||||||
|
h0, w0 = im0.shape[:2]
|
||||||
|
|
||||||
|
if x.size == 0:
|
||||||
|
# 没有检测到任何目标:返回空 mask、空坐标、空彩色掩码
|
||||||
|
empty_masks = np.zeros((0, h0, w0), dtype=bool)
|
||||||
|
empty_color_mask = np.zeros((h0, w0, 3), dtype=np.uint8)
|
||||||
|
return empty_masks, [], empty_color_mask
|
||||||
|
|
||||||
|
conf = cls_max[keep]
|
||||||
|
cls_id = np.argmax(cls_scores, axis=-1)
|
||||||
|
# 拼成 [cx,cy,w,h, conf, cls_id, mask_coeffs...]
|
||||||
|
x = np.c_[x[:, :4], conf, cls_id, x[:, -nm:]]
|
||||||
|
|
||||||
|
# ===== NMS:OpenCV NMSBoxes 需要 [x, y, w, h] 左上角坐标 =====
|
||||||
|
# 当前 x[:, :4] 是 [cx, cy, w, h],先转换成 [x, y, w, h]
|
||||||
|
bboxes_xywh = x[:, :4].copy()
|
||||||
|
bboxes_xywh[:, 0] = bboxes_xywh[:, 0] - bboxes_xywh[:, 2] / 2 # x = cx - w/2
|
||||||
|
bboxes_xywh[:, 1] = bboxes_xywh[:, 1] - bboxes_xywh[:, 3] / 2 # y = cy - h/2
|
||||||
|
|
||||||
|
indices = cv2.dnn.NMSBoxes(
|
||||||
|
bboxes_xywh.tolist(), x[:, 4].tolist(), conf_thres, iou_thres
|
||||||
|
)
|
||||||
|
|
||||||
|
# 不同 OpenCV 版本,indices 可能是 []、[0,1]、[[0],[1]]、np.array([...])
|
||||||
|
if indices is None or len(indices) == 0:
|
||||||
|
empty_masks = np.zeros((0, h0, w0), dtype=bool)
|
||||||
|
empty_color_mask = np.zeros((h0, w0, 3), dtype=np.uint8)
|
||||||
|
return empty_masks, [], empty_color_mask
|
||||||
|
|
||||||
|
# 统一成一维整型索引数组
|
||||||
|
indices = np.array(indices).reshape(-1)
|
||||||
|
x = x[indices]
|
||||||
|
|
||||||
|
# cxcywh -> xyxy(这里用处理后的 x[:, :4])
|
||||||
|
x[:, 0:2] -= x[:, 2:4] / 2
|
||||||
|
x[:, 2:4] += x[:, 0:2]
|
||||||
|
|
||||||
|
# 去掉 pad,缩放回原图
|
||||||
|
x[:, [0, 2]] -= pad_w
|
||||||
|
x[:, [1, 3]] -= pad_h
|
||||||
|
x[:, :4] /= min(ratio)
|
||||||
|
|
||||||
|
# 限制在图像范围内
|
||||||
|
x[:, [0, 2]] = x[:, [0, 2]].clip(0, w0)
|
||||||
|
x[:, [1, 3]] = x[:, [1, 3]].clip(0, h0)
|
||||||
|
|
||||||
|
# 解码 mask
|
||||||
|
protos = protos[0] # (32,Hm,Wm)
|
||||||
|
bboxes_xyxy = x[:, :4]
|
||||||
|
mask_coeffs = x[:, 6:]
|
||||||
|
masks = self._process_mask(protos, mask_coeffs, bboxes_xyxy, im0.shape)
|
||||||
|
|
||||||
|
# 掩码 -> 多边形
|
||||||
|
segments = self._masks2segments(masks)
|
||||||
|
|
||||||
|
# 打包坐标结果
|
||||||
|
coords = []
|
||||||
|
for (x1, y1, x2, y2, conf_i, cls_i), seg in zip(x[:, :6], segments):
|
||||||
|
cid = int(cls_i)
|
||||||
|
coords.append(
|
||||||
|
(
|
||||||
|
self.classes.get(cid, str(cid)),
|
||||||
|
seg.tolist()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
color_mask = self._make_color_mask(im0, masks, coords)
|
||||||
|
|
||||||
|
return color_mask, coords
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
img_path = r"D:\Projects\Python\wall\app\core\yolo\test.jpg"
|
||||||
|
IMG_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff"}
|
||||||
|
# ====== 加载模型 ======
|
||||||
|
model = YOLOSeg()
|
||||||
|
img = cv2.imdecode(np.fromfile(str(img_path), dtype=np.uint8), cv2.IMREAD_COLOR)
|
||||||
|
color_mask, coords = model.detect(img)
|
||||||
|
print(color_mask.shape)
|
||||||
|
print(coords)
|
||||||
BIN
app/core/yolo/model.onnx
(Stored with Git LFS)
Normal file
BIN
app/core/yolo/model.onnx
(Stored with Git LFS)
Normal file
Binary file not shown.
35
app/main.py
Normal file
35
app/main.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
import os
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.staticfiles import StaticFiles
|
||||||
|
from app.services.worker import Worker
|
||||||
|
from app.routes.health import router as health
|
||||||
|
from app.routes.analyze import router as submit_analyze
|
||||||
|
from app.routes.analyze_status import router as get_task_status
|
||||||
|
from app.routes.analyze_result import router as get_task_result
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
UPLOAD_DIR = os.getenv("UPLOAD_DIR", "uploads")
|
||||||
|
MOCK = os.getenv("MOCK", "false")
|
||||||
|
MODEL = os.getenv("MODEL", "segformer")
|
||||||
|
PREPROCESS = os.getenv("PREPROCESS", "sam3")
|
||||||
|
WORKER = Worker()
|
||||||
|
|
||||||
|
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
app.mount("/" + UPLOAD_DIR, StaticFiles(directory=UPLOAD_DIR), name=UPLOAD_DIR)
|
||||||
|
app.include_router(health)
|
||||||
|
app.include_router(submit_analyze)
|
||||||
|
app.include_router(get_task_status)
|
||||||
|
app.include_router(get_task_result)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
0
app/routes/__init__.py
Normal file
0
app/routes/__init__.py
Normal file
54
app/routes/analyze.py
Normal file
54
app/routes/analyze.py
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
import os
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from typing import Optional, List
|
||||||
|
from fastapi import APIRouter, UploadFile, File, Response
|
||||||
|
from app.schemas.analyze import Analyze, AnalyzeData
|
||||||
|
from app.services.model import TaskStatus, TaskStore
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/analyze")
|
||||||
|
async def submit_analyze(response: Response, images: Optional[List[UploadFile]] = File(default=[])):
|
||||||
|
from app.main import UPLOAD_DIR, WORKER
|
||||||
|
if not images:
|
||||||
|
response.status_code = 400
|
||||||
|
return Analyze(
|
||||||
|
success=False,
|
||||||
|
data=AnalyzeData(
|
||||||
|
taskId="",
|
||||||
|
status=TaskStatus.FAILED.name,
|
||||||
|
message="请上传图片",
|
||||||
|
estimatedTime="",
|
||||||
|
filesReceived=0
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
task_id = f"task_{int(time.time() * 1000)}_{uuid.uuid4().hex[:10]}"
|
||||||
|
|
||||||
|
task_dir = os.path.join(UPLOAD_DIR, task_id, "inputs")
|
||||||
|
os.makedirs(task_dir, exist_ok=True)
|
||||||
|
|
||||||
|
saved_paths = []
|
||||||
|
if images:
|
||||||
|
for idx, img in enumerate(images):
|
||||||
|
ext = os.path.splitext(img.filename)[1]
|
||||||
|
file_path = os.path.join(task_dir, f"{idx}{ext}")
|
||||||
|
with open(file_path, "wb") as f:
|
||||||
|
f.write(await img.read())
|
||||||
|
saved_paths.append(file_path)
|
||||||
|
|
||||||
|
WORKER.task_store[task_id] = TaskStore(images=saved_paths)
|
||||||
|
WORKER.task_queue.put(task_id)
|
||||||
|
|
||||||
|
return Analyze(
|
||||||
|
success=True,
|
||||||
|
data=AnalyzeData(
|
||||||
|
taskId=task_id,
|
||||||
|
status=TaskStatus.QUEUED.name,
|
||||||
|
message=f"已提交 {len(saved_paths)} 张图片,正在分析",
|
||||||
|
estimatedTime="",
|
||||||
|
filesReceived=len(saved_paths)
|
||||||
|
)
|
||||||
|
)
|
||||||
86
app/routes/analyze_result.py
Normal file
86
app/routes/analyze_result.py
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Response
|
||||||
|
from app.schemas.analyze_result import AnalyzeResult, AnalyzeResultData, ImageInfo, MaskInfo, ResultItem
|
||||||
|
from app.services.model import TaskStatus
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/analyze/result/{task_id}")
|
||||||
|
async def get_task_result(task_id: str, response: Response):
|
||||||
|
from app.main import UPLOAD_DIR, WORKER
|
||||||
|
task = WORKER.task_store.get(task_id)
|
||||||
|
|
||||||
|
if not task:
|
||||||
|
response.status_code = 404
|
||||||
|
return AnalyzeResult(
|
||||||
|
success=False,
|
||||||
|
data=AnalyzeResultData(
|
||||||
|
taskId=task_id,
|
||||||
|
status=TaskStatus.NOT_FOUND.name,
|
||||||
|
completedAt=None,
|
||||||
|
results=None
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if task.status == TaskStatus.COMPLETED.name:
|
||||||
|
# 构建完成状态的结果数据
|
||||||
|
result_items = []
|
||||||
|
|
||||||
|
# 输入和输出目录路径
|
||||||
|
input_dir = os.path.join(UPLOAD_DIR, task_id, "inputs")
|
||||||
|
output_dir = os.path.join(UPLOAD_DIR, task_id, "outputs")
|
||||||
|
|
||||||
|
for idx, result_data in enumerate(task.result):
|
||||||
|
# 解析坐标数据
|
||||||
|
coords_data = json.loads(result_data.get("coords", "[]"))
|
||||||
|
|
||||||
|
# 构建图片信息
|
||||||
|
input_img_path = result_data.get("input_img_path", "")
|
||||||
|
output_img_path = result_data.get("output_img_path", "")
|
||||||
|
|
||||||
|
# 构建URL路径
|
||||||
|
input_filename = os.path.basename(input_img_path)
|
||||||
|
output_filename = os.path.basename(output_img_path)
|
||||||
|
|
||||||
|
image_info = ImageInfo(
|
||||||
|
origin=f"/uploads/{task_id}/inputs/{input_filename}",
|
||||||
|
image=f"/uploads/{task_id}/outputs/{output_filename}" if output_img_path is not "" else "",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 构建mask信息
|
||||||
|
masks = [
|
||||||
|
MaskInfo(name=mask["name"], coords=mask["coords"])
|
||||||
|
for mask in coords_data
|
||||||
|
]
|
||||||
|
|
||||||
|
result_item = ResultItem(
|
||||||
|
id=str(idx),
|
||||||
|
images=image_info,
|
||||||
|
masks=masks
|
||||||
|
)
|
||||||
|
result_items.append(result_item)
|
||||||
|
|
||||||
|
response.status_code = 200
|
||||||
|
return AnalyzeResult(
|
||||||
|
success=True,
|
||||||
|
data=AnalyzeResultData(
|
||||||
|
taskId=task_id,
|
||||||
|
status=task.status,
|
||||||
|
completedAt=task.completedAt.isoformat() if task.completedAt else "",
|
||||||
|
results=result_items
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 其他状态(处理中、失败等)
|
||||||
|
return AnalyzeResult(
|
||||||
|
success=True,
|
||||||
|
data=AnalyzeResultData(
|
||||||
|
taskId=task_id,
|
||||||
|
status=task.status,
|
||||||
|
completedAt=None,
|
||||||
|
results=None
|
||||||
|
)
|
||||||
|
)
|
||||||
34
app/routes/analyze_status.py
Normal file
34
app/routes/analyze_status.py
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
from fastapi import APIRouter, Response
|
||||||
|
from app.schemas.analyze_status import AnalyzeStatus, AnalyzeStatusData
|
||||||
|
from app.services.model import TaskStatus
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/analyze/status/{task_id}")
|
||||||
|
async def get_task_status(task_id: str, response: Response):
|
||||||
|
from app.main import WORKER
|
||||||
|
task = WORKER.task_store.get(task_id)
|
||||||
|
|
||||||
|
if task:
|
||||||
|
response.status_code = 200
|
||||||
|
return AnalyzeStatus(
|
||||||
|
success=True,
|
||||||
|
data=AnalyzeStatusData(
|
||||||
|
taskId=task_id,
|
||||||
|
status=task.status,
|
||||||
|
progress=task.progress,
|
||||||
|
message=task.message
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
response.status_code = 404
|
||||||
|
return AnalyzeStatus(
|
||||||
|
success=False,
|
||||||
|
data=AnalyzeStatusData(
|
||||||
|
taskId=task_id,
|
||||||
|
status=TaskStatus.NOT_FOUND.name,
|
||||||
|
progress=0,
|
||||||
|
message=TaskStatus.NOT_FOUND.value
|
||||||
|
)
|
||||||
|
)
|
||||||
15
app/routes/health.py
Normal file
15
app/routes/health.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from fastapi import APIRouter
|
||||||
|
from app.schemas.health import Health
|
||||||
|
from app.services.model import TaskStatus
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/health", response_model=Health)
|
||||||
|
async def health():
|
||||||
|
return Health(
|
||||||
|
status=TaskStatus.OK.name,
|
||||||
|
message="算法服务器正常运行",
|
||||||
|
timestamp=str(datetime.now())
|
||||||
|
)
|
||||||
0
app/schemas/__init__.py
Normal file
0
app/schemas/__init__.py
Normal file
18
app/schemas/analyze.py
Normal file
18
app/schemas/analyze.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
""" 提交任务响应结构 """
|
||||||
|
|
||||||
|
|
||||||
|
class AnalyzeData(BaseModel):
|
||||||
|
taskId: Optional[str] = None
|
||||||
|
status: str
|
||||||
|
message: str
|
||||||
|
estimatedTime: Optional[str] = None
|
||||||
|
filesReceived: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
|
class Analyze(BaseModel):
|
||||||
|
success: bool
|
||||||
|
data: AnalyzeData
|
||||||
32
app/schemas/analyze_result.py
Normal file
32
app/schemas/analyze_result.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
from typing import List, Optional
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
""" 获取任务结果响应结构 """
|
||||||
|
|
||||||
|
|
||||||
|
class ImageInfo(BaseModel):
|
||||||
|
origin: str
|
||||||
|
image: str
|
||||||
|
|
||||||
|
|
||||||
|
class MaskInfo(BaseModel):
|
||||||
|
name: str
|
||||||
|
coords: List[List[int]]
|
||||||
|
|
||||||
|
|
||||||
|
class ResultItem(BaseModel):
|
||||||
|
id: str
|
||||||
|
images: ImageInfo
|
||||||
|
masks: List[MaskInfo]
|
||||||
|
|
||||||
|
|
||||||
|
class AnalyzeResultData(BaseModel):
|
||||||
|
taskId: str
|
||||||
|
status: str
|
||||||
|
completedAt: Optional[str] = None
|
||||||
|
results: Optional[List[ResultItem]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class AnalyzeResult(BaseModel):
|
||||||
|
success: bool
|
||||||
|
data: AnalyzeResultData
|
||||||
15
app/schemas/analyze_status.py
Normal file
15
app/schemas/analyze_status.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
""" 获取任务状态响应结构 """
|
||||||
|
|
||||||
|
|
||||||
|
class AnalyzeStatusData(BaseModel):
|
||||||
|
taskId: str
|
||||||
|
status: str
|
||||||
|
progress: int
|
||||||
|
message: str
|
||||||
|
|
||||||
|
|
||||||
|
class AnalyzeStatus(BaseModel):
|
||||||
|
success: bool
|
||||||
|
data: AnalyzeStatusData
|
||||||
7
app/schemas/health.py
Normal file
7
app/schemas/health.py
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class Health(BaseModel):
|
||||||
|
status: str
|
||||||
|
message: str
|
||||||
|
timestamp: str
|
||||||
0
app/services/__init__.py
Normal file
0
app/services/__init__.py
Normal file
66
app/services/model.py
Normal file
66
app/services/model.py
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from enum import Enum
|
||||||
|
from typing import List, Dict, Optional
|
||||||
|
|
||||||
|
|
||||||
|
class TaskStatus(Enum):
|
||||||
|
OK = "正常运行"
|
||||||
|
QUEUED = "已入队"
|
||||||
|
PROCESSING = "处理中"
|
||||||
|
COMPLETED = "处理完成"
|
||||||
|
FAILED = "处理错误"
|
||||||
|
NOT_FOUND = "任务不存在"
|
||||||
|
|
||||||
|
|
||||||
|
class TaskStore:
|
||||||
|
def __init__(self, images: List[str]):
|
||||||
|
self._status: str = TaskStatus.QUEUED.name
|
||||||
|
self._progress: int = 0
|
||||||
|
self._images: List[str] = images
|
||||||
|
self._result: List[Dict[str, str]] = []
|
||||||
|
self._message: str = ""
|
||||||
|
self._completedAt: Optional[datetime] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def images(self):
|
||||||
|
return self._images
|
||||||
|
|
||||||
|
@property
|
||||||
|
def status(self):
|
||||||
|
return self._status
|
||||||
|
|
||||||
|
@status.setter
|
||||||
|
def status(self, status: TaskStatus):
|
||||||
|
self._status = status
|
||||||
|
|
||||||
|
@property
|
||||||
|
def progress(self):
|
||||||
|
return self._progress
|
||||||
|
|
||||||
|
@progress.setter
|
||||||
|
def progress(self, progress: int):
|
||||||
|
self._progress = progress
|
||||||
|
|
||||||
|
@property
|
||||||
|
def result(self):
|
||||||
|
return self._result
|
||||||
|
|
||||||
|
@result.setter
|
||||||
|
def result(self, result: str):
|
||||||
|
self._result = result
|
||||||
|
|
||||||
|
@property
|
||||||
|
def message(self):
|
||||||
|
return self._message
|
||||||
|
|
||||||
|
@message.setter
|
||||||
|
def message(self, message: str):
|
||||||
|
self._message = message
|
||||||
|
|
||||||
|
@property
|
||||||
|
def completedAt(self):
|
||||||
|
return self._completedAt
|
||||||
|
|
||||||
|
@completedAt.setter
|
||||||
|
def completedAt(self, completedAt: datetime):
|
||||||
|
self._completedAt = completedAt
|
||||||
80
app/services/worker.py
Normal file
80
app/services/worker.py
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
import cv2
|
||||||
|
import threading
|
||||||
|
from datetime import datetime
|
||||||
|
from queue import Queue
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from app.core.model import Model
|
||||||
|
from app.core.preprocess import Preprocess
|
||||||
|
from app.services.model import TaskStatus, TaskStore
|
||||||
|
|
||||||
|
|
||||||
|
class Worker:
|
||||||
|
def __init__(self):
|
||||||
|
self.detection = Model().getModel()
|
||||||
|
self.preprocess = Preprocess().getPreprocess()
|
||||||
|
self.task_queue = Queue()
|
||||||
|
self.task_store: Dict[str, TaskStore] = {}
|
||||||
|
|
||||||
|
threading.Thread(target=self.worker, daemon=True).start()
|
||||||
|
|
||||||
|
def worker(self):
|
||||||
|
from app.main import UPLOAD_DIR
|
||||||
|
while True:
|
||||||
|
task_id = self.task_queue.get()
|
||||||
|
if task_id is None:
|
||||||
|
break
|
||||||
|
|
||||||
|
task = self.task_store.get(task_id)
|
||||||
|
if not task:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
task.status = TaskStatus.PROCESSING.name
|
||||||
|
task.progress = 0
|
||||||
|
print(f"开始处理任务 {task_id}...")
|
||||||
|
|
||||||
|
# 创建输出目录
|
||||||
|
output_dir = os.path.join(UPLOAD_DIR, task_id, "outputs")
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# 获取图像的标签列表
|
||||||
|
image_labels = self.preprocess.preprocess(task.images) # 返回一个0和1的列表,0代表跳过,1代表进行检测
|
||||||
|
|
||||||
|
for idx, (input_img_path, label) in enumerate(zip(task.images, image_labels)):
|
||||||
|
print(f"处理任务 {task_id}, 处理图片 {input_img_path}...")
|
||||||
|
|
||||||
|
if label == 0:
|
||||||
|
# 如果标签是0,跳过模型检测,输出路径和坐标为空
|
||||||
|
task.result.append(
|
||||||
|
{"input_img_path": input_img_path, "output_img_path": "", "coords": "[]"}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 进行模型检测
|
||||||
|
img_res, coords_res = self.detection.detect(input_img_path)
|
||||||
|
|
||||||
|
coords_res = [{"name": name, "coords": coords} for name, coords in coords_res]
|
||||||
|
coords_json = json.dumps(coords_res, ensure_ascii=False)
|
||||||
|
|
||||||
|
out_img_path = os.path.join(output_dir, f"{idx}.jpg")
|
||||||
|
cv2.imwrite(out_img_path, img_res)
|
||||||
|
|
||||||
|
task.result.append(
|
||||||
|
{"input_img_path": input_img_path, "output_img_path": out_img_path, "coords": coords_json}
|
||||||
|
)
|
||||||
|
|
||||||
|
task.progress = int((idx + 1) / len(task.images) * 100)
|
||||||
|
|
||||||
|
task.status = TaskStatus.COMPLETED.name
|
||||||
|
task.completedAt = datetime.now()
|
||||||
|
task.message = "处理完成"
|
||||||
|
|
||||||
|
print(f"任务 {task_id} 处理完成")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
task.status = TaskStatus.FAILED.name
|
||||||
|
task.message = str(e)
|
||||||
|
finally:
|
||||||
|
self.task_queue.task_done()
|
||||||
15
requirements.txt
Normal file
15
requirements.txt
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
fastapi[all]==0.122.0
|
||||||
|
albumentations==2.0.8
|
||||||
|
albucore==0.0.24
|
||||||
|
numpy==2.2.6
|
||||||
|
onnxruntime-gpu==1.23.2
|
||||||
|
opencv-python==4.12.0.88
|
||||||
|
PyYAML>=6.0
|
||||||
|
scipy>=1.10.0
|
||||||
|
pydantic>=2.10.0
|
||||||
|
typing-extensions>=4.12.0
|
||||||
|
python-dotenv~=1.2.1
|
||||||
|
opencv-python
|
||||||
|
matplotlib
|
||||||
|
pandas
|
||||||
|
tqdm==4.67.1
|
||||||
Loading…
x
Reference in New Issue
Block a user