wall_docker/app/core/sam3/preprocess.py
Boen_Shi 6a2e046884 feat(api): 添加图像分析功能和相关路由接口
- 新增 analyze、analyze_result、analyze_status 和 health 路由
- 实现图像上传和任务提交功能
- 添加任务状态查询和结果获取接口
- 集成 segformer 和 yolo 模型进行图像检测
- 实现 SAM3 预处理功能用于图像预处理判断
- 添加模型选择配置支持 segformer 和 yolo
- 实现任务队列管理和异步处理机制
- 添加 Dockerfile 用于容器化部署
- 配置环境变量和 gitignore 规则
- 创建数据模型定义 API 响应结构
2026-01-27 11:59:45 +08:00

173 lines
5.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import torch
from pathlib import Path
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from sam3.model_builder import build_sam3_image_model
from sam3.train.data.collator import collate_fn_api as collate
from sam3.model.utils.misc import copy_data_to_device
from sam3.train.data.sam3_image_dataset import (
Datapoint, Image as SAMImage, FindQueryLoaded, InferenceMetadata
)
from sam3.train.transforms.basic_for_api import ComposeAPI, RandomResizeAPI, ToTensorAPI, NormalizeAPI
from sam3.eval.postprocessors import PostProcessImage
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True,max_split_size_mb:256"
# ===== 配置 =====
CKPT_PATH = os.path.join(os.getcwd(), "app/core/sam3", "sam3.pt")
DEVICE = "cuda:0"
BATCH_SIZE = 12 # 批量大小,前端要设置
NUM_WORKERS = 12 # 加载图片的线程数,看前端要不要设置
CONF_TH = 0.5
RATIO_TH = 0.5 # 阈值,越大的话过滤越多,但太大会影响近景图片
_GLOBAL_ID = 1
PROMPTS = [
"wall",
"building wall",
"building facade",
"building exterior wall",
"exterior building facade",
]
IMG_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff", ".webp"}
# ============
class ImgPathList(Dataset):
def __init__(self, img_paths: list):
"""
初始化 ImgFolder传入一个图片路径的列表
Args:
img_paths (list): 一个包含图片路径的列表
"""
self.paths = img_paths # 使用传入的路径列表
def __len__(self):
return len(self.paths)
def __getitem__(self, i):
p = self.paths[i] # 直接使用列表中的路径
img = Image.open(p).convert("RGB") # 打开图片并转换为RGB模式
return p, img # 返回图片的路径和图片本身
class SAM3:
def __init__(self):
self.dev = torch.device(DEVICE)
self.postprocessor = PostProcessImage(
max_dets_per_img=-1,
iou_type="segm",
use_original_sizes_box=True,
use_original_sizes_mask=True,
convert_mask_to_rle=False,
detection_threshold=CONF_TH,
to_cpu=False,
)
self.transform = ComposeAPI(
transforms=[
RandomResizeAPI(sizes=1008, max_size=1008, square=True, consistent_transform=False),
ToTensorAPI(),
NormalizeAPI(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
]
)
self.model = build_sam3_image_model(
checkpoint_path=CKPT_PATH, load_from_HF=False, device=DEVICE
).to(DEVICE).eval()
def preprocess(self, image_path_list):
labels = []
loader = DataLoader(
ImgPathList(image_path_list),
batch_size=BATCH_SIZE,
shuffle=False,
num_workers=NUM_WORKERS,
pin_memory=True,
collate_fn=self.collate_fn,
)
with torch.inference_mode():
for names, images in loader:
datapoints = []
name2qids = {} # name -> [qid,...]
for name, img in zip(names, images):
dp = self.create_empty_datapoint()
self.set_image(dp, img)
qids = [self.add_text_prompt(dp, p) for p in PROMPTS]
name2qids[name] = qids
datapoints.append(self.transform(dp))
batch = collate(datapoints, dict_key="dummy")["dummy"]
batch = copy_data_to_device(batch, self.dev, non_blocking=True)
output = self.model(batch)
processed = self.postprocessor.process_results(output, batch.find_metadatas)
for name in names:
any_masks = []
for qid in name2qids[name]:
res = processed[qid]
m = res.get("masks", None) # 期望: [N,H,W]
if m is None:
any_masks.append(torch.zeros(1, 1, device=self.dev, dtype=torch.bool).squeeze())
else:
if not torch.is_tensor(m):
m = torch.as_tensor(m, device=self.dev)
any_masks.append(m.any(0)) # [H,W]
wall_mask = torch.stack(any_masks, 0).any(0) # [H,W] bool
ratio = wall_mask.float().mean().item()
lab = 1 if ratio >= RATIO_TH else 0
labels.append(lab)
print(f"{name} | wall_ratio={ratio:.4f} -> {lab}") # 这行可以不要
return labels
@staticmethod
def add_text_prompt(datapoint, text_query):
global _GLOBAL_ID
assert len(datapoint.images) == 1, "please set the image first"
w, h = datapoint.images[0].size
datapoint.find_queries.append(
FindQueryLoaded(
query_text=text_query,
image_id=0,
object_ids_output=[],
is_exhaustive=True,
query_processing_order=0,
inference_metadata=InferenceMetadata(
coco_image_id=_GLOBAL_ID,
original_image_id=_GLOBAL_ID,
original_category_id=1,
original_size=[w, h],
object_id=0,
frame_index=0,
),
)
)
_GLOBAL_ID += 1
return _GLOBAL_ID - 1
@staticmethod
def create_empty_datapoint():
return Datapoint(find_queries=[], images=[])
@staticmethod
def set_image(datapoint, pil_image):
w, h = pil_image.size
datapoint.images = [SAMImage(data=pil_image, objects=[], size=[h, w])] # size 用 [H,W]
@staticmethod
def collate_fn(batch):
names, imgs = zip(*batch)
return list(names), list(imgs)