- 新增 analyze、analyze_result、analyze_status 和 health 路由 - 实现图像上传和任务提交功能 - 添加任务状态查询和结果获取接口 - 集成 segformer 和 yolo 模型进行图像检测 - 实现 SAM3 预处理功能用于图像预处理判断 - 添加模型选择配置支持 segformer 和 yolo - 实现任务队列管理和异步处理机制 - 添加 Dockerfile 用于容器化部署 - 配置环境变量和 gitignore 规则 - 创建数据模型定义 API 响应结构
173 lines
5.8 KiB
Python
173 lines
5.8 KiB
Python
import os
|
||
import torch
|
||
from pathlib import Path
|
||
from PIL import Image
|
||
from torch.utils.data import Dataset, DataLoader
|
||
from sam3.model_builder import build_sam3_image_model
|
||
from sam3.train.data.collator import collate_fn_api as collate
|
||
from sam3.model.utils.misc import copy_data_to_device
|
||
from sam3.train.data.sam3_image_dataset import (
|
||
Datapoint, Image as SAMImage, FindQueryLoaded, InferenceMetadata
|
||
)
|
||
from sam3.train.transforms.basic_for_api import ComposeAPI, RandomResizeAPI, ToTensorAPI, NormalizeAPI
|
||
from sam3.eval.postprocessors import PostProcessImage
|
||
|
||
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True,max_split_size_mb:256"
|
||
|
||
# ===== 配置 =====
|
||
CKPT_PATH = os.path.join(os.getcwd(), "app/core/sam3", "sam3.pt")
|
||
DEVICE = "cuda:0"
|
||
|
||
BATCH_SIZE = 12 # 批量大小,前端要设置
|
||
NUM_WORKERS = 12 # 加载图片的线程数,看前端要不要设置
|
||
CONF_TH = 0.5
|
||
RATIO_TH = 0.5 # 阈值,越大的话过滤越多,但太大会影响近景图片
|
||
_GLOBAL_ID = 1
|
||
|
||
PROMPTS = [
|
||
"wall",
|
||
"building wall",
|
||
"building facade",
|
||
"building exterior wall",
|
||
"exterior building facade",
|
||
]
|
||
IMG_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff", ".webp"}
|
||
|
||
|
||
# ============
|
||
|
||
|
||
class ImgPathList(Dataset):
|
||
def __init__(self, img_paths: list):
|
||
"""
|
||
初始化 ImgFolder,传入一个图片路径的列表
|
||
|
||
Args:
|
||
img_paths (list): 一个包含图片路径的列表
|
||
"""
|
||
self.paths = img_paths # 使用传入的路径列表
|
||
|
||
def __len__(self):
|
||
return len(self.paths)
|
||
|
||
def __getitem__(self, i):
|
||
p = self.paths[i] # 直接使用列表中的路径
|
||
img = Image.open(p).convert("RGB") # 打开图片并转换为RGB模式
|
||
return p, img # 返回图片的路径和图片本身
|
||
|
||
|
||
|
||
class SAM3:
|
||
def __init__(self):
|
||
self.dev = torch.device(DEVICE)
|
||
self.postprocessor = PostProcessImage(
|
||
max_dets_per_img=-1,
|
||
iou_type="segm",
|
||
use_original_sizes_box=True,
|
||
use_original_sizes_mask=True,
|
||
convert_mask_to_rle=False,
|
||
detection_threshold=CONF_TH,
|
||
to_cpu=False,
|
||
)
|
||
self.transform = ComposeAPI(
|
||
transforms=[
|
||
RandomResizeAPI(sizes=1008, max_size=1008, square=True, consistent_transform=False),
|
||
ToTensorAPI(),
|
||
NormalizeAPI(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
||
]
|
||
)
|
||
self.model = build_sam3_image_model(
|
||
checkpoint_path=CKPT_PATH, load_from_HF=False, device=DEVICE
|
||
).to(DEVICE).eval()
|
||
|
||
def preprocess(self, image_path_list):
|
||
labels = []
|
||
|
||
loader = DataLoader(
|
||
ImgPathList(image_path_list),
|
||
batch_size=BATCH_SIZE,
|
||
shuffle=False,
|
||
num_workers=NUM_WORKERS,
|
||
pin_memory=True,
|
||
collate_fn=self.collate_fn,
|
||
)
|
||
|
||
with torch.inference_mode():
|
||
for names, images in loader:
|
||
datapoints = []
|
||
name2qids = {} # name -> [qid,...]
|
||
for name, img in zip(names, images):
|
||
dp = self.create_empty_datapoint()
|
||
self.set_image(dp, img)
|
||
|
||
qids = [self.add_text_prompt(dp, p) for p in PROMPTS]
|
||
name2qids[name] = qids
|
||
|
||
datapoints.append(self.transform(dp))
|
||
|
||
batch = collate(datapoints, dict_key="dummy")["dummy"]
|
||
batch = copy_data_to_device(batch, self.dev, non_blocking=True)
|
||
output = self.model(batch)
|
||
|
||
processed = self.postprocessor.process_results(output, batch.find_metadatas)
|
||
|
||
for name in names:
|
||
any_masks = []
|
||
for qid in name2qids[name]:
|
||
res = processed[qid]
|
||
m = res.get("masks", None) # 期望: [N,H,W]
|
||
if m is None:
|
||
any_masks.append(torch.zeros(1, 1, device=self.dev, dtype=torch.bool).squeeze())
|
||
else:
|
||
if not torch.is_tensor(m):
|
||
m = torch.as_tensor(m, device=self.dev)
|
||
any_masks.append(m.any(0)) # [H,W]
|
||
|
||
wall_mask = torch.stack(any_masks, 0).any(0) # [H,W] bool
|
||
ratio = wall_mask.float().mean().item()
|
||
lab = 1 if ratio >= RATIO_TH else 0
|
||
labels.append(lab)
|
||
print(f"{name} | wall_ratio={ratio:.4f} -> {lab}") # 这行可以不要
|
||
|
||
return labels
|
||
|
||
@staticmethod
|
||
def add_text_prompt(datapoint, text_query):
|
||
global _GLOBAL_ID
|
||
assert len(datapoint.images) == 1, "please set the image first"
|
||
w, h = datapoint.images[0].size
|
||
datapoint.find_queries.append(
|
||
FindQueryLoaded(
|
||
query_text=text_query,
|
||
image_id=0,
|
||
object_ids_output=[],
|
||
is_exhaustive=True,
|
||
query_processing_order=0,
|
||
inference_metadata=InferenceMetadata(
|
||
coco_image_id=_GLOBAL_ID,
|
||
original_image_id=_GLOBAL_ID,
|
||
original_category_id=1,
|
||
original_size=[w, h],
|
||
object_id=0,
|
||
frame_index=0,
|
||
),
|
||
)
|
||
)
|
||
_GLOBAL_ID += 1
|
||
return _GLOBAL_ID - 1
|
||
|
||
@staticmethod
|
||
def create_empty_datapoint():
|
||
return Datapoint(find_queries=[], images=[])
|
||
|
||
@staticmethod
|
||
def set_image(datapoint, pil_image):
|
||
w, h = pil_image.size
|
||
datapoint.images = [SAMImage(data=pil_image, objects=[], size=[h, w])] # size 用 [H,W]
|
||
|
||
@staticmethod
|
||
def collate_fn(batch):
|
||
names, imgs = zip(*batch)
|
||
return list(names), list(imgs)
|