2025年3月24日 星期一 甲辰(龙)年 月廿三 设为首页 加入收藏
rss
您当前的位置:首页 > 计算机 > 编程开发 > 人工智能

如何微调SAM模型:从环境配置到训练实现的完整指南

时间:02-27来源:作者:点击数:14
城东书院 www.cdsy.xyz

很多朋友来问数据标注是什么格式,因此添加作解答。

运行代码末尾提供的demo,既可以生成标注格式的demo示例。

  • python sam-data-setup.py

数据集目录下,放images文件夹、masks文件夹、和annotations.txt,

在这里插入图片描述

images里放原始图片,这里随机生成的。可在这个文件夹里放入自己的数据。

在这里插入图片描述

images里放对应的掩码图像,并且对应更改文件后缀名,在这个文件夹里放入自己数据对应的标签掩码图像。

在这里插入图片描述

annotations.txt里放图片对应的检测框坐标信息。

在这里插入图片描述

引言

Segment Anything Model (SAM) 是 Meta AI 推出的一个强大的图像分割模型。尽管预训练模型表现优秀,但在特定领域(如医疗影像、工业检测等)可能需要进行微调以获得更好的性能。本文将详细介绍如何微调 SAM 模型,包括环境配置、数据准备和训练实现。

目录

  1. 环境配置
  2. 项目结构
  3. 数据准备
  4. 模型微调
  5. 训练过程
  6. 注意事项和优化建议

1. 环境配置

首先,我们需要配置正确的 Python 环境和依赖包。推荐使用虚拟环境来管理依赖:

  • # 创建并激活虚拟环境
  • python -m venv sam_env
  • # Windows:
  • .\sam_env\Scripts\activate
  • # Linux/Mac:
  • source sam_env/bin/activate
  • # 安装依赖
  • pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
  • pip install opencv-python
  • pip install git+https://github.com/facebookresearch/segment-anything.git
  • pip install numpy matplotlib
  • # 下载预训练模型
  • # Windows PowerShell:
  • Invoke-WebRequest -Uri "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth" -OutFile "sam_vit_b_01ec64.pth"
  • # Linux/Mac:
  • wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth

2. 项目结构

推荐的项目结构如下:

  • project_root/
  • ├── stamps/
  • │ ├── images/ # 训练图像
  • │ ├── masks/ # 分割掩码
  • │ └── annotations.txt # 边界框标注
  • ├── checkpoints/ # 模型检查点
  • ├── setup_sam_data.py # 数据准备脚本
  • └── sam_finetune.py # 训练脚本

3. 数据准备

为了训练模型,我们需要准备以下数据:

  • 训练图像
  • 分割掩码
  • 边界框标注

以下是数据准备脚本的实现:

  • import os
  • import numpy as np
  • import cv2
  • from pathlib import Path
  • def create_project_structure():
  • """创建项目所需的目录结构"""
  • directories = [
  • './stamps/images',
  • './stamps/masks',
  • './checkpoints'
  • ]
  • for dir_path in directories:
  • Path(dir_path).mkdir(parents=True, exist_ok=True)
  • return directories
  • def create_sample_data(num_samples=5):
  • """创建示例训练数据"""
  • annotations = []
  • for i in range(num_samples):
  • # 创建示例图像
  • image = np.ones((500, 500, 3), dtype=np.uint8) * 255
  • center_x = np.random.randint(150, 350)
  • center_y = np.random.randint(150, 350)
  • radius = np.random.randint(50, 100)
  • # 绘制对象
  • cv2.circle(image, (center_x, center_y), radius, (0, 0, 255), -1)
  • # 创建掩码
  • mask = np.zeros((500, 500), dtype=np.uint8)
  • cv2.circle(mask, (center_x, center_y), radius, 255, -1)
  • # 保存文件
  • cv2.imwrite(f'./stamps/images/sample_{i}.jpg', image)
  • cv2.imwrite(f'./stamps/masks/sample_{i}_mask.png', mask)
  • # 计算边界框
  • x1 = max(0, center_x - radius)
  • y1 = max(0, center_y - radius)
  • x2 = min(500, center_x + radius)
  • y2 = min(500, center_y + radius)
  • annotations.append(f'sample_{i}.jpg,{x1},{y1},{x2},{y2}\n')
  • # 保存标注文件
  • with open('./stamps/annotations.txt', 'w') as f:
  • f.writelines(annotations)

4. 模型微调

4.1 数据集类实现

首先实现自定义数据集类:

  • class StampDataset(Dataset):
  • def __init__(self, image_dir, mask_dir, bbox_file):
  • self.image_dir = image_dir
  • self.mask_dir = mask_dir
  • self.transform = ResizeLongestSide(1024)
  • # 加载标注
  • self.annotations = []
  • with open(bbox_file, 'r') as f:
  • for line in f:
  • img_name, x1, y1, x2, y2 = line.strip().split(',')
  • self.annotations.append({
  • 'image': img_name,
  • 'bbox': [float(x1), float(y1), float(x2), float(y2)]
  • })
  • def __len__(self):
  • return len(self.annotations)
  • def __getitem__(self, idx):
  • ann = self.annotations[idx]
  • # 加载和预处理图像
  • image = cv2.imread(os.path.join(self.image_dir, ann['image']))
  • image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
  • mask = cv2.imread(os.path.join(self.mask_dir,
  • ann['image'].replace('.jpg', '_mask.png')),
  • cv2.IMREAD_GRAYSCALE)
  • mask = mask.astype(np.float32) / 255.0
  • # 图像处理
  • original_size = image.shape[:2]
  • input_image = self.transform.apply_image(image)
  • input_image = input_image.astype(np.float32) / 255.0
  • input_image = torch.from_numpy(input_image).permute(2, 0, 1)
  • # 标准化
  • mean = torch.tensor([0.485, 0.456, 0.406]).view(-1, 1, 1)
  • std = torch.tensor([0.229, 0.224, 0.225]).view(-1, 1, 1)
  • input_image = (input_image - mean) / std
  • # 处理边界框和掩码
  • bbox = self.transform.apply_boxes(np.array([ann['bbox']]), original_size)[0]
  • bbox_torch = torch.tensor(bbox, dtype=torch.float).unsqueeze(0)
  • mask_torch = torch.from_numpy(mask).float().unsqueeze(0)
  • return {
  • 'image': input_image.float(),
  • 'original_size': original_size,
  • 'bbox': bbox_torch,
  • 'mask': mask_torch
  • }
