import os import torch from pathlib import Path from PIL import Image from torch.utils.data import Dataset, DataLoader from sam3.model_builder import build_sam3_image_model from sam3.train.data.collator import collate_fn_api as collate from sam3.model.utils.misc import copy_data_to_device from sam3.train.data.sam3_image_dataset import ( Datapoint, Image as SAMImage, FindQueryLoaded, InferenceMetadata ) from sam3.train.transforms.basic_for_api import ComposeAPI, RandomResizeAPI, ToTensorAPI, NormalizeAPI from sam3.eval.postprocessors import PostProcessImage os.environ["CUDA_VISIBLE_DEVICES"] = "0" os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True,max_split_size_mb:256" # ===== 配置 ===== CKPT_PATH = os.path.join(os.getcwd(), "app/core/sam3", "sam3.pt") DEVICE = "cuda:0" BATCH_SIZE = 12 # 批量大小,前端要设置 NUM_WORKERS = 12 # 加载图片的线程数,看前端要不要设置 CONF_TH = 0.5 RATIO_TH = 0.5 # 阈值,越大的话过滤越多,但太大会影响近景图片 _GLOBAL_ID = 1 PROMPTS = [ "wall", "building wall", "building facade", "building exterior wall", "exterior building facade", ] IMG_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff", ".webp"} # ============ class ImgPathList(Dataset): def __init__(self, img_paths: list): """ 初始化 ImgFolder,传入一个图片路径的列表 Args: img_paths (list): 一个包含图片路径的列表 """ self.paths = img_paths # 使用传入的路径列表 def __len__(self): return len(self.paths) def __getitem__(self, i): p = self.paths[i] # 直接使用列表中的路径 img = Image.open(p).convert("RGB") # 打开图片并转换为RGB模式 return p, img # 返回图片的路径和图片本身 class SAM3: def __init__(self): self.dev = torch.device(DEVICE) self.postprocessor = PostProcessImage( max_dets_per_img=-1, iou_type="segm", use_original_sizes_box=True, use_original_sizes_mask=True, convert_mask_to_rle=False, detection_threshold=CONF_TH, to_cpu=False, ) self.transform = ComposeAPI( transforms=[ RandomResizeAPI(sizes=1008, max_size=1008, square=True, consistent_transform=False), ToTensorAPI(), NormalizeAPI(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), ] ) self.model = build_sam3_image_model( checkpoint_path=CKPT_PATH, load_from_HF=False, device=DEVICE ).to(DEVICE).eval() def preprocess(self, image_path_list): labels = [] loader = DataLoader( ImgPathList(image_path_list), batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True, collate_fn=self.collate_fn, ) with torch.inference_mode(): for names, images in loader: datapoints = [] name2qids = {} # name -> [qid,...] for name, img in zip(names, images): dp = self.create_empty_datapoint() self.set_image(dp, img) qids = [self.add_text_prompt(dp, p) for p in PROMPTS] name2qids[name] = qids datapoints.append(self.transform(dp)) batch = collate(datapoints, dict_key="dummy")["dummy"] batch = copy_data_to_device(batch, self.dev, non_blocking=True) output = self.model(batch) processed = self.postprocessor.process_results(output, batch.find_metadatas) for name in names: any_masks = [] for qid in name2qids[name]: res = processed[qid] m = res.get("masks", None) # 期望: [N,H,W] if m is None: any_masks.append(torch.zeros(1, 1, device=self.dev, dtype=torch.bool).squeeze()) else: if not torch.is_tensor(m): m = torch.as_tensor(m, device=self.dev) any_masks.append(m.any(0)) # [H,W] wall_mask = torch.stack(any_masks, 0).any(0) # [H,W] bool ratio = wall_mask.float().mean().item() lab = 1 if ratio >= RATIO_TH else 0 labels.append(lab) print(f"{name} | wall_ratio={ratio:.4f} -> {lab}") # 这行可以不要 return labels @staticmethod def add_text_prompt(datapoint, text_query): global _GLOBAL_ID assert len(datapoint.images) == 1, "please set the image first" w, h = datapoint.images[0].size datapoint.find_queries.append( FindQueryLoaded( query_text=text_query, image_id=0, object_ids_output=[], is_exhaustive=True, query_processing_order=0, inference_metadata=InferenceMetadata( coco_image_id=_GLOBAL_ID, original_image_id=_GLOBAL_ID, original_category_id=1, original_size=[w, h], object_id=0, frame_index=0, ), ) ) _GLOBAL_ID += 1 return _GLOBAL_ID - 1 @staticmethod def create_empty_datapoint(): return Datapoint(find_queries=[], images=[]) @staticmethod def set_image(datapoint, pil_image): w, h = pil_image.size datapoint.images = [SAMImage(data=pil_image, objects=[], size=[h, w])] # size 用 [H,W] @staticmethod def collate_fn(batch): names, imgs = zip(*batch) return list(names), list(imgs)