很多朋友来问数据标注是什么格式,因此添加作解答。
运行代码末尾提供的demo,既可以生成标注格式的demo示例。
- python sam-data-setup.py
-
数据集目录下,放images文件夹、masks文件夹、和annotations.txt,
images里放原始图片,这里随机生成的。可在这个文件夹里放入自己的数据。
images里放对应的掩码图像,并且对应更改文件后缀名,在这个文件夹里放入自己数据对应的标签掩码图像。
annotations.txt里放图片对应的检测框坐标信息。
Segment Anything Model (SAM) 是 Meta AI 推出的一个强大的图像分割模型。尽管预训练模型表现优秀,但在特定领域(如医疗影像、工业检测等)可能需要进行微调以获得更好的性能。本文将详细介绍如何微调 SAM 模型,包括环境配置、数据准备和训练实现。
首先,我们需要配置正确的 Python 环境和依赖包。推荐使用虚拟环境来管理依赖:
- # 创建并激活虚拟环境
- python -m venv sam_env
- # Windows:
- .\sam_env\Scripts\activate
- # Linux/Mac:
- source sam_env/bin/activate
-
- # 安装依赖
- pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
- pip install opencv-python
- pip install git+https://github.com/facebookresearch/segment-anything.git
- pip install numpy matplotlib
-
- # 下载预训练模型
- # Windows PowerShell:
- Invoke-WebRequest -Uri "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth" -OutFile "sam_vit_b_01ec64.pth"
- # Linux/Mac:
- wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth
-
推荐的项目结构如下:
- project_root/
- ├── stamps/
- │ ├── images/ # 训练图像
- │ ├── masks/ # 分割掩码
- │ └── annotations.txt # 边界框标注
- ├── checkpoints/ # 模型检查点
- ├── setup_sam_data.py # 数据准备脚本
- └── sam_finetune.py # 训练脚本
-
为了训练模型,我们需要准备以下数据:
以下是数据准备脚本的实现:
- import os
- import numpy as np
- import cv2
- from pathlib import Path
-
- def create_project_structure():
- """创建项目所需的目录结构"""
- directories = [
- './stamps/images',
- './stamps/masks',
- './checkpoints'
- ]
-
- for dir_path in directories:
- Path(dir_path).mkdir(parents=True, exist_ok=True)
-
- return directories
-
- def create_sample_data(num_samples=5):
- """创建示例训练数据"""
- annotations = []
-
- for i in range(num_samples):
- # 创建示例图像
- image = np.ones((500, 500, 3), dtype=np.uint8) * 255
- center_x = np.random.randint(150, 350)
- center_y = np.random.randint(150, 350)
- radius = np.random.randint(50, 100)
-
- # 绘制对象
- cv2.circle(image, (center_x, center_y), radius, (0, 0, 255), -1)
-
- # 创建掩码
- mask = np.zeros((500, 500), dtype=np.uint8)
- cv2.circle(mask, (center_x, center_y), radius, 255, -1)
-
- # 保存文件
- cv2.imwrite(f'./stamps/images/sample_{i}.jpg', image)
- cv2.imwrite(f'./stamps/masks/sample_{i}_mask.png', mask)
-
- # 计算边界框
- x1 = max(0, center_x - radius)
- y1 = max(0, center_y - radius)
- x2 = min(500, center_x + radius)
- y2 = min(500, center_y + radius)
-
- annotations.append(f'sample_{i}.jpg,{x1},{y1},{x2},{y2}\n')
-
- # 保存标注文件
- with open('./stamps/annotations.txt', 'w') as f:
- f.writelines(annotations)
-
首先实现自定义数据集类:
- class StampDataset(Dataset):
- def __init__(self, image_dir, mask_dir, bbox_file):
- self.image_dir = image_dir
- self.mask_dir = mask_dir
- self.transform = ResizeLongestSide(1024)
-
- # 加载标注
- self.annotations = []
- with open(bbox_file, 'r') as f:
- for line in f:
- img_name, x1, y1, x2, y2 = line.strip().split(',')
- self.annotations.append({
- 'image': img_name,
- 'bbox': [float(x1), float(y1), float(x2), float(y2)]
- })
-
- def __len__(self):
- return len(self.annotations)
-
- def __getitem__(self, idx):
- ann = self.annotations[idx]
-
- # 加载和预处理图像
- image = cv2.imread(os.path.join(self.image_dir, ann['image']))
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
- mask = cv2.imread(os.path.join(self.mask_dir,
- ann['image'].replace('.jpg', '_mask.png')),
- cv2.IMREAD_GRAYSCALE)
- mask = mask.astype(np.float32) / 255.0
-
- # 图像处理
- original_size = image.shape[:2]
- input_image = self.transform.apply_image(image)
- input_image = input_image.astype(np.float32) / 255.0
- input_image = torch.from_numpy(input_image).permute(2, 0, 1)
-
- # 标准化
- mean = torch.tensor([0.485, 0.456, 0.406]).view(-1, 1, 1)
- std = torch.tensor([0.229, 0.224, 0.225]).view(-1, 1, 1)
- input_image = (input_image - mean) / std
-
- # 处理边界框和掩码
- bbox = self.transform.apply_boxes(np.array([ann['bbox']]), original_size)[0]
- bbox_torch = torch.tensor(bbox, dtype=torch.float).unsqueeze(0)
- mask_torch = torch.from_numpy(mask).float().unsqueeze(0)
-
- return {
- 'image': input_image.float(),
- 'original_size': original_size,
- 'bbox': bbox_torch,
- 'mask': mask_torch
- }
-
训练函数的核心实现:
- def train_sam(
- model_type='vit_b',
- checkpoint_path='sam_vit_b_01ec64.pth',
- num_epochs=10,
- batch_size=1,
- learning_rate=1e-5
- ):
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
-
- # 初始化模型
- sam_model = sam_model_registry[model_type](checkpoint=checkpoint_path)
- sam_model.to(device)
-
- # 准备数据和优化器
- dataset = StampDataset(image_dir='./stamps/images',
- mask_dir='./stamps/masks',
- bbox_file='./stamps/annotations.txt')
- dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
- optimizer = torch.optim.Adam(sam_model.mask_decoder.parameters(), lr=learning_rate)
- loss_fn = torch.nn.MSELoss()
-
- # 训练循环
- for epoch in range(num_epochs):
- total_loss = 0
- for batch_idx, batch in enumerate(dataloader):
- # 准备数据
- input_image = batch['image'].to(device)
- original_size = batch['original_size']
- bbox = batch['bbox'].to(device)
- gt_mask = batch['mask'].to(device)
-
- # 前向传播
- with torch.no_grad():
- image_embedding = sam_model.image_encoder(input_image)
- sparse_embeddings, dense_embeddings = sam_model.prompt_encoder(
- points=None,
- boxes=bbox,
- masks=None,
- )
-
- # 生成预测
- mask_predictions, _ = sam_model.mask_decoder(
- image_embeddings=image_embedding,
- image_pe=sam_model.prompt_encoder.get_dense_pe(),
- sparse_prompt_embeddings=sparse_embeddings,
- dense_prompt_embeddings=dense_embeddings,
- multimask_output=False,
- )
-
- # 后处理
- upscaled_masks = sam_model.postprocess_masks(
- mask_predictions,
- input_size=input_image.shape[-2:],
- original_size=original_size[0]
- ).to(device)
-
- binary_masks = torch.sigmoid(upscaled_masks)
-
- # 计算损失并优化
- loss = loss_fn(binary_masks, gt_mask)
- optimizer.zero_grad()
- loss.backward()
- optimizer.step()
-
- total_loss += loss.item()
-
- if batch_idx % 10 == 0:
- print(f'Epoch: {epoch}, Batch: {batch_idx}, Loss: {loss.item():.4f}')
-
- # 输出epoch统计
- avg_loss = total_loss / len(dataloader)
- print(f'Epoch {epoch} completed. Average Loss: {avg_loss:.4f}')
-
- # 保存检查点
- if (epoch + 1) % 5 == 0:
- checkpoint_file = f'./checkpoints/sam_finetuned_epoch_{epoch+1}.pth'
- torch.save(sam_model.state_dict(), checkpoint_file)
-
完整的训练过程如下:
- python setup_sam_data.py
-
- python sam_finetune.py
-
在完成模型微调后,我们需要一个方便的方式来使用模型进行预测并可视化结果。以下是完整的实现:
首先,我们封装一个预测器类,用于处理模型加载、图像预处理和预测:
- class SAMPredictor:
- def __init__(self, checkpoint_path, model_type="vit_b", device="cuda"):
- self.device = torch.device(device if torch.cuda.is_available() and device == "cuda" else "cpu")
- self.sam_model = sam_model_registry[model_type](checkpoint=checkpoint_path)
- self.sam_model.to(self.device)
- self.transform = ResizeLongestSide(1024)
-
这个类提供了简单的接口来加载模型并进行预测。主要功能包括:
为了better展示预测结果,我们实现了一个可视化函数:
- def visualize_prediction(image, mask, bbox, confidence, save_path=None):
- plt.figure(figsize=(15, 5))
- # 显示原始图像、预测掩码和叠加结果
- ...
-
这个函数可以同时显示:
以下是如何使用这些工具的完整示例:
- # 初始化预测器
- predictor = SAMPredictor("./checkpoints/sam_finetuned_final.pth")
-
- # 读取测试图像
- image = cv2.imread("test_image.jpg")
- bbox = [x1, y1, x2, y2] # 边界框坐标
-
- # 预测
- mask, confidence = predictor.predict(image, bbox)
-
- # 可视化
- visualize_prediction(image, mask, bbox, confidence, "result.png")
-
在使用预测器时,需要注意以下几点:
- def predict_batch(self, images, bboxes):
- results = []
- for image, bbox in zip(images, bboxes):
- mask, conf = self.predict(image, bbox)
- results.append((mask, conf))
- return results
-
- def predict_multiple_boxes(self, image, bboxes):
- masks = []
- for bbox in bboxes:
- mask, _ = self.predict(image, bbox)
- masks.append(mask)
- return np.stack(masks)
-
- def interactive_visualization(image, predictor):
- def onclick(event):
- if event.button == 1: # 左键点击
- bbox = [event.xdata-50, event.ydata-50,
- event.xdata+50, event.ydata+50]
- mask, _ = predictor.predict(image, bbox)
- visualize_prediction(image, mask, bbox)
-
- fig, ax = plt.subplots()
- ax.imshow(image)
- fig.canvas.mpl_connect('button_press_event', onclick)
- plt.show()
-
这些工具和示例可以帮助你更好地理解和使用微调后的SAM模型。根据具体需求,你可以进一步优化和扩展这些功能。
通过以上步骤,我们实现了SAM模型的微调过程。这个实现可以作为基础,根据具体需求进行优化和改进。在实际应用中,可能需要根据具体任务调整数据预处理、损失函数和训练策略。
建议在使用时注意以下几点:
希望这个教程对你的项目有所帮助!如果有任何问题,欢迎讨论和交流。
下载这三个代码,配置好运行环境,依次运行:
- # sam-data-setup.py
- import os
- import numpy as np
- import cv2
- from pathlib import Path
-
- def create_project_structure():
- """创建项目所需的目录结构"""
- # 创建主目录
- directories = [
- './stamps/images',
- './stamps/masks',
- './checkpoints'
- ]
-
- for dir_path in directories:
- Path(dir_path).mkdir(parents=True, exist_ok=True)
-
- return directories
-
- def create_sample_data(num_samples=5):
- """创建示例训练数据"""
- # 创建示例图像和掩码
- annotations = []
-
- for i in range(num_samples):
- # 创建示例图像 (500x500)
- image = np.ones((500, 500, 3), dtype=np.uint8) * 255
- # 添加一个示例印章 (随机位置的圆形)
- center_x = np.random.randint(150, 350)
- center_y = np.random.randint(150, 350)
- radius = np.random.randint(50, 100)
-
- # 绘制印章
- cv2.circle(image, (center_x, center_y), radius, (0, 0, 255), -1)
-
- # 创建对应的掩码
- mask = np.zeros((500, 500), dtype=np.uint8)
- cv2.circle(mask, (center_x, center_y), radius, 255, -1)
-
- # 保存图像和掩码
- image_path = f'./stamps/images/sample_{i}.jpg'
- mask_path = f'./stamps/masks/sample_{i}_mask.png'
-
- cv2.imwrite(image_path, image)
- cv2.imwrite(mask_path, mask)
-
- # 计算边界框
- x1 = max(0, center_x - radius)
- y1 = max(0, center_y - radius)
- x2 = min(500, center_x + radius)
- y2 = min(500, center_y + radius)
-
- # 添加到注释列表
- annotations.append(f'sample_{i}.jpg,{x1},{y1},{x2},{y2}\n')
-
- # 保存注释文件
- with open('./stamps/annotations.txt', 'w') as f:
- f.writelines(annotations)
-
- def main():
- print("开始创建项目结构...")
- directories = create_project_structure()
- for dir_path in directories:
- print(f"创建目录: {dir_path}")
-
- print("\n创建示例训练数据...")
- create_sample_data()
- print("示例数据创建完成!")
-
- print("\n项目结构:")
- for root, dirs, files in os.walk('./stamps'):
- level = root.replace('./stamps', '').count(os.sep)
- indent = ' ' * 4 * level
- print(f"{indent}{os.path.basename(root)}/")
- sub_indent = ' ' * 4 * (level + 1)
- for f in files:
- print(f"{sub_indent}{f}")
-
- if __name__ == '__main__':
- main()
-
- # sam_finetune_decoder.py
- import torch
- import numpy as np
- from segment_anything import sam_model_registry, SamPredictor
- from segment_anything.utils.transforms import ResizeLongestSide
- from torch.utils.data import Dataset, DataLoader
- import cv2
- import os
-
- class StampDataset(Dataset):
- def __init__(self, image_dir, mask_dir, bbox_file, target_size=(1024, 1024)):
- self.image_dir = image_dir
- self.mask_dir = mask_dir
- self.target_size = target_size
- self.transform = ResizeLongestSide(1024) # SAM default size
-
- # Load bbox annotations
- self.annotations = []
- with open(bbox_file, 'r') as f:
- for line in f:
- img_name, x1, y1, x2, y2 = line.strip().split(',')
- self.annotations.append({
- 'image': img_name,
- 'bbox': [float(x1), float(y1), float(x2), float(y2)]
- })
-
- def resize_with_bbox(self, image, mask, bbox):
- """调整图像、掩码和边界框的大小"""
- h, w = image.shape[:2]
- target_h, target_w = self.target_size
-
- # 计算缩放比例
- scale_x = target_w / w
- scale_y = target_h / h
-
- # 调整图像大小
- resized_image = cv2.resize(image, (target_w, target_h), interpolation=cv2.INTER_LINEAR)
-
- # 调整掩码大小
- if mask is not None:
- resized_mask = cv2.resize(mask, (target_w, target_h), interpolation=cv2.INTER_NEAREST)
- else:
- resized_mask = None
-
- # 调整边界框
- resized_bbox = [
- bbox[0] * scale_x, # x1
- bbox[1] * scale_y, # y1
- bbox[2] * scale_x, # x2
- bbox[3] * scale_y # y2
- ]
-
- return resized_image, resized_mask, resized_bbox
-
- def __len__(self):
- return len(self.annotations)
-
- def __getitem__(self, idx):
- ann = self.annotations[idx]
-
- # Load image
- image = cv2.imread(os.path.join(self.image_dir, ann['image']))
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
-
- # Load mask
- mask_name = ann['image'].replace('.jpg', '_mask.png')
- mask = cv2.imread(os.path.join(self.mask_dir, mask_name), cv2.IMREAD_GRAYSCALE)
- mask = mask.astype(np.float32) / 255.0
-
- # 首先将图像调整为统一大小
- image, mask, bbox = self.resize_with_bbox(image, mask, ann['bbox'])
-
- # 准备图像
- original_size = self.target_size
- input_image = self.transform.apply_image(image)
-
- # Convert to float32 and normalize to 0-1 range
- input_image = input_image.astype(np.float32) / 255.0
-
- # Convert to tensor and normalize according to ImageNet stats
- input_image = torch.from_numpy(input_image).permute(2, 0, 1).contiguous()
-
- # Apply ImageNet normalization
- mean = torch.tensor([0.485, 0.456, 0.406]).view(-1, 1, 1)
- std = torch.tensor([0.229, 0.224, 0.225]).view(-1, 1, 1)
- input_image = (input_image - mean) / std
-
- # Prepare bbox
- bbox = self.transform.apply_boxes(np.array([bbox]), original_size)[0]
- bbox_torch = torch.tensor(bbox, dtype=torch.float).unsqueeze(0)
-
- # Prepare mask
- mask_torch = torch.from_numpy(mask).float().unsqueeze(0)
-
- return {
- 'image': input_image.float(),
- 'original_size': original_size,
- 'bbox': bbox_torch,
- 'mask': mask_torch
- }
-
- def train_sam(
- model_type='vit_b',
- checkpoint_path='sam_vit_b_01ec64.pth',
- image_dir='./stamps/images',
- mask_dir='./stamps/masks',
- bbox_file='./stamps/annotations.txt',
- output_dir='./checkpoints',
- num_epochs=10,
- batch_size=1,
- learning_rate=1e-5
- ):
- # Setup device
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
- print(f"Using device: {device}")
-
- # Initialize model
- sam_model = sam_model_registry[model_type](checkpoint=checkpoint_path)
- sam_model.to(device)
-
- # Prepare dataset
- dataset = StampDataset(image_dir, mask_dir, bbox_file)
- dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
-
- # Setup optimizer
- optimizer = torch.optim.Adam(sam_model.mask_decoder.parameters(), lr=learning_rate)
-
- # Loss function
- loss_fn = torch.nn.MSELoss()
-
- # Training loop
- for epoch in range(num_epochs):
- total_loss = 0
- for batch_idx, batch in enumerate(dataloader):
- # Move inputs to device
- input_image = batch['image'].to(device)
- original_size = batch['original_size']
- bbox = batch['bbox'].to(device)
- gt_mask = batch['mask'].to(device)
-
- # Print shapes and types for debugging
- if batch_idx == 0 and epoch == 0:
- print(f"Input image shape: {input_image.shape}")
- print(f"Input image type: {input_image.dtype}")
- print(f"Input image range: [{input_image.min():.2f}, {input_image.max():.2f}]")
-
- # Get image embedding (without gradient)
- with torch.no_grad():
- image_embedding = sam_model.image_encoder(input_image)
-
- # Get prompt embeddings
- sparse_embeddings, dense_embeddings = sam_model.prompt_encoder(
- points=None,
- boxes=bbox,
- masks=None,
- )
-
- # Generate mask prediction
- mask_predictions, iou_predictions = sam_model.mask_decoder(
- image_embeddings=image_embedding,
- image_pe=sam_model.prompt_encoder.get_dense_pe(),
- sparse_prompt_embeddings=sparse_embeddings,
- dense_prompt_embeddings=dense_embeddings,
- multimask_output=False,
- )
-
- # Upscale masks to original size
- upscaled_masks = sam_model.postprocess_masks(
- mask_predictions,
- input_size=input_image.shape[-2:],
- original_size=original_size[0]
- ).to(device)
-
- # Convert to binary mask
- binary_masks = torch.sigmoid(upscaled_masks)
-
- # Calculate loss
- loss = loss_fn(binary_masks, gt_mask)
-
- # Optimize
- optimizer.zero_grad()
- loss.backward()
- optimizer.step()
-
- total_loss += loss.item()
-
- if batch_idx % 10 == 0:
- print(f'Epoch: {epoch}, Batch: {batch_idx}, Loss: {loss.item():.4f}')
-
- avg_loss = total_loss / len(dataloader)
- print(f'Epoch {epoch} completed. Average Loss: {avg_loss:.4f}')
-
- # Save checkpoint
- if (epoch + 1) % 5 == 0:
- checkpoint_file = os.path.join(output_dir, f'sam_finetuned_epoch_{epoch+1}.pth')
- torch.save(sam_model.state_dict(), checkpoint_file)
- print(f'Checkpoint saved: {checkpoint_file}')
-
- # Save final model
- final_checkpoint = os.path.join(output_dir, 'sam_finetuned_final.pth')
- torch.save(sam_model.state_dict(), final_checkpoint)
- print(f'Final model saved to {final_checkpoint}')
-
- if __name__ == '__main__':
- # Create output directory if it doesn't exist
- os.makedirs('./checkpoints', exist_ok=True)
-
- # Start training
- train_sam()
-
-
- import torch
- import numpy as np
- import matplotlib.pyplot as plt
- from segment_anything import sam_model_registry, SamPredictor
- from segment_anything.utils.transforms import ResizeLongestSide
- import cv2
- from pathlib import Path
-
- class SAMPredictor:
- def __init__(self, checkpoint_path, model_type="vit_b", device="cuda"):
- """
- 初始化SAM预测器
- Args:
- checkpoint_path: 模型权重路径
- model_type: 模型类型 ("vit_h", "vit_l", "vit_b")
- device: 使用设备 ("cuda" or "cpu")
- """
- self.device = torch.device(device if torch.cuda.is_available() and device == "cuda" else "cpu")
- print(f"Using device: {self.device}")
-
- # 加载模型
- self.sam_model = sam_model_registry[model_type](checkpoint=checkpoint_path)
- self.sam_model.to(self.device)
-
- # 创建图像变换器
- self.transform = ResizeLongestSide(1024)
-
- def resize_bbox(self, bbox, original_size, target_size=(1024, 1024)):
- """
- 调整边界框坐标以匹配调整大小后的图像
- Args:
- bbox: 原始边界框坐标 [x1, y1, x2, y2]
- original_size: 原始图像尺寸 (height, width)
- target_size: 目标图像尺寸 (height, width)
- Returns:
- resized_bbox: 调整后的边界框坐标
- """
- orig_h, orig_w = original_size
- target_h, target_w = target_size
-
- # 计算缩放比例
- scale_x = target_w / orig_w
- scale_y = target_h / orig_h
-
- # 调整边界框坐标
- x1, y1, x2, y2 = bbox
- resized_bbox = [
- x1 * scale_x,
- y1 * scale_y,
- x2 * scale_x,
- y2 * scale_y
- ]
-
- return resized_bbox
-
- def preprocess_image(self, image):
- """预处理输入图像"""
- # 保存原始尺寸
- original_size = image.shape[:2]
-
- # 确保图像是RGB格式
- if len(image.shape) == 2:
- image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
- elif image.shape[2] == 4:
- image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGB)
- elif len(image.shape) == 3 and image.shape[2] == 3:
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
-
- # 调整图像大小
- image = cv2.resize(image, (1024, 1024), interpolation=cv2.INTER_LINEAR)
-
- # 转换为float32并归一化
- input_image = image.astype(np.float32) / 255.0
-
- # 转换为tensor并添加batch维度
- input_image = torch.from_numpy(input_image).permute(2, 0, 1).unsqueeze(0)
-
- # 标准化
- mean = torch.tensor([0.485, 0.456, 0.406]).view(-1, 1, 1)
- std = torch.tensor([0.229, 0.224, 0.225]).view(-1, 1, 1)
- input_image = (input_image - mean) / std
-
- return input_image.to(self.device), original_size, image
-
- def predict(self, image, bbox):
- """
- 预测单个图像的分割掩码
- Args:
- image: numpy array 格式的图像
- bbox: [x1, y1, x2, y2] 格式的边界框
- Returns:
- binary_mask: 二值化的分割掩码
- confidence: 预测的置信度
- """
- # 预处理图像
- input_image, original_size, resized_image = self.preprocess_image(image)
-
- # 调整边界框大小
- resized_bbox = self.resize_bbox(bbox, original_size)
- print(resized_bbox, image.shape, resized_image.shape)
-
- # 准备边界框
- bbox_torch = torch.tensor(resized_bbox, dtype=torch.float, device=self.device).unsqueeze(0)
-
- # 获取图像嵌入
- with torch.no_grad():
- image_embedding = self.sam_model.image_encoder(input_image)
-
- # 获取提示嵌入
- sparse_embeddings, dense_embeddings = self.sam_model.prompt_encoder(
- points=None,
- boxes=bbox_torch,
- masks=None,
- )
-
- # 生成掩码预测
- mask_predictions, iou_predictions = self.sam_model.mask_decoder(
- image_embeddings=image_embedding,
- image_pe=self.sam_model.prompt_encoder.get_dense_pe(),
- sparse_prompt_embeddings=sparse_embeddings,
- dense_prompt_embeddings=dense_embeddings,
- multimask_output=False,
- )
-
- # 后处理掩码
- upscaled_masks = self.sam_model.postprocess_masks(
- mask_predictions,
- input_size=input_image.shape[-2:],
- original_size=original_size
- ).to(self.device)
-
- # 转换为二值掩码
- binary_mask = torch.sigmoid(upscaled_masks) > 0.5
-
- return binary_mask[0, 0].cpu().numpy(), iou_predictions[0, 0].item()
-
- def visualize_prediction(image, mask, bbox, confidence, save_path=None):
- """
- 可视化预测结果
- Args:
- image: 原始图像
- mask: 预测的掩码
- bbox: 边界框坐标
- confidence: 预测置信度
- save_path: 保存路径(可选)
- """
- # 创建图形
- plt.figure(figsize=(15, 5))
-
- # 显示原始图像
- plt.subplot(131)
- plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
- plt.title('Original Image')
- # 绘制边界框
- x1, y1, x2, y2 = map(int, bbox)
- plt.plot([x1, x2, x2, x1, x1], [y1, y1, y2, y2, y1], 'r-', linewidth=2)
- plt.axis('off')
-
- # 显示预测掩码
- plt.subplot(132)
- plt.imshow(mask, cmap='gray')
- plt.title(f'Predicted Mask\nConfidence: {confidence:.2f}')
- plt.axis('off')
-
- # 显示叠加结果
- plt.subplot(133)
- overlay = image.copy()
- overlay[mask > 0] = overlay[mask > 0] * 0.7 + np.array([0, 255, 0], dtype=np.uint8) * 0.3
- plt.imshow(cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB))
- plt.title('Overlay')
- plt.axis('off')
-
- plt.tight_layout()
-
- if save_path:
- plt.savefig(save_path)
- print(f"结果已保存到: {save_path}")
-
- plt.show()
-
- def main():
- # 配置参数
- checkpoint_path = "./checkpoints/sam_finetuned_final.pth" # 使用微调后的模型
- test_image_path = "./stamps/images/sample_0.jpg"
- output_dir = "./predictions"
-
- # 创建输出目录
- Path(output_dir).mkdir(parents=True, exist_ok=True)
-
- # 初始化预测器
- predictor = SAMPredictor(checkpoint_path)
-
- # 读取测试图像
- image = cv2.imread(test_image_path)
- # image = cv2.resize(image, (1024, 1024), interpolation=cv2.INTER_LINEAR)
-
- # 读取边界框(这里使用示例边界框,实际应用中可能需要从标注文件读取)
- with open('./stamps/annotations.txt', 'r') as f:
- first_line = f.readline().strip()
- _, x1, y1, x2, y2 = first_line.split(',')
- bbox = [float(x1), float(y1), float(x2), float(y2)]
- print(bbox)
-
- # 进行预测
- mask, confidence = predictor.predict(image, bbox)
-
- # 可视化结果
- save_path = str(Path(output_dir) / "prediction_result.png")
- visualize_prediction(image, mask, bbox, confidence, save_path)
-
- if __name__ == "__main__":
- main()
-
运行结果:
上文提到的是微调decoder部分,下面补充微调encoder部分的代码:
注意事项:
微调encoder需要更多的计算资源和训练时间
需要更大的训练数据集以避免过拟合
建议使用验证集监控性能,防止模型退化
可能需要更多的训练轮次才能收敛
- import torch
- import numpy as np
- from per_segment_anything import sam_model_registry, SamPredictor
- from per_segment_anything.utils.transforms import ResizeLongestSide
- from torch.utils.data import Dataset, DataLoader
- from torch.cuda.amp import autocast, GradScaler
- import cv2
- import os
- from tqdm import tqdm
- import logging
- import json
- from datetime import datetime
- from train_setimage import preprocess
- os.environ['KMP_DUPLICATE_LIB_OK']='TRUE'
-
- # 设置日志
- logging.basicConfig(level=logging.INFO)
- logger = logging.getLogger(__name__)
-
-
- class StampDataset(Dataset):
- def __init__(self, image_dir, mask_dir, bbox_file, transform=None):
- self.image_dir = image_dir
- self.mask_dir = mask_dir
- self.transform = transform if transform else ResizeLongestSide(1024)
-
- # 加载标注文件
- self.annotations = []
- with open(bbox_file, 'r') as f:
- for line in f:
- img_name, x1, y1, x2, y2 = line.strip().split(',')
- self.annotations.append({
- 'image': img_name,
- 'bbox': [float(x1), float(y1), float(x2), float(y2)]
- })
-
- def __len__(self):
- return len(self.annotations)
-
- def __getitem__(self, idx):
- ann = self.annotations[idx]
-
- # 读取图像
- image_path = os.path.join(self.image_dir, ann['image'])
- image = cv2.imread(image_path)
- if image is None:
- raise ValueError(f"Failed to load image: {image_path}")
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
-
- # 读取mask
- mask_name = ann['image'].replace('.jpg', '_mask.png')
- mask_path = os.path.join(self.mask_dir, mask_name)
- mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
- if mask is None:
- raise ValueError(f"Failed to load mask: {mask_path}")
- mask = mask.astype(np.float32) / 255.0
-
- # 准备图像
- original_size = image.shape[:2]
- input_image = self.transform.apply_image(image)
- input_image = input_image.astype(np.float32) / 255.0
- # 转换为tensor并进行ImageNet归一化
- input_image = torch.from_numpy(input_image).permute(2, 0, 1)
- # Use preprocess to handle ImageNet normalization and padding
- input_image = preprocess(input_image)
- print(f"Processed image shape: {input_image.shape}")
-
- # 准备bbox
- bbox = self.transform.apply_boxes(np.array([ann['bbox']]), original_size)[0]
- bbox_torch = torch.tensor(bbox, dtype=torch.float)
-
- # 准备mask
- mask_torch = torch.from_numpy(mask).float()
-
- return {
- 'image': input_image.float(),
- 'original_size': original_size,
- 'bbox': bbox_torch,
- 'mask': mask_torch,
- 'image_path': image_path # 用于调试
- }
-
-
- class SAMFineTuner:
- def __init__(self, config):
- self.config = config
- self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
- self.setup_model()
- self.setup_datasets()
- self.setup_training()
-
- # 创建输出目录
- os.makedirs(config['output_dir'], exist_ok=True)
-
- # 保存配置
- config_path = os.path.join(config['output_dir'], 'config.json')
- with open(config_path, 'w') as f:
- json.dump(config, f, indent=4)
-
- def setup_model(self):
- logger.info(f"Loading SAM model: {self.config['model_type']}")
- self.model = sam_model_registry[self.config['model_type']](
- checkpoint=self.config['checkpoint_path']
- )
- self.model.to(self.device)
-
- def setup_datasets(self):
- logger.info("Setting up datasets")
- self.train_dataset = StampDataset(
- self.config['train_image_dir'],
- self.config['train_mask_dir'],
- self.config['train_bbox_file']
- )
- # 从训练数据集中按批次加载数据
- self.train_loader = DataLoader(
- self.train_dataset,
- batch_size=self.config['batch_size'],
- shuffle=True,
- num_workers=self.config['num_workers'],
- pin_memory=True
- )
- # 验证集
- if self.config.get('val_bbox_file'):
- self.val_dataset = StampDataset(
- self.config['val_image_dir'],
- self.config['val_mask_dir'],
- self.config['val_bbox_file']
- )
- self.val_loader = DataLoader(
- self.val_dataset,
- batch_size=self.config['batch_size'],
- shuffle=False,
- num_workers=self.config['num_workers'],
- pin_memory=True
- )
-
- def setup_training(self):
- logger.info("Setting up training components")
- # 分别设置encoder和decoder的学习率
- self.optimizer = torch.optim.Adam([
- {'params': self.model.image_encoder.parameters(),
- 'lr': self.config['encoder_lr']},
- {'params': self.model.mask_decoder.parameters(),
- 'lr': self.config['decoder_lr']}
- ])
-
- self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
- self.optimizer,
- mode='min',
- factor=0.5,
- patience=5,
- verbose=True
- )
-
- self.loss_fn = torch.nn.MSELoss()
- self.scaler = GradScaler()
-
- # 记录最佳模型
- self.best_loss = float('inf')
-
- def train_epoch(self, epoch):
- self.model.train()
- total_loss = 0 # 初始化总损失
-
- pbar = tqdm(self.train_loader, desc=f'Epoch {epoch + 1}')
- for batch_idx, batch in enumerate(pbar):
- # 将数据移到GPU
- input_image = batch['image'].to(self.device)
- bbox = batch['bbox'].to(self.device)
- gt_mask = batch['mask'].to(self.device)
-
- self.optimizer.zero_grad()
-
- with autocast():
- # 前向传播
- image_embedding = self.model.image_encoder(input_image)
-
- with torch.no_grad():
- sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
- points=None,
- boxes=bbox,
- masks=None,
- )
-
- mask_predictions, _ = self.model.mask_decoder(
- image_embeddings=image_embedding,
- image_pe=self.model.prompt_encoder.get_dense_pe(),
- sparse_prompt_embeddings=sparse_embeddings,
- dense_prompt_embeddings=dense_embeddings,
- multimask_output=False,
- )
-
- upscaled_masks = self.model.postprocess_masks(
- mask_predictions,
- input_size=input_image.shape[-2:],
- original_size=batch['original_size']
- ).to(self.device)
-
- binary_masks = torch.sigmoid(upscaled_masks)
- loss = self.loss_fn(binary_masks, gt_mask.unsqueeze(1))
-
- # 反向传播
- self.scaler.scale(loss).backward()
- self.scaler.unscale_(self.optimizer)
- torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
- self.scaler.step(self.optimizer)
- self.scaler.update()
-
- total_loss += loss.item()
- pbar.set_postfix({'loss': loss.item()})
-
- return total_loss / len(self.train_loader)
-
- @torch.no_grad()
- def validate(self):
- if not hasattr(self, 'val_loader'):
- return None
-
- self.model.eval()
- total_loss = 0
-
- for batch in tqdm(self.val_loader, desc='Validating'):
- input_image = batch['image'].to(self.device)
- bbox = batch['bbox'].to(self.device)
- gt_mask = batch['mask'].to(self.device)
-
- with autocast():
- image_embedding = self.model.image_encoder(input_image)
- sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
- points=None,
- boxes=bbox,
- masks=None,
- )
-
- mask_predictions, _ = self.model.mask_decoder(
- image_embeddings=image_embedding,
- image_pe=self.model.prompt_encoder.get_dense_pe(),
- sparse_prompt_embeddings=sparse_embeddings,
- dense_prompt_embeddings=dense_embeddings,
- multimask_output=False,
- )
-
- upscaled_masks = self.model.postprocess_masks(
- mask_predictions,
- input_size=input_image.shape[-2:],
- original_size=batch['original_size']
- ).to(self.device)
-
- binary_masks = torch.sigmoid(upscaled_masks)
- loss = self.loss_fn(binary_masks, gt_mask.unsqueeze(1))
-
- total_loss += loss.item()
-
- return total_loss / len(self.val_loader)
-
-
- def save_checkpoint(self, epoch, loss, is_best=False):
- # 保存完整的训练状态(用于恢复训练)
- checkpoint = {
- 'epoch': epoch,
- 'model_state_dict': self.model.state_dict(),
- 'optimizer_state_dict': self.optimizer.state_dict(),
- 'scheduler_state_dict': self.scheduler.state_dict(),
- 'loss': loss,
- 'config': self.config
- }
-
- # 保存完整checkpoint
- checkpoint_path = os.path.join(
- self.config['output_dir'],
- f'checkpoint_epoch_{epoch + 1}.pth'
- )
- torch.save(checkpoint, checkpoint_path)
-
- # 如果是最佳模型,保存兼容格式的模型权重
- if is_best:
- # 保存完整checkpoint
- best_checkpoint_path = os.path.join(self.config['output_dir'], 'best_checkpoint.pth')
- torch.save(checkpoint, best_checkpoint_path)
-
- # 额外保存一个干净的模型权重(兼容原SAM格式)
- best_model_path = os.path.join(self.config['output_dir'], 'best_model_sam_format.pth')
- torch.save(self.model.state_dict(), best_model_path)
- logger.info(f"Saved best model with loss: {loss:.4f}")
-
- def train(self):
- logger.info("Starting training")
- for epoch in range(self.config['num_epochs']):
- train_loss = self.train_epoch(epoch)
- logger.info(f"Epoch {epoch + 1} - Train Loss: {train_loss:.4f}")
-
- val_loss = self.validate()
- if val_loss is not None:
- logger.info(f"Epoch {epoch + 1} - Val Loss: {val_loss:.4f}")
- self.scheduler.step(val_loss)
- is_best = val_loss < self.best_loss
- if is_best:
- self.best_loss = val_loss
- else:
- is_best = False
- self.scheduler.step(train_loss)
-
- if (epoch + 1) % self.config['save_interval'] == 0:
- self.save_checkpoint(
- epoch,
- val_loss if val_loss is not None else train_loss,
- is_best
- )
-
- # 训练结束后保存最终的兼容格式模型
- final_model_path = os.path.join(self.config['output_dir'], 'final_model_sam_format.pth')
- torch.save(self.model.state_dict(), final_model_path)
- logger.info(f"Saved final model in SAM-compatible format: {final_model_path}")
-
-
- def main():
- # 训练配置
- config = {
- 'model_type': 'vit_b',
- 'checkpoint_path': './checkpoints/sam_vit_b_01ec64.pth',
- 'train_image_dir': './stamps/images',
- 'train_mask_dir': './stamps/masks',
- 'train_bbox_file': './stamps/annotations.txt',
-
- 'val_image_dir': './stamps/val_images',
- 'val_mask_dir': './stamps/val_masks',
- 'val_bbox_file': './stamps/val_annotations.txt',
- 'output_dir': f'./outputs/sam_finetune_{datetime.now().strftime("%Y%m%d_%H%M%S")}',
- 'num_epochs': 1,
- 'batch_size': 1,
- 'num_workers': 4,
- 'encoder_lr': 1e-6,
- 'decoder_lr': 1e-5,
- 'save_interval': 5
- }
-
- # 创建训练器并开始训练
- trainer = SAMFineTuner(config)
- trainer.train()
-
-
- if __name__ == '__main__':
- main()
-
-