4.2 训练函数实现

训练函数的核心实现:

  • def train_sam(
  • model_type='vit_b',
  • checkpoint_path='sam_vit_b_01ec64.pth',
  • num_epochs=10,
  • batch_size=1,
  • learning_rate=1e-5
  • ):
  • device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  • # 初始化模型
  • sam_model = sam_model_registry[model_type](checkpoint=checkpoint_path)
  • sam_model.to(device)
  • # 准备数据和优化器
  • dataset = StampDataset(image_dir='./stamps/images',
  • mask_dir='./stamps/masks',
  • bbox_file='./stamps/annotations.txt')
  • dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
  • optimizer = torch.optim.Adam(sam_model.mask_decoder.parameters(), lr=learning_rate)
  • loss_fn = torch.nn.MSELoss()
  • # 训练循环
  • for epoch in range(num_epochs):
  • total_loss = 0
  • for batch_idx, batch in enumerate(dataloader):
  • # 准备数据
  • input_image = batch['image'].to(device)
  • original_size = batch['original_size']
  • bbox = batch['bbox'].to(device)
  • gt_mask = batch['mask'].to(device)
  • # 前向传播
  • with torch.no_grad():
  • image_embedding = sam_model.image_encoder(input_image)
  • sparse_embeddings, dense_embeddings = sam_model.prompt_encoder(
  • points=None,
  • boxes=bbox,
  • masks=None,
  • )
  • # 生成预测
  • mask_predictions, _ = sam_model.mask_decoder(
  • image_embeddings=image_embedding,
  • image_pe=sam_model.prompt_encoder.get_dense_pe(),
  • sparse_prompt_embeddings=sparse_embeddings,
  • dense_prompt_embeddings=dense_embeddings,
  • multimask_output=False,
  • )
  • # 后处理
  • upscaled_masks = sam_model.postprocess_masks(
  • mask_predictions,
  • input_size=input_image.shape[-2:],
  • original_size=original_size[0]
  • ).to(device)
  • binary_masks = torch.sigmoid(upscaled_masks)
  • # 计算损失并优化
  • loss = loss_fn(binary_masks, gt_mask)
  • optimizer.zero_grad()
  • loss.backward()
  • optimizer.step()
  • total_loss += loss.item()
  • if batch_idx % 10 == 0:
  • print(f'Epoch: {epoch}, Batch: {batch_idx}, Loss: {loss.item():.4f}')
  • # 输出epoch统计
  • avg_loss = total_loss / len(dataloader)
  • print(f'Epoch {epoch} completed. Average Loss: {avg_loss:.4f}')
  • # 保存检查点
  • if (epoch + 1) % 5 == 0:
  • checkpoint_file = f'./checkpoints/sam_finetuned_epoch_{epoch+1}.pth'
  • torch.save(sam_model.state_dict(), checkpoint_file)

5. 训练过程

完整的训练过程如下:

  1. 准备环境和数据:
  • python setup_sam_data.py
在这里插入图片描述
  1. 开始训练:
  • python sam_finetune.py
在这里插入图片描述

6. 注意事项和优化建议

  1. 数据预处理:
    • 确保图像数据类型正确(float32)
    • 进行适当的数据标准化
    • 注意图像尺寸的一致性
  2. 训练优化:
    • 根据GPU内存调整batch_size
    • 适当调整学习率
    • 考虑使用学习率调度器
    • 添加验证集评估
    • 实现早停机制
  3. 可能的改进:
    • 添加数据增强
    • 使用不同的损失函数
    • 实现多GPU训练
    • 添加训练过程可视化
    • 实现模型验证和测试

7. 模型预测和可视化

在完成模型微调后,我们需要一个方便的方式来使用模型进行预测并可视化结果。以下是完整的实现:

7.1 预测器类实现

首先,我们封装一个预测器类,用于处理模型加载、图像预处理和预测:

  • class SAMPredictor:
  • def __init__(self, checkpoint_path, model_type="vit_b", device="cuda"):
  • self.device = torch.device(device if torch.cuda.is_available() and device == "cuda" else "cpu")
  • self.sam_model = sam_model_registry[model_type](checkpoint=checkpoint_path)
  • self.sam_model.to(self.device)
  • self.transform = ResizeLongestSide(1024)

这个类提供了简单的接口来加载模型并进行预测。主要功能包括:

  • 模型加载和设备配置
  • 图像预处理
  • 掩码预测
  • 后处理优化
7.2 可视化函数

