feat(api): 添加图像分析功能和相关路由接口
- 新增 analyze、analyze_result、analyze_status 和 health 路由 - 实现图像上传和任务提交功能 - 添加任务状态查询和结果获取接口 - 集成 segformer 和 yolo 模型进行图像检测 - 实现 SAM3 预处理功能用于图像预处理判断 - 添加模型选择配置支持 segformer 和 yolo - 实现任务队列管理和异步处理机制 - 添加 Dockerfile 用于容器化部署 - 配置环境变量和 gitignore 规则 - 创建数据模型定义 API 响应结构
This commit is contained in:
commit
6a2e046884
4
.env
Normal file
4
.env
Normal file
@ -0,0 +1,4 @@
|
||||
UPLOAD_DIR=uploads
|
||||
MOCK=false
|
||||
MODEL=yolo #segformer, yolo
|
||||
PREPROCESS=sam3
|
||||
4
.gitattributes
vendored
Normal file
4
.gitattributes
vendored
Normal file
@ -0,0 +1,4 @@
|
||||
*.pt filter=lfs diff=lfs merge=lfs -text
|
||||
*.onnx filter=lfs diff=lfs merge=lfs -text
|
||||
*.onnx.data filter=lfs diff=lfs merge=lfs -text
|
||||
*.pth filter=lfs diff=lfs merge=lfs -text
|
||||
6
.gitignore
vendored
Normal file
6
.gitignore
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
.idea/
|
||||
docker-test/
|
||||
tests/
|
||||
uploads/
|
||||
Wall Docker 镜像使用教程.md
|
||||
Wall Docker 镜像使用教程.pdf
|
||||
29
Dockerfile
Normal file
29
Dockerfile
Normal file
@ -0,0 +1,29 @@
|
||||
FROM python:3.11
|
||||
|
||||
# 安装系统依赖并清理缓存
|
||||
RUN apt-get update && apt-get install -y \
|
||||
libgl1 \
|
||||
libglib2.0-0 \
|
||||
libsm6 \
|
||||
libxrender1 \
|
||||
libxext6 \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# 设置工作目录
|
||||
WORKDIR /code
|
||||
|
||||
# 复制并安装 Python 依赖
|
||||
COPY requirements.txt /code/requirements.txt
|
||||
|
||||
# 安装 Python 依赖(加速源)
|
||||
RUN pip install --no-cache-dir --upgrade -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple \
|
||||
&& pip install --no-cache-dir --upgrade torch torchvision --index-url https://download.pytorch.org/whl/cu130
|
||||
|
||||
# 复制应用代码
|
||||
COPY ./app /code/app
|
||||
|
||||
# 删除无用的文件,避免占用磁盘空间
|
||||
RUN rm -rf /code/app/core/*.onnx /code/app/core/*.data /code/app/core/*.pt
|
||||
|
||||
# 暴露端口并启动应用
|
||||
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "80"]
|
||||
56
README.md
Normal file
56
README.md
Normal file
@ -0,0 +1,56 @@
|
||||
### Wall Docker 镜像使用教程
|
||||
|
||||
---
|
||||
|
||||
> 采用模型文件持久化,方便后续更新模型而不需要重新创建容器,以及统一配置配置文件
|
||||
|
||||
1. 导入docker images `docker load -i wall.tar`
|
||||
|
||||
2. 进入持久化目录,新建.env,添加相应内容
|
||||
|
||||
```env
|
||||
UPLOAD_DIR=uploads
|
||||
MOCK=false
|
||||
MODEL=segformer #segformer, yolo,目前打包模型只有segformer
|
||||
```
|
||||
|
||||
3. 解压算法模型目录到core文件夹
|
||||
|
||||
```bash
|
||||
tar -xvf core.tar
|
||||
```
|
||||
|
||||
4. 使用指令运行docker镜像
|
||||
|
||||
```bash
|
||||
sudo docker run -d \
|
||||
--name [docker_container_name] \
|
||||
--gpus all \
|
||||
-p [local_port]:80 \
|
||||
-v $(pwd)/uploads:/code/uploads \
|
||||
-v $(pwd)/core:/code/app/core \
|
||||
-v $(pwd)/.env:/code/.env \
|
||||
wall
|
||||
```
|
||||
|
||||
5. 如果后续需要更新模型,只需要覆盖掉core内的文件,更改.env配置文件后,即可继续运行
|
||||
|
||||
---
|
||||
|
||||
> 如果不想要模型文件持久化,则不需要解压算法文件了
|
||||
|
||||
1. 导入docker images `docker load -i wall.tar`
|
||||
|
||||
2. 使用指令运行docker镜像
|
||||
|
||||
```bash
|
||||
sudo docker run -d \
|
||||
--name [docker_container_name] \
|
||||
--gpus all \
|
||||
-p [local_port]:80 \
|
||||
-v $(pwd)/core:/code/app/core \
|
||||
-v $(pwd)/.env:/code/.env \
|
||||
wall
|
||||
```
|
||||
|
||||
|
||||
0
app/__init__.py
Normal file
0
app/__init__.py
Normal file
0
app/core/__init__.py
Normal file
0
app/core/__init__.py
Normal file
16
app/core/model.py
Normal file
16
app/core/model.py
Normal file
@ -0,0 +1,16 @@
|
||||
from app.core.segformer.detect import Detection as SegFormer, DetectionMock as SegFormerMock
|
||||
from app.core.yolo.detect import YOLOSeg
|
||||
|
||||
|
||||
class Model:
|
||||
def __init__(self):
|
||||
from app.main import MODEL
|
||||
if MODEL == "segformer":
|
||||
print("使用 SegFormer 模型")
|
||||
self.detection = SegFormer()
|
||||
elif MODEL == "yolo":
|
||||
print("使用 YOLO 模型")
|
||||
self.detection = YOLOSeg()
|
||||
|
||||
def getModel(self):
|
||||
return self.detection
|
||||
12
app/core/preprocess.py
Normal file
12
app/core/preprocess.py
Normal file
@ -0,0 +1,12 @@
|
||||
from app.core.sam3.preprocess import SAM3
|
||||
|
||||
|
||||
class Preprocess:
|
||||
def __init__(self):
|
||||
from app.main import PREPROCESS
|
||||
if PREPROCESS == "sam3":
|
||||
print("使用 SAM3 进行预处理判断")
|
||||
self.preprocess = SAM3()
|
||||
|
||||
def getPreprocess(self):
|
||||
return self.preprocess
|
||||
0
app/core/sam3/__init__.py
Normal file
0
app/core/sam3/__init__.py
Normal file
172
app/core/sam3/preprocess.py
Normal file
172
app/core/sam3/preprocess.py
Normal file
@ -0,0 +1,172 @@
|
||||
import os
|
||||
import torch
|
||||
from pathlib import Path
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from sam3.model_builder import build_sam3_image_model
|
||||
from sam3.train.data.collator import collate_fn_api as collate
|
||||
from sam3.model.utils.misc import copy_data_to_device
|
||||
from sam3.train.data.sam3_image_dataset import (
|
||||
Datapoint, Image as SAMImage, FindQueryLoaded, InferenceMetadata
|
||||
)
|
||||
from sam3.train.transforms.basic_for_api import ComposeAPI, RandomResizeAPI, ToTensorAPI, NormalizeAPI
|
||||
from sam3.eval.postprocessors import PostProcessImage
|
||||
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True,max_split_size_mb:256"
|
||||
|
||||
# ===== 配置 =====
|
||||
CKPT_PATH = os.path.join(os.getcwd(), "app/core/sam3", "sam3.pt")
|
||||
DEVICE = "cuda:0"
|
||||
|
||||
BATCH_SIZE = 12 # 批量大小,前端要设置
|
||||
NUM_WORKERS = 12 # 加载图片的线程数,看前端要不要设置
|
||||
CONF_TH = 0.5
|
||||
RATIO_TH = 0.5 # 阈值,越大的话过滤越多,但太大会影响近景图片
|
||||
_GLOBAL_ID = 1
|
||||
|
||||
PROMPTS = [
|
||||
"wall",
|
||||
"building wall",
|
||||
"building facade",
|
||||
"building exterior wall",
|
||||
"exterior building facade",
|
||||
]
|
||||
IMG_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff", ".webp"}
|
||||
|
||||
|
||||
# ============
|
||||
|
||||
|
||||
class ImgPathList(Dataset):
|
||||
def __init__(self, img_paths: list):
|
||||
"""
|
||||
初始化 ImgFolder,传入一个图片路径的列表
|
||||
|
||||
Args:
|
||||
img_paths (list): 一个包含图片路径的列表
|
||||
"""
|
||||
self.paths = img_paths # 使用传入的路径列表
|
||||
|
||||
def __len__(self):
|
||||
return len(self.paths)
|
||||
|
||||
def __getitem__(self, i):
|
||||
p = self.paths[i] # 直接使用列表中的路径
|
||||
img = Image.open(p).convert("RGB") # 打开图片并转换为RGB模式
|
||||
return p, img # 返回图片的路径和图片本身
|
||||
|
||||
|
||||
|
||||
class SAM3:
|
||||
def __init__(self):
|
||||
self.dev = torch.device(DEVICE)
|
||||
self.postprocessor = PostProcessImage(
|
||||
max_dets_per_img=-1,
|
||||
iou_type="segm",
|
||||
use_original_sizes_box=True,
|
||||
use_original_sizes_mask=True,
|
||||
convert_mask_to_rle=False,
|
||||
detection_threshold=CONF_TH,
|
||||
to_cpu=False,
|
||||
)
|
||||
self.transform = ComposeAPI(
|
||||
transforms=[
|
||||
RandomResizeAPI(sizes=1008, max_size=1008, square=True, consistent_transform=False),
|
||||
ToTensorAPI(),
|
||||
NormalizeAPI(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
||||
]
|
||||
)
|
||||
self.model = build_sam3_image_model(
|
||||
checkpoint_path=CKPT_PATH, load_from_HF=False, device=DEVICE
|
||||
).to(DEVICE).eval()
|
||||
|
||||
def preprocess(self, image_path_list):
|
||||
labels = []
|
||||
|
||||
loader = DataLoader(
|
||||
ImgPathList(image_path_list),
|
||||
batch_size=BATCH_SIZE,
|
||||
shuffle=False,
|
||||
num_workers=NUM_WORKERS,
|
||||
pin_memory=True,
|
||||
collate_fn=self.collate_fn,
|
||||
)
|
||||
|
||||
with torch.inference_mode():
|
||||
for names, images in loader:
|
||||
datapoints = []
|
||||
name2qids = {} # name -> [qid,...]
|
||||
for name, img in zip(names, images):
|
||||
dp = self.create_empty_datapoint()
|
||||
self.set_image(dp, img)
|
||||
|
||||
qids = [self.add_text_prompt(dp, p) for p in PROMPTS]
|
||||
name2qids[name] = qids
|
||||
|
||||
datapoints.append(self.transform(dp))
|
||||
|
||||
batch = collate(datapoints, dict_key="dummy")["dummy"]
|
||||
batch = copy_data_to_device(batch, self.dev, non_blocking=True)
|
||||
output = self.model(batch)
|
||||
|
||||
processed = self.postprocessor.process_results(output, batch.find_metadatas)
|
||||
|
||||
for name in names:
|
||||
any_masks = []
|
||||
for qid in name2qids[name]:
|
||||
res = processed[qid]
|
||||
m = res.get("masks", None) # 期望: [N,H,W]
|
||||
if m is None:
|
||||
any_masks.append(torch.zeros(1, 1, device=self.dev, dtype=torch.bool).squeeze())
|
||||
else:
|
||||
if not torch.is_tensor(m):
|
||||
m = torch.as_tensor(m, device=self.dev)
|
||||
any_masks.append(m.any(0)) # [H,W]
|
||||
|
||||
wall_mask = torch.stack(any_masks, 0).any(0) # [H,W] bool
|
||||
ratio = wall_mask.float().mean().item()
|
||||
lab = 1 if ratio >= RATIO_TH else 0
|
||||
labels.append(lab)
|
||||
print(f"{name} | wall_ratio={ratio:.4f} -> {lab}") # 这行可以不要
|
||||
|
||||
return labels
|
||||
|
||||
@staticmethod
|
||||
def add_text_prompt(datapoint, text_query):
|
||||
global _GLOBAL_ID
|
||||
assert len(datapoint.images) == 1, "please set the image first"
|
||||
w, h = datapoint.images[0].size
|
||||
datapoint.find_queries.append(
|
||||
FindQueryLoaded(
|
||||
query_text=text_query,
|
||||
image_id=0,
|
||||
object_ids_output=[],
|
||||
is_exhaustive=True,
|
||||
query_processing_order=0,
|
||||
inference_metadata=InferenceMetadata(
|
||||
coco_image_id=_GLOBAL_ID,
|
||||
original_image_id=_GLOBAL_ID,
|
||||
original_category_id=1,
|
||||
original_size=[w, h],
|
||||
object_id=0,
|
||||
frame_index=0,
|
||||
),
|
||||
)
|
||||
)
|
||||
_GLOBAL_ID += 1
|
||||
return _GLOBAL_ID - 1
|
||||
|
||||
@staticmethod
|
||||
def create_empty_datapoint():
|
||||
return Datapoint(find_queries=[], images=[])
|
||||
|
||||
@staticmethod
|
||||
def set_image(datapoint, pil_image):
|
||||
w, h = pil_image.size
|
||||
datapoint.images = [SAMImage(data=pil_image, objects=[], size=[h, w])] # size 用 [H,W]
|
||||
|
||||
@staticmethod
|
||||
def collate_fn(batch):
|
||||
names, imgs = zip(*batch)
|
||||
return list(names), list(imgs)
|
||||
BIN
app/core/sam3/sam3.pt
(Stored with Git LFS)
Normal file
BIN
app/core/sam3/sam3.pt
(Stored with Git LFS)
Normal file
Binary file not shown.
0
app/core/segformer/__init__.py
Normal file
0
app/core/segformer/__init__.py
Normal file
99
app/core/segformer/detect.py
Normal file
99
app/core/segformer/detect.py
Normal file
@ -0,0 +1,99 @@
|
||||
import os
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import albumentations as A
|
||||
import onnxruntime as ort
|
||||
|
||||
|
||||
class Detection:
|
||||
def __init__(self):
|
||||
self.CLASS_NAMES = ["background", "Hollowing", "Water seepage", "Cracking"]
|
||||
self.PALETTE = np.array([[0, 0, 0], [255, 0, 0], [0, 255, 0], [0, 0, 255]], dtype=np.uint8)
|
||||
self.tfm = A.Compose([
|
||||
A.Resize(512, 512),
|
||||
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||
])
|
||||
real_path = os.path.join(os.path.dirname(__file__), "segformer.onnx")
|
||||
self.model = ort.InferenceSession(
|
||||
real_path,
|
||||
providers=["CUDAExecutionProvider", "CPUExecutionProvider"]
|
||||
)
|
||||
print("ONNX 模型加载完成!")
|
||||
|
||||
def get_contours(self, mask_np):
|
||||
res = []
|
||||
for idx in range(1, len(self.CLASS_NAMES)):
|
||||
name = self.CLASS_NAMES[idx]
|
||||
binary = (mask_np == idx).astype(np.uint8) * 255
|
||||
cnts, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
for c in cnts:
|
||||
if cv2.contourArea(c) < 20:
|
||||
continue
|
||||
eps = 0.002 * cv2.arcLength(c, True)
|
||||
poly = cv2.approxPolyDP(c, eps, True)
|
||||
if len(poly) >= 3:
|
||||
res.append((name, poly.reshape(-1, 2).tolist()))
|
||||
return res
|
||||
|
||||
def detect(self, img_input):
|
||||
# 读取图片
|
||||
if isinstance(img_input, str):
|
||||
img_bgr = cv2.imdecode(np.fromfile(img_input, dtype=np.uint8), cv2.IMREAD_COLOR)
|
||||
else:
|
||||
img_bgr = img_input
|
||||
|
||||
h, w = img_bgr.shape[:2]
|
||||
img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# 数据预处理
|
||||
x = self.tfm(image=img_rgb)["image"]
|
||||
# 将 HWC -> CHW 并增加 batch 维度
|
||||
x = np.transpose(x, (2, 0, 1)) # HWC -> CHW
|
||||
x = np.expand_dims(x, axis=0).astype(np.float32) # batch x C x H x W
|
||||
|
||||
# ONNX 推理
|
||||
inp_name = self.model.get_inputs()[0].name
|
||||
out_name = self.model.get_outputs()[0].name
|
||||
out = self.model.run([out_name], {inp_name: x})[0]
|
||||
|
||||
# 获取预测结果
|
||||
pred = out.argmax(axis=1)[0].astype(np.uint8)
|
||||
|
||||
mask = cv2.resize(pred, (w, h), interpolation=cv2.INTER_NEAREST)
|
||||
mask_rgb = self.PALETTE[mask]
|
||||
coords = self.get_contours(mask)
|
||||
# print(coords)
|
||||
return mask_rgb, coords
|
||||
|
||||
|
||||
class DetectionMock:
|
||||
def __init__(self):
|
||||
self.onnx_path = "segformer.onnx"
|
||||
|
||||
def detect(self, img_input):
|
||||
if isinstance(img_input, str):
|
||||
img_bgr = cv2.imdecode(np.fromfile(img_input, dtype=np.uint8), cv2.IMREAD_COLOR)
|
||||
else:
|
||||
img_bgr = img_input
|
||||
|
||||
img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
|
||||
|
||||
return img_rgb, [
|
||||
('Cracking', [[771, 928], [757, 935]]),
|
||||
('Cracking', [[254, 740], [251, 942]]),
|
||||
('Cracking', [[764, 420], [764, 424]]),
|
||||
('Cracking', [[257, 238], [251, 245]]),
|
||||
('Cracking', [[1436, 145], [1401, 145]])
|
||||
]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
img_path = "test.jpg"
|
||||
|
||||
detection = Detection()
|
||||
mask_res, coords_res = detection.detect(img_path)
|
||||
|
||||
print("Mask Shape:", mask_res.shape)
|
||||
print("Detections:", coords_res)
|
||||
cv2.imwrite("res_onnx.png", cv2.cvtColor(mask_res, cv2.COLOR_RGB2BGR))
|
||||
BIN
app/core/segformer/segformer.onnx
(Stored with Git LFS)
Normal file
BIN
app/core/segformer/segformer.onnx
(Stored with Git LFS)
Normal file
Binary file not shown.
BIN
app/core/segformer/segformer_b2.onnx.data
(Stored with Git LFS)
Normal file
BIN
app/core/segformer/segformer_b2.onnx.data
(Stored with Git LFS)
Normal file
Binary file not shown.
0
app/core/yolo/__init__.py
Normal file
0
app/core/yolo/__init__.py
Normal file
291
app/core/yolo/detect.py
Normal file
291
app/core/yolo/detect.py
Normal file
@ -0,0 +1,291 @@
|
||||
import os
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
|
||||
CLASS_NAMES = {
|
||||
0: "Hollowing",
|
||||
1: "Water seepage",
|
||||
2: "Cracking",
|
||||
}
|
||||
|
||||
CLASS_COLORS = {
|
||||
0: (0, 0, 255), # Hollowing -> 红色
|
||||
1: (0, 255, 0), # Water seepage -> 绿色
|
||||
2: (255, 0, 0), # Cracking -> 蓝色
|
||||
}
|
||||
|
||||
IMG_SIZE = 640
|
||||
|
||||
|
||||
class YOLOSeg:
|
||||
def __init__(self, onnx_path: str = "model.onnx", imgsz: int = IMG_SIZE):
|
||||
real_path = os.path.join(os.path.dirname(__file__), onnx_path)
|
||||
self.session = ort.InferenceSession(
|
||||
real_path,
|
||||
providers=["CUDAExecutionProvider", "CPUExecutionProvider"]
|
||||
if ort.get_device() == "GPU"
|
||||
else ["CPUExecutionProvider"],
|
||||
)
|
||||
self.ndtype = (
|
||||
np.half
|
||||
if self.session.get_inputs()[0].type == "tensor(float16)"
|
||||
else np.single
|
||||
)
|
||||
self.imgsz = imgsz
|
||||
self.classes = CLASS_NAMES
|
||||
|
||||
# ---------- 预处理:letterbox ----------
|
||||
def _preprocess(self, img_bgr):
|
||||
h0, w0 = img_bgr.shape[:2]
|
||||
new_shape = (self.imgsz, self.imgsz)
|
||||
|
||||
r = min(new_shape[0] / h0, new_shape[1] / w0)
|
||||
ratio = (r, r)
|
||||
new_unpad = (int(round(w0 * r)), int(round(h0 * r)))
|
||||
pad_w = (new_shape[1] - new_unpad[0]) / 2
|
||||
pad_h = (new_shape[0] - new_unpad[1]) / 2
|
||||
|
||||
if (w0, h0) != new_unpad:
|
||||
img = cv2.resize(img_bgr, new_unpad, interpolation=cv2.INTER_LINEAR)
|
||||
else:
|
||||
img = img_bgr.copy()
|
||||
|
||||
top, bottom = int(round(pad_h - 0.1)), int(round(pad_h + 0.1))
|
||||
left, right = int(round(pad_w - 0.1)), int(round(pad_w + 0.1))
|
||||
img = cv2.copyMakeBorder(
|
||||
img, top, bottom, left, right,
|
||||
borderType=cv2.BORDER_CONSTANT, value=(114, 114, 114)
|
||||
)
|
||||
|
||||
# HWC -> CHW, BGR->RGB, /255
|
||||
img = np.ascontiguousarray(np.einsum("HWC->CHW", img)[::-1], dtype=self.ndtype)
|
||||
img = img / 255.0
|
||||
if img.ndim == 3:
|
||||
img = img[None] # (1,3,H,W)
|
||||
return img, ratio, (pad_w, pad_h)
|
||||
|
||||
# ---------- mask -> 多边形 ----------
|
||||
@staticmethod
|
||||
def _masks2segments(masks):
|
||||
"""masks: (N,H,W) -> 每个实例的多边形坐标"""
|
||||
segments = []
|
||||
for x in masks.astype("uint8"):
|
||||
cs = cv2.findContours(x, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)[0]
|
||||
if cs:
|
||||
# 取点数最多的一条轮廓
|
||||
c = np.array(cs[np.argmax([len(i) for i in cs])]).reshape(-1, 2)
|
||||
else:
|
||||
c = np.zeros((0, 2))
|
||||
segments.append(c.astype("float32"))
|
||||
return segments
|
||||
|
||||
@staticmethod
|
||||
def _crop_mask(masks, boxes):
|
||||
"""masks: (N,H,W), boxes: (N,4) xyxy"""
|
||||
n, h, w = masks.shape
|
||||
x1, y1, x2, y2 = np.split(boxes[:, :, None], 4, 1)
|
||||
r = np.arange(w, dtype=x1.dtype)[None, None, :]
|
||||
c = np.arange(h, dtype=x1.dtype)[None, :, None]
|
||||
return masks * ((r >= x1) * (r < x2) * (c >= y1) * (c < y2))
|
||||
|
||||
@staticmethod
|
||||
def _scale_mask(masks, im0_shape, ratio_pad=None):
|
||||
"""把特征图上的 mask 缩放到原图大小"""
|
||||
im1_shape = masks.shape[:2]
|
||||
if ratio_pad is None:
|
||||
gain = min(im1_shape[0] / im0_shape[0], im1_shape[1] / im0_shape[1])
|
||||
pad = (
|
||||
(im1_shape[1] - im0_shape[1] * gain) / 2,
|
||||
(im1_shape[0] - im0_shape[0] * gain) / 2,
|
||||
)
|
||||
else:
|
||||
pad = ratio_pad[1]
|
||||
|
||||
top, left = int(round(pad[1] - 0.1)), int(round(pad[0] - 0.1))
|
||||
bottom = int(round(im1_shape[0] - pad[1] + 0.1))
|
||||
right = int(round(im1_shape[1] - pad[0] + 0.1))
|
||||
|
||||
if masks.ndim < 2:
|
||||
raise ValueError("masks ndim 应该是 2 或 3")
|
||||
|
||||
masks = masks[top:bottom, left:right]
|
||||
masks = cv2.resize(
|
||||
masks, (im0_shape[1], im0_shape[0]), interpolation=cv2.INTER_LINEAR
|
||||
)
|
||||
if masks.ndim == 2:
|
||||
masks = masks[:, :, None]
|
||||
return masks
|
||||
|
||||
def _process_mask(self, protos, masks_in, bboxes, im0_shape):
|
||||
"""
|
||||
protos: (C,Hm,Wm)
|
||||
masks_in: (N,C)
|
||||
bboxes: (N,4) xyxy
|
||||
返回: (N,H,W) bool
|
||||
"""
|
||||
c, mh, mw = protos.shape
|
||||
masks = (
|
||||
np.matmul(masks_in, protos.reshape(c, -1))
|
||||
.reshape(-1, mh, mw)
|
||||
.transpose(1, 2, 0)
|
||||
) # HWN
|
||||
masks = np.ascontiguousarray(masks)
|
||||
masks = self._scale_mask(masks, im0_shape) # HWC
|
||||
masks = np.einsum("HWN->NHW", masks) # NHW
|
||||
masks = self._crop_mask(masks, bboxes)
|
||||
return masks > 0.5
|
||||
|
||||
@staticmethod
|
||||
def _get_cid(name):
|
||||
for k, v in CLASS_NAMES.items():
|
||||
if v == name:
|
||||
return k
|
||||
|
||||
@staticmethod
|
||||
def _make_color_mask(img_bgr, masks, coords):
|
||||
"""
|
||||
生成一张“带颜色的掩码图”
|
||||
- 背景为黑色
|
||||
- 每个实例区域按类别上色(不叠加到原图)
|
||||
返回:color_mask (H,W,3) BGR uint8
|
||||
"""
|
||||
h, w = img_bgr.shape[:2]
|
||||
color_mask = np.zeros((h, w, 3), dtype=np.uint8)
|
||||
|
||||
N = masks.shape[0]
|
||||
for i in range(N):
|
||||
m = masks[i] # (H,W) bool
|
||||
inst = coords[i]
|
||||
cid = YOLOSeg._get_cid(inst[0])
|
||||
# print(f"name: {inst[0]}, cid: {cid}")
|
||||
color = CLASS_COLORS.get(cid, (0, 255, 255)) # 没配置的类用黄青色
|
||||
|
||||
# 只在掩码区域上色
|
||||
color_mask[m] = color
|
||||
|
||||
return color_mask
|
||||
|
||||
# ---------- 推理主入口 ----------
|
||||
def detect(self, img_input):
|
||||
conf_thres = 0.1
|
||||
iou_thres = 0.1
|
||||
"""
|
||||
输入: 原始 BGR 图像
|
||||
输出:
|
||||
masks: (N,H,W) bool 掩码
|
||||
coords: List[dict] 每个实例包含 class_name, confidence, points(多边形)
|
||||
color_mask: 带有颜色的掩码图(黑背景,上面是彩色的缺陷区域)
|
||||
"""
|
||||
if isinstance(img_input, str):
|
||||
img_bgr = cv2.imdecode(np.fromfile(img_input, dtype=np.uint8), cv2.IMREAD_COLOR)
|
||||
else:
|
||||
img_bgr = img_input
|
||||
|
||||
if img_bgr is None:
|
||||
raise ValueError("img_bgr is None, 请检查图片读取是否成功")
|
||||
|
||||
im0 = img_bgr.copy()
|
||||
im, ratio, (pad_w, pad_h) = self._preprocess(im0)
|
||||
|
||||
# ONNX 推理
|
||||
input_name = self.session.get_inputs()[0].name
|
||||
preds = self.session.run(None, {input_name: im})
|
||||
x, protos = preds[0], preds[1] # x:(1,C,N), protos:(1,32,Hm,Wm)
|
||||
|
||||
# (1,C,N) -> (N,C)
|
||||
x = np.einsum("bcn->bnc", x)[0] # (N, C)
|
||||
|
||||
# 从 protos 动态推断 mask 通道数
|
||||
nm = int(protos.shape[1]) # 一般是 32
|
||||
C = x.shape[1]
|
||||
nc = C - 4 - nm # 类别数
|
||||
|
||||
# 类别分数区间 [4:4+nc]
|
||||
cls_scores = x[:, 4:4 + nc]
|
||||
cls_max = np.max(cls_scores, axis=-1)
|
||||
keep = cls_max > conf_thres
|
||||
x = x[keep]
|
||||
cls_scores = cls_scores[keep]
|
||||
|
||||
h0, w0 = im0.shape[:2]
|
||||
|
||||
if x.size == 0:
|
||||
# 没有检测到任何目标:返回空 mask、空坐标、空彩色掩码
|
||||
empty_masks = np.zeros((0, h0, w0), dtype=bool)
|
||||
empty_color_mask = np.zeros((h0, w0, 3), dtype=np.uint8)
|
||||
return empty_masks, [], empty_color_mask
|
||||
|
||||
conf = cls_max[keep]
|
||||
cls_id = np.argmax(cls_scores, axis=-1)
|
||||
# 拼成 [cx,cy,w,h, conf, cls_id, mask_coeffs...]
|
||||
x = np.c_[x[:, :4], conf, cls_id, x[:, -nm:]]
|
||||
|
||||
# ===== NMS:OpenCV NMSBoxes 需要 [x, y, w, h] 左上角坐标 =====
|
||||
# 当前 x[:, :4] 是 [cx, cy, w, h],先转换成 [x, y, w, h]
|
||||
bboxes_xywh = x[:, :4].copy()
|
||||
bboxes_xywh[:, 0] = bboxes_xywh[:, 0] - bboxes_xywh[:, 2] / 2 # x = cx - w/2
|
||||
bboxes_xywh[:, 1] = bboxes_xywh[:, 1] - bboxes_xywh[:, 3] / 2 # y = cy - h/2
|
||||
|
||||
indices = cv2.dnn.NMSBoxes(
|
||||
bboxes_xywh.tolist(), x[:, 4].tolist(), conf_thres, iou_thres
|
||||
)
|
||||
|
||||
# 不同 OpenCV 版本,indices 可能是 []、[0,1]、[[0],[1]]、np.array([...])
|
||||
if indices is None or len(indices) == 0:
|
||||
empty_masks = np.zeros((0, h0, w0), dtype=bool)
|
||||
empty_color_mask = np.zeros((h0, w0, 3), dtype=np.uint8)
|
||||
return empty_masks, [], empty_color_mask
|
||||
|
||||
# 统一成一维整型索引数组
|
||||
indices = np.array(indices).reshape(-1)
|
||||
x = x[indices]
|
||||
|
||||
# cxcywh -> xyxy(这里用处理后的 x[:, :4])
|
||||
x[:, 0:2] -= x[:, 2:4] / 2
|
||||
x[:, 2:4] += x[:, 0:2]
|
||||
|
||||
# 去掉 pad,缩放回原图
|
||||
x[:, [0, 2]] -= pad_w
|
||||
x[:, [1, 3]] -= pad_h
|
||||
x[:, :4] /= min(ratio)
|
||||
|
||||
# 限制在图像范围内
|
||||
x[:, [0, 2]] = x[:, [0, 2]].clip(0, w0)
|
||||
x[:, [1, 3]] = x[:, [1, 3]].clip(0, h0)
|
||||
|
||||
# 解码 mask
|
||||
protos = protos[0] # (32,Hm,Wm)
|
||||
bboxes_xyxy = x[:, :4]
|
||||
mask_coeffs = x[:, 6:]
|
||||
masks = self._process_mask(protos, mask_coeffs, bboxes_xyxy, im0.shape)
|
||||
|
||||
# 掩码 -> 多边形
|
||||
segments = self._masks2segments(masks)
|
||||
|
||||
# 打包坐标结果
|
||||
coords = []
|
||||
for (x1, y1, x2, y2, conf_i, cls_i), seg in zip(x[:, :6], segments):
|
||||
cid = int(cls_i)
|
||||
coords.append(
|
||||
(
|
||||
self.classes.get(cid, str(cid)),
|
||||
seg.tolist()
|
||||
)
|
||||
)
|
||||
|
||||
color_mask = self._make_color_mask(im0, masks, coords)
|
||||
|
||||
return color_mask, coords
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
img_path = r"D:\Projects\Python\wall\app\core\yolo\test.jpg"
|
||||
IMG_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff"}
|
||||
# ====== 加载模型 ======
|
||||
model = YOLOSeg()
|
||||
img = cv2.imdecode(np.fromfile(str(img_path), dtype=np.uint8), cv2.IMREAD_COLOR)
|
||||
color_mask, coords = model.detect(img)
|
||||
print(color_mask.shape)
|
||||
print(coords)
|
||||
BIN
app/core/yolo/model.onnx
(Stored with Git LFS)
Normal file
BIN
app/core/yolo/model.onnx
(Stored with Git LFS)
Normal file
Binary file not shown.
35
app/main.py
Normal file
35
app/main.py
Normal file
@ -0,0 +1,35 @@
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import FastAPI
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from app.services.worker import Worker
|
||||
from app.routes.health import router as health
|
||||
from app.routes.analyze import router as submit_analyze
|
||||
from app.routes.analyze_status import router as get_task_status
|
||||
from app.routes.analyze_result import router as get_task_result
|
||||
|
||||
load_dotenv()
|
||||
UPLOAD_DIR = os.getenv("UPLOAD_DIR", "uploads")
|
||||
MOCK = os.getenv("MOCK", "false")
|
||||
MODEL = os.getenv("MODEL", "segformer")
|
||||
PREPROCESS = os.getenv("PREPROCESS", "sam3")
|
||||
WORKER = Worker()
|
||||
|
||||
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
||||
|
||||
app = FastAPI()
|
||||
app.mount("/" + UPLOAD_DIR, StaticFiles(directory=UPLOAD_DIR), name=UPLOAD_DIR)
|
||||
app.include_router(health)
|
||||
app.include_router(submit_analyze)
|
||||
app.include_router(get_task_status)
|
||||
app.include_router(get_task_result)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
0
app/routes/__init__.py
Normal file
0
app/routes/__init__.py
Normal file
54
app/routes/analyze.py
Normal file
54
app/routes/analyze.py
Normal file
@ -0,0 +1,54 @@
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from typing import Optional, List
|
||||
from fastapi import APIRouter, UploadFile, File, Response
|
||||
from app.schemas.analyze import Analyze, AnalyzeData
|
||||
from app.services.model import TaskStatus, TaskStore
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/analyze")
|
||||
async def submit_analyze(response: Response, images: Optional[List[UploadFile]] = File(default=[])):
|
||||
from app.main import UPLOAD_DIR, WORKER
|
||||
if not images:
|
||||
response.status_code = 400
|
||||
return Analyze(
|
||||
success=False,
|
||||
data=AnalyzeData(
|
||||
taskId="",
|
||||
status=TaskStatus.FAILED.name,
|
||||
message="请上传图片",
|
||||
estimatedTime="",
|
||||
filesReceived=0
|
||||
)
|
||||
)
|
||||
|
||||
task_id = f"task_{int(time.time() * 1000)}_{uuid.uuid4().hex[:10]}"
|
||||
|
||||
task_dir = os.path.join(UPLOAD_DIR, task_id, "inputs")
|
||||
os.makedirs(task_dir, exist_ok=True)
|
||||
|
||||
saved_paths = []
|
||||
if images:
|
||||
for idx, img in enumerate(images):
|
||||
ext = os.path.splitext(img.filename)[1]
|
||||
file_path = os.path.join(task_dir, f"{idx}{ext}")
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(await img.read())
|
||||
saved_paths.append(file_path)
|
||||
|
||||
WORKER.task_store[task_id] = TaskStore(images=saved_paths)
|
||||
WORKER.task_queue.put(task_id)
|
||||
|
||||
return Analyze(
|
||||
success=True,
|
||||
data=AnalyzeData(
|
||||
taskId=task_id,
|
||||
status=TaskStatus.QUEUED.name,
|
||||
message=f"已提交 {len(saved_paths)} 张图片,正在分析",
|
||||
estimatedTime="",
|
||||
filesReceived=len(saved_paths)
|
||||
)
|
||||
)
|
||||
86
app/routes/analyze_result.py
Normal file
86
app/routes/analyze_result.py
Normal file
@ -0,0 +1,86 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
from fastapi import APIRouter, Response
|
||||
from app.schemas.analyze_result import AnalyzeResult, AnalyzeResultData, ImageInfo, MaskInfo, ResultItem
|
||||
from app.services.model import TaskStatus
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/analyze/result/{task_id}")
|
||||
async def get_task_result(task_id: str, response: Response):
|
||||
from app.main import UPLOAD_DIR, WORKER
|
||||
task = WORKER.task_store.get(task_id)
|
||||
|
||||
if not task:
|
||||
response.status_code = 404
|
||||
return AnalyzeResult(
|
||||
success=False,
|
||||
data=AnalyzeResultData(
|
||||
taskId=task_id,
|
||||
status=TaskStatus.NOT_FOUND.name,
|
||||
completedAt=None,
|
||||
results=None
|
||||
)
|
||||
)
|
||||
|
||||
if task.status == TaskStatus.COMPLETED.name:
|
||||
# 构建完成状态的结果数据
|
||||
result_items = []
|
||||
|
||||
# 输入和输出目录路径
|
||||
input_dir = os.path.join(UPLOAD_DIR, task_id, "inputs")
|
||||
output_dir = os.path.join(UPLOAD_DIR, task_id, "outputs")
|
||||
|
||||
for idx, result_data in enumerate(task.result):
|
||||
# 解析坐标数据
|
||||
coords_data = json.loads(result_data.get("coords", "[]"))
|
||||
|
||||
# 构建图片信息
|
||||
input_img_path = result_data.get("input_img_path", "")
|
||||
output_img_path = result_data.get("output_img_path", "")
|
||||
|
||||
# 构建URL路径
|
||||
input_filename = os.path.basename(input_img_path)
|
||||
output_filename = os.path.basename(output_img_path)
|
||||
|
||||
image_info = ImageInfo(
|
||||
origin=f"/uploads/{task_id}/inputs/{input_filename}",
|
||||
image=f"/uploads/{task_id}/outputs/{output_filename}" if output_img_path is not "" else "",
|
||||
)
|
||||
|
||||
# 构建mask信息
|
||||
masks = [
|
||||
MaskInfo(name=mask["name"], coords=mask["coords"])
|
||||
for mask in coords_data
|
||||
]
|
||||
|
||||
result_item = ResultItem(
|
||||
id=str(idx),
|
||||
images=image_info,
|
||||
masks=masks
|
||||
)
|
||||
result_items.append(result_item)
|
||||
|
||||
response.status_code = 200
|
||||
return AnalyzeResult(
|
||||
success=True,
|
||||
data=AnalyzeResultData(
|
||||
taskId=task_id,
|
||||
status=task.status,
|
||||
completedAt=task.completedAt.isoformat() if task.completedAt else "",
|
||||
results=result_items
|
||||
)
|
||||
)
|
||||
else:
|
||||
# 其他状态(处理中、失败等)
|
||||
return AnalyzeResult(
|
||||
success=True,
|
||||
data=AnalyzeResultData(
|
||||
taskId=task_id,
|
||||
status=task.status,
|
||||
completedAt=None,
|
||||
results=None
|
||||
)
|
||||
)
|
||||
34
app/routes/analyze_status.py
Normal file
34
app/routes/analyze_status.py
Normal file
@ -0,0 +1,34 @@
|
||||
from fastapi import APIRouter, Response
|
||||
from app.schemas.analyze_status import AnalyzeStatus, AnalyzeStatusData
|
||||
from app.services.model import TaskStatus
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/analyze/status/{task_id}")
|
||||
async def get_task_status(task_id: str, response: Response):
|
||||
from app.main import WORKER
|
||||
task = WORKER.task_store.get(task_id)
|
||||
|
||||
if task:
|
||||
response.status_code = 200
|
||||
return AnalyzeStatus(
|
||||
success=True,
|
||||
data=AnalyzeStatusData(
|
||||
taskId=task_id,
|
||||
status=task.status,
|
||||
progress=task.progress,
|
||||
message=task.message
|
||||
)
|
||||
)
|
||||
else:
|
||||
response.status_code = 404
|
||||
return AnalyzeStatus(
|
||||
success=False,
|
||||
data=AnalyzeStatusData(
|
||||
taskId=task_id,
|
||||
status=TaskStatus.NOT_FOUND.name,
|
||||
progress=0,
|
||||
message=TaskStatus.NOT_FOUND.value
|
||||
)
|
||||
)
|
||||
15
app/routes/health.py
Normal file
15
app/routes/health.py
Normal file
@ -0,0 +1,15 @@
|
||||
from datetime import datetime
|
||||
from fastapi import APIRouter
|
||||
from app.schemas.health import Health
|
||||
from app.services.model import TaskStatus
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/health", response_model=Health)
|
||||
async def health():
|
||||
return Health(
|
||||
status=TaskStatus.OK.name,
|
||||
message="算法服务器正常运行",
|
||||
timestamp=str(datetime.now())
|
||||
)
|
||||
0
app/schemas/__init__.py
Normal file
0
app/schemas/__init__.py
Normal file
18
app/schemas/analyze.py
Normal file
18
app/schemas/analyze.py
Normal file
@ -0,0 +1,18 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
""" 提交任务响应结构 """
|
||||
|
||||
|
||||
class AnalyzeData(BaseModel):
|
||||
taskId: Optional[str] = None
|
||||
status: str
|
||||
message: str
|
||||
estimatedTime: Optional[str] = None
|
||||
filesReceived: Optional[int] = None
|
||||
|
||||
|
||||
class Analyze(BaseModel):
|
||||
success: bool
|
||||
data: AnalyzeData
|
||||
32
app/schemas/analyze_result.py
Normal file
32
app/schemas/analyze_result.py
Normal file
@ -0,0 +1,32 @@
|
||||
from typing import List, Optional
|
||||
from pydantic import BaseModel
|
||||
|
||||
""" 获取任务结果响应结构 """
|
||||
|
||||
|
||||
class ImageInfo(BaseModel):
|
||||
origin: str
|
||||
image: str
|
||||
|
||||
|
||||
class MaskInfo(BaseModel):
|
||||
name: str
|
||||
coords: List[List[int]]
|
||||
|
||||
|
||||
class ResultItem(BaseModel):
|
||||
id: str
|
||||
images: ImageInfo
|
||||
masks: List[MaskInfo]
|
||||
|
||||
|
||||
class AnalyzeResultData(BaseModel):
|
||||
taskId: str
|
||||
status: str
|
||||
completedAt: Optional[str] = None
|
||||
results: Optional[List[ResultItem]] = None
|
||||
|
||||
|
||||
class AnalyzeResult(BaseModel):
|
||||
success: bool
|
||||
data: AnalyzeResultData
|
||||
15
app/schemas/analyze_status.py
Normal file
15
app/schemas/analyze_status.py
Normal file
@ -0,0 +1,15 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
""" 获取任务状态响应结构 """
|
||||
|
||||
|
||||
class AnalyzeStatusData(BaseModel):
|
||||
taskId: str
|
||||
status: str
|
||||
progress: int
|
||||
message: str
|
||||
|
||||
|
||||
class AnalyzeStatus(BaseModel):
|
||||
success: bool
|
||||
data: AnalyzeStatusData
|
||||
7
app/schemas/health.py
Normal file
7
app/schemas/health.py
Normal file
@ -0,0 +1,7 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class Health(BaseModel):
|
||||
status: str
|
||||
message: str
|
||||
timestamp: str
|
||||
0
app/services/__init__.py
Normal file
0
app/services/__init__.py
Normal file
66
app/services/model.py
Normal file
66
app/services/model.py
Normal file
@ -0,0 +1,66 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import List, Dict, Optional
|
||||
|
||||
|
||||
class TaskStatus(Enum):
|
||||
OK = "正常运行"
|
||||
QUEUED = "已入队"
|
||||
PROCESSING = "处理中"
|
||||
COMPLETED = "处理完成"
|
||||
FAILED = "处理错误"
|
||||
NOT_FOUND = "任务不存在"
|
||||
|
||||
|
||||
class TaskStore:
|
||||
def __init__(self, images: List[str]):
|
||||
self._status: str = TaskStatus.QUEUED.name
|
||||
self._progress: int = 0
|
||||
self._images: List[str] = images
|
||||
self._result: List[Dict[str, str]] = []
|
||||
self._message: str = ""
|
||||
self._completedAt: Optional[datetime] = None
|
||||
|
||||
@property
|
||||
def images(self):
|
||||
return self._images
|
||||
|
||||
@property
|
||||
def status(self):
|
||||
return self._status
|
||||
|
||||
@status.setter
|
||||
def status(self, status: TaskStatus):
|
||||
self._status = status
|
||||
|
||||
@property
|
||||
def progress(self):
|
||||
return self._progress
|
||||
|
||||
@progress.setter
|
||||
def progress(self, progress: int):
|
||||
self._progress = progress
|
||||
|
||||
@property
|
||||
def result(self):
|
||||
return self._result
|
||||
|
||||
@result.setter
|
||||
def result(self, result: str):
|
||||
self._result = result
|
||||
|
||||
@property
|
||||
def message(self):
|
||||
return self._message
|
||||
|
||||
@message.setter
|
||||
def message(self, message: str):
|
||||
self._message = message
|
||||
|
||||
@property
|
||||
def completedAt(self):
|
||||
return self._completedAt
|
||||
|
||||
@completedAt.setter
|
||||
def completedAt(self, completedAt: datetime):
|
||||
self._completedAt = completedAt
|
||||
80
app/services/worker.py
Normal file
80
app/services/worker.py
Normal file
@ -0,0 +1,80 @@
|
||||
import json
|
||||
import os
|
||||
import cv2
|
||||
import threading
|
||||
from datetime import datetime
|
||||
from queue import Queue
|
||||
from typing import Dict
|
||||
|
||||
from app.core.model import Model
|
||||
from app.core.preprocess import Preprocess
|
||||
from app.services.model import TaskStatus, TaskStore
|
||||
|
||||
|
||||
class Worker:
|
||||
def __init__(self):
|
||||
self.detection = Model().getModel()
|
||||
self.preprocess = Preprocess().getPreprocess()
|
||||
self.task_queue = Queue()
|
||||
self.task_store: Dict[str, TaskStore] = {}
|
||||
|
||||
threading.Thread(target=self.worker, daemon=True).start()
|
||||
|
||||
def worker(self):
|
||||
from app.main import UPLOAD_DIR
|
||||
while True:
|
||||
task_id = self.task_queue.get()
|
||||
if task_id is None:
|
||||
break
|
||||
|
||||
task = self.task_store.get(task_id)
|
||||
if not task:
|
||||
continue
|
||||
|
||||
try:
|
||||
task.status = TaskStatus.PROCESSING.name
|
||||
task.progress = 0
|
||||
print(f"开始处理任务 {task_id}...")
|
||||
|
||||
# 创建输出目录
|
||||
output_dir = os.path.join(UPLOAD_DIR, task_id, "outputs")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# 获取图像的标签列表
|
||||
image_labels = self.preprocess.preprocess(task.images) # 返回一个0和1的列表,0代表跳过,1代表进行检测
|
||||
|
||||
for idx, (input_img_path, label) in enumerate(zip(task.images, image_labels)):
|
||||
print(f"处理任务 {task_id}, 处理图片 {input_img_path}...")
|
||||
|
||||
if label == 0:
|
||||
# 如果标签是0,跳过模型检测,输出路径和坐标为空
|
||||
task.result.append(
|
||||
{"input_img_path": input_img_path, "output_img_path": "", "coords": "[]"}
|
||||
)
|
||||
else:
|
||||
# 进行模型检测
|
||||
img_res, coords_res = self.detection.detect(input_img_path)
|
||||
|
||||
coords_res = [{"name": name, "coords": coords} for name, coords in coords_res]
|
||||
coords_json = json.dumps(coords_res, ensure_ascii=False)
|
||||
|
||||
out_img_path = os.path.join(output_dir, f"{idx}.jpg")
|
||||
cv2.imwrite(out_img_path, img_res)
|
||||
|
||||
task.result.append(
|
||||
{"input_img_path": input_img_path, "output_img_path": out_img_path, "coords": coords_json}
|
||||
)
|
||||
|
||||
task.progress = int((idx + 1) / len(task.images) * 100)
|
||||
|
||||
task.status = TaskStatus.COMPLETED.name
|
||||
task.completedAt = datetime.now()
|
||||
task.message = "处理完成"
|
||||
|
||||
print(f"任务 {task_id} 处理完成")
|
||||
|
||||
except Exception as e:
|
||||
task.status = TaskStatus.FAILED.name
|
||||
task.message = str(e)
|
||||
finally:
|
||||
self.task_queue.task_done()
|
||||
15
requirements.txt
Normal file
15
requirements.txt
Normal file
@ -0,0 +1,15 @@
|
||||
fastapi[all]==0.122.0
|
||||
albumentations==2.0.8
|
||||
albucore==0.0.24
|
||||
numpy==2.2.6
|
||||
onnxruntime-gpu==1.23.2
|
||||
opencv-python==4.12.0.88
|
||||
PyYAML>=6.0
|
||||
scipy>=1.10.0
|
||||
pydantic>=2.10.0
|
||||
typing-extensions>=4.12.0
|
||||
python-dotenv~=1.2.1
|
||||
opencv-python
|
||||
matplotlib
|
||||
pandas
|
||||
tqdm==4.67.1
|
||||
Loading…
x
Reference in New Issue
Block a user