为了better展示预测结果,我们实现了一个可视化函数:

  • def visualize_prediction(image, mask, bbox, confidence, save_path=None):
  • plt.figure(figsize=(15, 5))
  • # 显示原始图像、预测掩码和叠加结果
  • ...

这个函数可以同时显示:

  • 原始图像(带边界框)
  • 预测的分割掩码
  • 结果叠加视图
7.3 使用示例

以下是如何使用这些工具的完整示例:

  • # 初始化预测器
  • predictor = SAMPredictor("./checkpoints/sam_finetuned_final.pth")
  • # 读取测试图像
  • image = cv2.imread("test_image.jpg")
  • bbox = [x1, y1, x2, y2] # 边界框坐标
  • # 预测
  • mask, confidence = predictor.predict(image, bbox)
  • # 可视化
  • visualize_prediction(image, mask, bbox, confidence, "result.png")
在这里插入图片描述
7.4 注意事项

在使用预测器时,需要注意以下几点:

  1. 输入图像处理:
    • 确保图像格式正确(RGB)
    • 注意图像尺寸的一致性
    • 正确的数据类型和范围
  2. 边界框格式:
    • 使用 [x1, y1, x2, y2] 格式
    • 确保坐标在图像范围内
    • 坐标值为浮点数
  3. 性能优化:
    • 批处理预测
    • GPU 内存管理
    • 结果缓存
7.5 可能的改进
  1. 批量处理功能:
  • def predict_batch(self, images, bboxes):
  • results = []
  • for image, bbox in zip(images, bboxes):
  • mask, conf = self.predict(image, bbox)
  • results.append((mask, conf))
  • return results
  1. 多边界框支持:
  • def predict_multiple_boxes(self, image, bboxes):
  • masks = []
  • for bbox in bboxes:
  • mask, _ = self.predict(image, bbox)
  • masks.append(mask)
  • return np.stack(masks)
  1. 交互式可视化:
  • def interactive_visualization(image, predictor):
  • def onclick(event):
  • if event.button == 1: # 左键点击
  • bbox = [event.xdata-50, event.ydata-50,
  • event.xdata+50, event.ydata+50]
  • mask, _ = predictor.predict(image, bbox)
  • visualize_prediction(image, mask, bbox)
  • fig, ax = plt.subplots()
  • ax.imshow(image)
  • fig.canvas.mpl_connect('button_press_event', onclick)
  • plt.show()

这些工具和示例可以帮助你更好地理解和使用微调后的SAM模型。根据具体需求,你可以进一步优化和扩展这些功能。

结论

通过以上步骤,我们实现了SAM模型的微调过程。这个实现可以作为基础,根据具体需求进行优化和改进。在实际应用中,可能需要根据具体任务调整数据预处理、损失函数和训练策略。

建议在使用时注意以下几点:

  1. 确保训练数据质量
  2. 合理设置训练参数
  3. 定期保存检查点
  4. 监控训练过程
  5. 适当使用数据增强

希望这个教程对你的项目有所帮助!如果有任何问题,欢迎讨论和交流。

参考资料

  1. Segment Anything 官方仓库
  2. PyTorch 文档
  3. SAM 论文:Segment Anything
  4. torchvision 文档

快速部署:

下载这三个代码,配置好运行环境,依次运行:

  • # sam-data-setup.py
  • import os
  • import numpy as np
  • import cv2
  • from pathlib import Path
  • def create_project_structure():
  • """创建项目所需的目录结构"""
  • # 创建主目录
  • directories = [
  • './stamps/images',
  • './stamps/masks',
  • './checkpoints'
  • ]
  • for dir_path in directories:
  • Path(dir_path).mkdir(parents=True, exist_ok=True)
  • return directories
  • def create_sample_data(num_samples=5):
  • """创建示例训练数据"""
  • # 创建示例图像和掩码
  • annotations = []
  • for i in range(num_samples):
  • # 创建示例图像 (500x500)
  • image = np.ones((500, 500, 3), dtype=np.uint8) * 255
  • # 添加一个示例印章 (随机位置的圆形)
  • center_x = np.random.randint(150, 350)
  • center_y = np.random.randint(150, 350)
  • radius = np.random.randint(50, 100)
  • # 绘制印章
  • cv2.circle(image, (center_x, center_y), radius, (0, 0, 255), -1)
  • # 创建对应的掩码
  • mask = np.zeros((500, 500), dtype=np.uint8)
  • cv2.circle(mask, (center_x, center_y), radius, 255, -1)
  • # 保存图像和掩码
  • image_path = f'./stamps/images/sample_{i}.jpg'
  • mask_path = f'./stamps/masks/sample_{i}_mask.png'
  • cv2.imwrite(image_path, image)
  • cv2.imwrite(mask_path, mask)
  • # 计算边界框
  • x1 = max(0, center_x - radius)
  • y1 = max(0, center_y - radius)
  • x2 = min(500, center_x + radius)
  • y2 = min(500, center_y + radius)
  • # 添加到注释列表
  • annotations.append(f'sample_{i}.jpg,{x1},{y1},{x2},{y2}\n')
  • # 保存注释文件
  • with open('./stamps/annotations.txt', 'w') as f:
  • f.writelines(annotations)
  • def main():
  • print("开始创建项目结构...")
  • directories = create_project_structure()
  • for dir_path in directories:
  • print(f"创建目录: {dir_path}")
  • print("\n创建示例训练数据...")
  • create_sample_data()
  • print("示例数据创建完成!")
  • print("\n项目结构:")
  • for root, dirs, files in os.walk('./stamps'):
  • level = root.replace('./stamps', '').count(os.sep)
  • indent = ' ' * 4 * level
  • print(f"{indent}{os.path.basename(root)}/")
  • sub_indent = ' ' * 4 * (level + 1)
  • for f in files:
  • print(f"{sub_indent}{f}")
  • if __name__ == '__main__':
  • main()
  • # sam_finetune_decoder.py
  • import torch
  • import numpy as np
  • from segment_anything import sam_model_registry, SamPredictor
  • from segment_anything.utils.transforms import ResizeLongestSide
  • from torch.utils.data import Dataset, DataLoader
  • import cv2
  • import os
  • class StampDataset(Dataset):
  • def __init__(self, image_dir, mask_dir, bbox_file, target_size=(1024, 1024)):
  • self.image_dir = image_dir
  • self.mask_dir = mask_dir
  • self.target_size = target_size
  • self.transform = ResizeLongestSide(1024) # SAM default size
  • # Load bbox annotations
  • self.annotations = []
  • with open(bbox_file, 'r') as f:
  • for line in f:
  • img_name, x1, y1, x2, y2 = line.strip().split(',')
  • self.annotations.append({
  • 'image': img_name,
  • 'bbox': [float(x1), float(y1), float(x2), float(y2)]
  • })
  • def resize_with_bbox(self, image, mask, bbox):
  • """调整图像、掩码和边界框的大小"""
  • h, w = image.shape[:2]
  • target_h, target_w = self.target_size
  • # 计算缩放比例
  • scale_x = target_w / w
  • scale_y = target_h / h
  • # 调整图像大小
  • resized_image = cv2.resize(image, (target_w, target_h), interpolation=cv2.INTER_LINEAR)
  • # 调整掩码大小
  • if mask is not None:
  • resized_mask = cv2.resize(mask, (target_w, target_h), interpolation=cv2.INTER_NEAREST)
  • else:
  • resized_mask = None
  • # 调整边界框
  • resized_bbox = [
  • bbox[0] * scale_x, # x1
  • bbox[1] * scale_y, # y1
  • bbox[2] * scale_x, # x2
  • bbox[3] * scale_y # y2
  • ]
  • return resized_image, resized_mask, resized_bbox
  • def __len__(self):
  • return len(self.annotations)
  • def __getitem__(self, idx):
  • ann = self.annotations[idx]
  • # Load image
  • image = cv2.imread(os.path.join(self.image_dir, ann['image']))
  • image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
  • # Load mask
  • mask_name = ann['image'].replace('.jpg', '_mask.png')
  • mask = cv2.imread(os.path.join(self.mask_dir, mask_name), cv2.IMREAD_GRAYSCALE)
  • mask = mask.astype(np.float32) / 255.0
  • # 首先将图像调整为统一大小
  • image, mask, bbox = self.resize_with_bbox(image, mask, ann['bbox'])
  • # 准备图像
  • original_size = self.target_size
  • input_image = self.transform.apply_image(image)
  • # Convert to float32 and normalize to 0-1 range
  • input_image = input_image.astype(np.float32) / 255.0
  • # Convert to tensor and normalize according to ImageNet stats
  • input_image = torch.from_numpy(input_image).permute(2, 0, 1).contiguous()
  • # Apply ImageNet normalization
  • mean = torch.tensor([0.485, 0.456, 0.406]).view(-1, 1, 1)
  • std = torch.tensor([0.229, 0.224, 0.225]).view(-1, 1, 1)
  • input_image = (input_image - mean) / std
  • # Prepare bbox
  • bbox = self.transform.apply_boxes(np.array([bbox]), original_size)[0]
  • bbox_torch = torch.tensor(bbox, dtype=torch.float).unsqueeze(0)
  • # Prepare mask
  • mask_torch = torch.from_numpy(mask).float().unsqueeze(0)
  • return {
  • 'image': input_image.float(),
  • 'original_size': original_size,
  • 'bbox': bbox_torch,
  • 'mask': mask_torch
  • }
  • def train_sam(
  • model_type='vit_b',
  • checkpoint_path='sam_vit_b_01ec64.pth',
  • image_dir='./stamps/images',
  • mask_dir='./stamps/masks',
  • bbox_file='./stamps/annotations.txt',
  • output_dir='./checkpoints',
  • num_epochs=10,
  • batch_size=1,
  • learning_rate=1e-5
  • ):
  • # Setup device
  • device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  • print(f"Using device: {device}")
  • # Initialize model
  • sam_model = sam_model_registry[model_type](checkpoint=checkpoint_path)
  • sam_model.to(device)
  • # Prepare dataset
  • dataset = StampDataset(image_dir, mask_dir, bbox_file)
  • dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
  • # Setup optimizer
  • optimizer = torch.optim.Adam(sam_model.mask_decoder.parameters(), lr=learning_rate)
  • # Loss function
  • loss_fn = torch.nn.MSELoss()
  • # Training loop
  • for epoch in range(num_epochs):
  • total_loss = 0
  • for batch_idx, batch in enumerate(dataloader):
  • # Move inputs to device
  • input_image = batch['image'].to(device)
  • original_size = batch['original_size']
  • bbox = batch['bbox'].to(device)
  • gt_mask = batch['mask'].to(device)
  • # Print shapes and types for debugging
  • if batch_idx == 0 and epoch == 0:
  • print(f"Input image shape: {input_image.shape}")
  • print(f"Input image type: {input_image.dtype}")
  • print(f"Input image range: [{input_image.min():.2f}, {input_image.max():.2f}]")
  • # Get image embedding (without gradient)
  • with torch.no_grad():
  • image_embedding = sam_model.image_encoder(input_image)
  • # Get prompt embeddings
  • sparse_embeddings, dense_embeddings = sam_model.prompt_encoder(
  • points=None,
  • boxes=bbox,
  • masks=None,
  • )
  • # Generate mask prediction
  • mask_predictions, iou_predictions = sam_model.mask_decoder(
  • image_embeddings=image_embedding,
  • image_pe=sam_model.prompt_encoder.get_dense_pe(),
  • sparse_prompt_embeddings=sparse_embeddings,
  • dense_prompt_embeddings=dense_embeddings,
  • multimask_output=False,
  • )
  • # Upscale masks to original size
  • upscaled_masks = sam_model.postprocess_masks(
  • mask_predictions,
  • input_size=input_image.shape[-2:],
  • original_size=original_size[0]
  • ).to(device)
  • # Convert to binary mask
  • binary_masks = torch.sigmoid(upscaled_masks)
  • # Calculate loss
  • loss = loss_fn(binary_masks, gt_mask)
  • # Optimize
  • optimizer.zero_grad()
  • loss.backward()
  • optimizer.step()
  • total_loss += loss.item()
  • if batch_idx % 10 == 0:
  • print(f'Epoch: {epoch}, Batch: {batch_idx}, Loss: {loss.item():.4f}')
  • avg_loss = total_loss / len(dataloader)
  • print(f'Epoch {epoch} completed. Average Loss: {avg_loss:.4f}')
  • # Save checkpoint
  • if (epoch + 1) % 5 == 0:
  • checkpoint_file = os.path.join(output_dir, f'sam_finetuned_epoch_{epoch+1}.pth')
  • torch.save(sam_model.state_dict(), checkpoint_file)
  • print(f'Checkpoint saved: {checkpoint_file}')
  • # Save final model
  • final_checkpoint = os.path.join(output_dir, 'sam_finetuned_final.pth')
  • torch.save(sam_model.state_dict(), final_checkpoint)
  • print(f'Final model saved to {final_checkpoint}')
  • if __name__ == '__main__':
  • # Create output directory if it doesn't exist
  • os.makedirs('./checkpoints', exist_ok=True)
  • # Start training
  • train_sam()
  • import torch
  • import numpy as np
  • import matplotlib.pyplot as plt
  • from segment_anything import sam_model_registry, SamPredictor
  • from segment_anything.utils.transforms import ResizeLongestSide
  • import cv2
  • from pathlib import Path
  • class SAMPredictor:
  • def __init__(self, checkpoint_path, model_type="vit_b", device="cuda"):
  • """
  • 初始化SAM预测器
  • Args:
  • checkpoint_path: 模型权重路径
  • model_type: 模型类型 ("vit_h", "vit_l", "vit_b")
  • device: 使用设备 ("cuda" or "cpu")
  • """
  • self.device = torch.device(device if torch.cuda.is_available() and device == "cuda" else "cpu")
  • print(f"Using device: {self.device}")
  • # 加载模型
  • self.sam_model = sam_model_registry[model_type](checkpoint=checkpoint_path)
  • self.sam_model.to(self.device)
  • # 创建图像变换器
  • self.transform = ResizeLongestSide(1024)
  • def resize_bbox(self, bbox, original_size, target_size=(1024, 1024)):
  • """
  • 调整边界框坐标以匹配调整大小后的图像
  • Args:
  • bbox: 原始边界框坐标 [x1, y1, x2, y2]
  • original_size: 原始图像尺寸 (height, width)
  • target_size: 目标图像尺寸 (height, width)
  • Returns:
  • resized_bbox: 调整后的边界框坐标
  • """
  • orig_h, orig_w = original_size
  • target_h, target_w = target_size
  • # 计算缩放比例
  • scale_x = target_w / orig_w
  • scale_y = target_h / orig_h
  • # 调整边界框坐标
  • x1, y1, x2, y2 = bbox
  • resized_bbox = [
  • x1 * scale_x,
  • y1 * scale_y,
  • x2 * scale_x,
  • y2 * scale_y
  • ]
  • return resized_bbox
  • def preprocess_image(self, image):
  • """预处理输入图像"""
  • # 保存原始尺寸
  • original_size = image.shape[:2]
  • # 确保图像是RGB格式
  • if len(image.shape) == 2:
  • image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
  • elif image.shape[2] == 4:
  • image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGB)
  • elif len(image.shape) == 3 and image.shape[2] == 3:
  • image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
  • # 调整图像大小
  • image = cv2.resize(image, (1024, 1024), interpolation=cv2.INTER_LINEAR)
  • # 转换为float32并归一化
  • input_image = image.astype(np.float32) / 255.0
  • # 转换为tensor并添加batch维度
  • input_image = torch.from_numpy(input_image).permute(2, 0, 1).unsqueeze(0)
  • # 标准化
  • mean = torch.tensor([0.485, 0.456, 0.406]).view(-1, 1, 1)
  • std = torch.tensor([0.229, 0.224, 0.225]).view(-1, 1, 1)
  • input_image = (input_image - mean) / std
  • return input_image.to(self.device), original_size, image
  • def predict(self, image, bbox):
  • """
  • 预测单个图像的分割掩码
  • Args:
  • image: numpy array 格式的图像
  • bbox: [x1, y1, x2, y2] 格式的边界框
  • Returns:
  • binary_mask: 二值化的分割掩码
  • confidence: 预测的置信度
  • """
  • # 预处理图像
  • input_image, original_size, resized_image = self.preprocess_image(image)
  • # 调整边界框大小
  • resized_bbox = self.resize_bbox(bbox, original_size)
  • print(resized_bbox, image.shape, resized_image.shape)
  • # 准备边界框
  • bbox_torch = torch.tensor(resized_bbox, dtype=torch.float, device=self.device).unsqueeze(0)
  • # 获取图像嵌入
  • with torch.no_grad():
  • image_embedding = self.sam_model.image_encoder(input_image)
  • # 获取提示嵌入
  • sparse_embeddings, dense_embeddings = self.sam_model.prompt_encoder(
  • points=None,
  • boxes=bbox_torch,
  • masks=None,
  • )
  • # 生成掩码预测
  • mask_predictions, iou_predictions = self.sam_model.mask_decoder(
  • image_embeddings=image_embedding,
  • image_pe=self.sam_model.prompt_encoder.get_dense_pe(),
  • sparse_prompt_embeddings=sparse_embeddings,
  • dense_prompt_embeddings=dense_embeddings,
  • multimask_output=False,
  • )
  • # 后处理掩码
  • upscaled_masks = self.sam_model.postprocess_masks(
  • mask_predictions,
  • input_size=input_image.shape[-2:],
  • original_size=original_size
  • ).to(self.device)
  • # 转换为二值掩码
  • binary_mask = torch.sigmoid(upscaled_masks) > 0.5
  • return binary_mask[0, 0].cpu().numpy(), iou_predictions[0, 0].item()
  • def visualize_prediction(image, mask, bbox, confidence, save_path=None):
  • """
  • 可视化预测结果
  • Args:
  • image: 原始图像
  • mask: 预测的掩码
  • bbox: 边界框坐标
  • confidence: 预测置信度
  • save_path: 保存路径(可选)
  • """
  • # 创建图形
  • plt.figure(figsize=(15, 5))
  • # 显示原始图像
  • plt.subplot(131)
  • plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
  • plt.title('Original Image')
  • # 绘制边界框
  • x1, y1, x2, y2 = map(int, bbox)
  • plt.plot([x1, x2, x2, x1, x1], [y1, y1, y2, y2, y1], 'r-', linewidth=2)
  • plt.axis('off')
  • # 显示预测掩码
  • plt.subplot(132)
  • plt.imshow(mask, cmap='gray')
  • plt.title(f'Predicted Mask\nConfidence: {confidence:.2f}')
  • plt.axis('off')
  • # 显示叠加结果
  • plt.subplot(133)
  • overlay = image.copy()
  • overlay[mask > 0] = overlay[mask > 0] * 0.7 + np.array([0, 255, 0], dtype=np.uint8) * 0.3
  • plt.imshow(cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB))
  • plt.title('Overlay')
  • plt.axis('off')
  • plt.tight_layout()
  • if save_path:
  • plt.savefig(save_path)
  • print(f"结果已保存到: {save_path}")
  • plt.show()
  • def main():
  • # 配置参数
  • checkpoint_path = "./checkpoints/sam_finetuned_final.pth" # 使用微调后的模型
  • test_image_path = "./stamps/images/sample_0.jpg"
  • output_dir = "./predictions"
  • # 创建输出目录
  • Path(output_dir).mkdir(parents=True, exist_ok=True)
  • # 初始化预测器
  • predictor = SAMPredictor(checkpoint_path)
  • # 读取测试图像
  • image = cv2.imread(test_image_path)
  • # image = cv2.resize(image, (1024, 1024), interpolation=cv2.INTER_LINEAR)
  • # 读取边界框(这里使用示例边界框,实际应用中可能需要从标注文件读取)
  • with open('./stamps/annotations.txt', 'r') as f:
  • first_line = f.readline().strip()
  • _, x1, y1, x2, y2 = first_line.split(',')
  • bbox = [float(x1), float(y1), float(x2), float(y2)]
  • print(bbox)
  • # 进行预测
  • mask, confidence = predictor.predict(image, bbox)
  • # 可视化结果
  • save_path = str(Path(output_dir) / "prediction_result.png")
  • visualize_prediction(image, mask, bbox, confidence, save_path)
  • if __name__ == "__main__":
  • main()

运行结果:

在这里插入图片描述


分割线



补充2:

上文提到的是微调decoder部分,下面补充微调encoder部分的代码:

注意事项:

微调encoder需要更多的计算资源和训练时间

需要更大的训练数据集以避免过拟合

建议使用验证集监控性能,防止模型退化

可能需要更多的训练轮次才能收敛

  • import torch
  • import numpy as np
  • from per_segment_anything import sam_model_registry, SamPredictor
  • from per_segment_anything.utils.transforms import ResizeLongestSide
  • from torch.utils.data import Dataset, DataLoader
  • from torch.cuda.amp import autocast, GradScaler
  • import cv2
  • import os
  • from tqdm import tqdm
  • import logging
  • import json
  • from datetime import datetime
  • from train_setimage import preprocess
  • os.environ['KMP_DUPLICATE_LIB_OK']='TRUE'
  • # 设置日志
  • logging.basicConfig(level=logging.INFO)
  • logger = logging.getLogger(__name__)
  • class StampDataset(Dataset):
  • def __init__(self, image_dir, mask_dir, bbox_file, transform=None):
  • self.image_dir = image_dir
  • self.mask_dir = mask_dir
  • self.transform = transform if transform else ResizeLongestSide(1024)
  • # 加载标注文件
  • self.annotations = []
  • with open(bbox_file, 'r') as f:
  • for line in f:
  • img_name, x1, y1, x2, y2 = line.strip().split(',')
  • self.annotations.append({
  • 'image': img_name,
  • 'bbox': [float(x1), float(y1), float(x2), float(y2)]
  • })
  • def __len__(self):
  • return len(self.annotations)
  • def __getitem__(self, idx):
  • ann = self.annotations[idx]
  • # 读取图像
  • image_path = os.path.join(self.image_dir, ann['image'])
  • image = cv2.imread(image_path)
  • if image is None:
  • raise ValueError(f"Failed to load image: {image_path}")
  • image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
  • # 读取mask
  • mask_name = ann['image'].replace('.jpg', '_mask.png')
  • mask_path = os.path.join(self.mask_dir, mask_name)
  • mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
  • if mask is None:
  • raise ValueError(f"Failed to load mask: {mask_path}")
  • mask = mask.astype(np.float32) / 255.0
  • # 准备图像
  • original_size = image.shape[:2]
  • input_image = self.transform.apply_image(image)
  • input_image = input_image.astype(np.float32) / 255.0
  • # 转换为tensor并进行ImageNet归一化
  • input_image = torch.from_numpy(input_image).permute(2, 0, 1)
  • # Use preprocess to handle ImageNet normalization and padding
  • input_image = preprocess(input_image)
  • print(f"Processed image shape: {input_image.shape}")
  • # 准备bbox
  • bbox = self.transform.apply_boxes(np.array([ann['bbox']]), original_size)[0]
  • bbox_torch = torch.tensor(bbox, dtype=torch.float)
  • # 准备mask
  • mask_torch = torch.from_numpy(mask).float()
  • return {
  • 'image': input_image.float(),
  • 'original_size': original_size,
  • 'bbox': bbox_torch,
  • 'mask': mask_torch,
  • 'image_path': image_path # 用于调试
  • }
  • class SAMFineTuner:
  • def __init__(self, config):
  • self.config = config
  • self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  • self.setup_model()
  • self.setup_datasets()
  • self.setup_training()
  • # 创建输出目录
  • os.makedirs(config['output_dir'], exist_ok=True)
  • # 保存配置
  • config_path = os.path.join(config['output_dir'], 'config.json')
  • with open(config_path, 'w') as f:
  • json.dump(config, f, indent=4)
  • def setup_model(self):
  • logger.info(f"Loading SAM model: {self.config['model_type']}")
  • self.model = sam_model_registry[self.config['model_type']](
  • checkpoint=self.config['checkpoint_path']
  • )
  • self.model.to(self.device)
  • def setup_datasets(self):
  • logger.info("Setting up datasets")
  • self.train_dataset = StampDataset(
  • self.config['train_image_dir'],
  • self.config['train_mask_dir'],
  • self.config['train_bbox_file']
  • )
  • # 从训练数据集中按批次加载数据
  • self.train_loader = DataLoader(
  • self.train_dataset,
  • batch_size=self.config['batch_size'],
  • shuffle=True,
  • num_workers=self.config['num_workers'],
  • pin_memory=True
  • )
  • # 验证集
  • if self.config.get('val_bbox_file'):
  • self.val_dataset = StampDataset(
  • self.config['val_image_dir'],
  • self.config['val_mask_dir'],
  • self.config['val_bbox_file']
  • )
  • self.val_loader = DataLoader(
  • self.val_dataset,
  • batch_size=self.config['batch_size'],
  • shuffle=False,
  • num_workers=self.config['num_workers'],
  • pin_memory=True
  • )
  • def setup_training(self):
  • logger.info("Setting up training components")
  • # 分别设置encoder和decoder的学习率
  • self.optimizer = torch.optim.Adam([
  • {'params': self.model.image_encoder.parameters(),
  • 'lr': self.config['encoder_lr']},
  • {'params': self.model.mask_decoder.parameters(),
  • 'lr': self.config['decoder_lr']}
  • ])
  • self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
  • self.optimizer,
  • mode='min',
  • factor=0.5,
  • patience=5,
  • verbose=True
  • )
  • self.loss_fn = torch.nn.MSELoss()
  • self.scaler = GradScaler()
  • # 记录最佳模型
  • self.best_loss = float('inf')
  • def train_epoch(self, epoch):
  • self.model.train()
  • total_loss = 0 # 初始化总损失
  • pbar = tqdm(self.train_loader, desc=f'Epoch {epoch + 1}')
  • for batch_idx, batch in enumerate(pbar):
  • # 将数据移到GPU
  • input_image = batch['image'].to(self.device)
  • bbox = batch['bbox'].to(self.device)
  • gt_mask = batch['mask'].to(self.device)
  • self.optimizer.zero_grad()
  • with autocast():
  • # 前向传播
  • image_embedding = self.model.image_encoder(input_image)
  • with torch.no_grad():
  • sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
  • points=None,
  • boxes=bbox,
  • masks=None,
  • )
  • mask_predictions, _ = self.model.mask_decoder(
  • image_embeddings=image_embedding,
  • image_pe=self.model.prompt_encoder.get_dense_pe(),
  • sparse_prompt_embeddings=sparse_embeddings,
  • dense_prompt_embeddings=dense_embeddings,
  • multimask_output=False,
  • )
  • upscaled_masks = self.model.postprocess_masks(
  • mask_predictions,
  • input_size=input_image.shape[-2:],
  • original_size=batch['original_size']
  • ).to(self.device)
  • binary_masks = torch.sigmoid(upscaled_masks)
  • loss = self.loss_fn(binary_masks, gt_mask.unsqueeze(1))
  • # 反向传播
  • self.scaler.scale(loss).backward()
  • self.scaler.unscale_(self.optimizer)
  • torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
  • self.scaler.step(self.optimizer)
  • self.scaler.update()
  • total_loss += loss.item()
  • pbar.set_postfix({'loss': loss.item()})
  • return total_loss / len(self.train_loader)
  • @torch.no_grad()
  • def validate(self):
  • if not hasattr(self, 'val_loader'):
  • return None
  • self.model.eval()
  • total_loss = 0
  • for batch in tqdm(self.val_loader, desc='Validating'):
  • input_image = batch['image'].to(self.device)
  • bbox = batch['bbox'].to(self.device)
  • gt_mask = batch['mask'].to(self.device)
  • with autocast():
  • image_embedding = self.model.image_encoder(input_image)
  • sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
  • points=None,
  • boxes=bbox,
  • masks=None,
  • )
  • mask_predictions, _ = self.model.mask_decoder(
  • image_embeddings=image_embedding,
  • image_pe=self.model.prompt_encoder.get_dense_pe(),
  • sparse_prompt_embeddings=sparse_embeddings,
  • dense_prompt_embeddings=dense_embeddings,
  • multimask_output=False,
  • )
  • upscaled_masks = self.model.postprocess_masks(
  • mask_predictions,
  • input_size=input_image.shape[-2:],
  • original_size=batch['original_size']
  • ).to(self.device)
  • binary_masks = torch.sigmoid(upscaled_masks)
  • loss = self.loss_fn(binary_masks, gt_mask.unsqueeze(1))
  • total_loss += loss.item()
  • return total_loss / len(self.val_loader)
  • def save_checkpoint(self, epoch, loss, is_best=False):
  • # 保存完整的训练状态(用于恢复训练)
  • checkpoint = {
  • 'epoch': epoch,
  • 'model_state_dict': self.model.state_dict(),
  • 'optimizer_state_dict': self.optimizer.state_dict(),
  • 'scheduler_state_dict': self.scheduler.state_dict(),
  • 'loss': loss,
  • 'config': self.config
  • }
  • # 保存完整checkpoint
  • checkpoint_path = os.path.join(
  • self.config['output_dir'],
  • f'checkpoint_epoch_{epoch + 1}.pth'
  • )
  • torch.save(checkpoint, checkpoint_path)
  • # 如果是最佳模型,保存兼容格式的模型权重
  • if is_best:
  • # 保存完整checkpoint
  • best_checkpoint_path = os.path.join(self.config['output_dir'], 'best_checkpoint.pth')
  • torch.save(checkpoint, best_checkpoint_path)
  • # 额外保存一个干净的模型权重(兼容原SAM格式)
  • best_model_path = os.path.join(self.config['output_dir'], 'best_model_sam_format.pth')
  • torch.save(self.model.state_dict(), best_model_path)
  • logger.info(f"Saved best model with loss: {loss:.4f}")
  • def train(self):
  • logger.info("Starting training")
  • for epoch in range(self.config['num_epochs']):
  • train_loss = self.train_epoch(epoch)
  • logger.info(f"Epoch {epoch + 1} - Train Loss: {train_loss:.4f}")
  • val_loss = self.validate()
  • if val_loss is not None:
  • logger.info(f"Epoch {epoch + 1} - Val Loss: {val_loss:.4f}")
  • self.scheduler.step(val_loss)
  • is_best = val_loss < self.best_loss
  • if is_best:
  • self.best_loss = val_loss
  • else:
  • is_best = False
  • self.scheduler.step(train_loss)
  • if (epoch + 1) % self.config['save_interval'] == 0:
  • self.save_checkpoint(
  • epoch,
  • val_loss if val_loss is not None else train_loss,
  • is_best
  • )
  • # 训练结束后保存最终的兼容格式模型
  • final_model_path = os.path.join(self.config['output_dir'], 'final_model_sam_format.pth')
  • torch.save(self.model.state_dict(), final_model_path)
  • logger.info(f"Saved final model in SAM-compatible format: {final_model_path}")
  • def main():
  • # 训练配置
  • config = {
  • 'model_type': 'vit_b',
  • 'checkpoint_path': './checkpoints/sam_vit_b_01ec64.pth',
  • 'train_image_dir': './stamps/images',
  • 'train_mask_dir': './stamps/masks',
  • 'train_bbox_file': './stamps/annotations.txt',
  • 'val_image_dir': './stamps/val_images',
  • 'val_mask_dir': './stamps/val_masks',
  • 'val_bbox_file': './stamps/val_annotations.txt',
  • 'output_dir': f'./outputs/sam_finetune_{datetime.now().strftime("%Y%m%d_%H%M%S")}',
  • 'num_epochs': 1,
  • 'batch_size': 1,
  • 'num_workers': 4,
  • 'encoder_lr': 1e-6,
  • 'decoder_lr': 1e-5,
  • 'save_interval': 5
  • }
  • # 创建训练器并开始训练
  • trainer = SAMFineTuner(config)
  • trainer.train()
  • if __name__ == '__main__':
  • main()
城东书院 www.cdsy.xyz
方便获取更多学习、工作、生活信息请关注本站微信公众号城东书院 微信服务号城东书院 微信订阅号
推荐内容
相关内容
栏目更新
栏目热门
本栏推荐