Dataset preparation from HF

Data will be downloaded from hugging face and then will be processed to get the data in the format we want.

#from cv_tools.core import *
dataset = load_dataset("nielsr/breast-cancer", split="train")
dataset
Dataset({
    features: ['image', 'label'],
    num_rows: 130
})
img = dataset[0]['image']
msk = dataset[0]['label']
img.size, img.mode, msk.size, msk.mode
((256, 256), 'RGB', (256, 256), 'I')
from datasets import load_dataset
training_dataset = load_dataset(
    "hasangoni/Electron_microscopy_dataset",
    split="train"
    )
validation_dataset = load_dataset(
    "hasangoni/Electron_microscopy_dataset",
    split="test"
    )
training_dataset
Dataset({
    features: ['image', 'label'],
    num_rows: 1642
})

source

get_bounding_box

 get_bounding_box (ground_truth_map:numpy.ndarray)

Get bounding box coordinates from mask image

Type Details
ground_truth_map ndarray mask image type cv2

tesing get_bounding_box


source

SAMDataset

 SAMDataset (dataset:torch.utils.data.dataset.Dataset,
             processor:transformers.models.sam.processing_sam.SamProcessor
             )

Creating dataset for SAM Training

Type Details
dataset Dataset pytorch dataset
processor SamProcessor hf model processor

Creating pytorch dataset

processor = SamProcessor.from_pretrained('facebook/sam-vit-base')
train_dataset = SAMDataset(
  dataset=training_dataset, 
  processor=processor)
val_dataset = SAMDataset(
  dataset=validation_dataset, 
  processor=processor)
trn_ = np.transpose(train_dataset[0]['pixel_values'].to('cpu').numpy(), (1,2,0))
val_ = np.transpose(val_dataset[0]['pixel_values'].to('cpu').numpy(), (1,2,0))
#show_(trn_)
train_dataloader = DataLoader(
                            train_dataset, 
                            batch_size=2,
                            shuffle=True)
val_dataloader = DataLoader(
                            val_dataset, 
                            batch_size=2,
                            shuffle=False)
example = train_dataset[0]
for k,v in example.items():
  print(k,v.shape)
     
model = SamModel.from_pretrained("facebook/sam-vit-base")

# make sure we only compute gradients for mask decoder
for name, param in model.named_parameters():
  if name.startswith("vision_encoder") or name.startswith("prompt_encoder"):
    param.requires_grad_(False)
pixel_values torch.Size([3, 1024, 1024])
original_sizes torch.Size([2])
reshaped_input_sizes torch.Size([2])
input_boxes torch.Size([1, 4])
ground_truth_mask (256, 256)

Creating pytorch pytorch dataloader

train_dataloader = DataLoader(
    train_ds, 
    batch_size=2, 
    shuffle=True)
val_dataloader = DataLoader(
    val_ds, 
    batch_size=2, 
    shuffle=False)
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Cell In[149], line 2
      1 train_dataloader = DataLoader(
----> 2     train_ds, 
      3     batch_size=2, 
      4     shuffle=True)
      5 val_dataloader = DataLoader(
      6     val_ds, 
      7     batch_size=2, 
      8     shuffle=False)

NameError: name 'train_ds' is not defined

testing pytorch dataloader

batch = next(iter(train_dataloader))
for k,v in batch.items():
  print(k,v.shape)
pixel_values torch.Size([2, 3, 1024, 1024])
original_sizes torch.Size([2, 2])
reshaped_input_sizes torch.Size([2, 2])
input_boxes torch.Size([2, 1, 4])
ground_truth_mask torch.Size([2, 256, 256])

loading Model

model = SamModel.from_pretrained('facebook/sam-vit-base')
# make sure we only compute gradients for mask decoder
for name, param in model.named_parameters():
  if name.startswith("vision_encoder") or name.startswith("prompt_encoder"):
    param.requires_grad_(False)
NUM_EPOCHS = 2
T_0 = int(0.5 * NUM_EPOCHS)
ITERS = len(train_dataloader)
optimizer = AdamW(
    model.mask_decoder.parameters(),
    lr=0.001,
    weight_decay=0.0001)
device = "cuda" if torch.cuda.is_available() else "cpu"
device= "cpu"
# in case of very small gpu memory, like me then use cpu
#device = "cpu"
model.to(device)
scheduler = CosineAnnealingWarmRestarts(
   T_0=T_0,
   optimizer=optimizer, 
   eta_min=0.00001)
device='cpu'

validate

 validate (model:transformers.models.sam.modeling_sam.SamModel,
           dataloader:torch.utils.data.dataloader.DataLoader, loss_fn:<mod
           ule'monai.losses.dice'from'/opt/hostedtoolcache/Python/3.10.14/
           x64/lib/python3.10/site-packages/monai/losses/dice.py'>,
           device:str='cpu')
num_epochs = 100

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

seg_loss = monai.losses.DiceCELoss(
    sigmoid=True, 
    squared_pred=True, 
    reduction='mean')
model.train()
for epoch in range(num_epochs):
    epoch_losses = []
    for batch in tqdm(train_dataloader):
      # forward pass
      print(batch['pixel_values'].shape)
      outputs = model(pixel_values=batch["pixel_values"].to(device),
                      input_boxes=batch["input_boxes"].to(device),
                      multimask_output=False)
      print(f'outputs shape {outputs.pred_masks.shape}')

      # compute loss
      predicted_masks = outputs.pred_masks.squeeze(1)
      print(f'predicted_masks shape {predicted_masks.shape}') 
      ground_truth_masks = batch["ground_truth_mask"].float().to(device)
      print(f'ground_truth_masks shape {ground_truth_masks.shape}')
      loss = seg_loss(predicted_masks, ground_truth_masks.unsqueeze(1))

      # backward pass (compute gradients of parameters w.r.t. loss)
      optimizer.zero_grad()
      loss.backward()

      # optimize
      optimizer.step()
      epoch_losses.append(loss.item())

    print(f'EPOCH: {epoch}')
    print(f'Mean Training Loss: {mean(epoch_losses)}')
    validation_loss = validate(
        model=model, 
        dataloader=val_dataloader, 
        loss_fn=seg_loss, 
        device=device)

    print(f'Validation Loss: {validation_loss}')
  0%|          | 0/65 [00:00<?, ?it/s]
torch.Size([2, 3, 1024, 1024])
  0%|          | 0/65 [00:03<?, ?it/s]
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
Cell In[44], line 16
     13 for batch in tqdm(train_dataloader):
     14   # forward pass
     15   print(batch['pixel_values'].shape)
---> 16   outputs = model(pixel_values=batch["pixel_values"].to(device),
     17                   input_boxes=batch["input_boxes"].to(device),
     18                   multimask_output=False)
     19   print(f'outputs shape {outputs.pred_masks.shape}')
     21   # compute loss

File ~/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File ~/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File ~/miniconda3/lib/python3.11/site-packages/transformers/models/sam/modeling_sam.py:1358, in SamModel.forward(self, pixel_values, input_points, input_labels, input_boxes, input_masks, image_embeddings, multimask_output, attention_similarity, target_embedding, output_attentions, output_hidden_states, return_dict, **kwargs)
   1355 vision_hidden_states = None
   1357 if pixel_values is not None:
-> 1358     vision_outputs = self.vision_encoder(
   1359         pixel_values,
   1360         output_attentions=output_attentions,
   1361         output_hidden_states=output_hidden_states,
   1362         return_dict=return_dict,
   1363     )
   1364     image_embeddings = vision_outputs[0]
   1366     if output_hidden_states:

File ~/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File ~/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File ~/miniconda3/lib/python3.11/site-packages/transformers/models/sam/modeling_sam.py:1046, in SamVisionEncoder.forward(self, pixel_values, output_attentions, output_hidden_states, return_dict)
   1041     layer_outputs = self._gradient_checkpointing_func(
   1042         layer_module.__call__,
   1043         hidden_states,
   1044     )
   1045 else:
-> 1046     layer_outputs = layer_module(hidden_states, output_attentions=output_attentions)
   1048 hidden_states = layer_outputs[0]
   1050 if output_attentions:

File ~/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File ~/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File ~/miniconda3/lib/python3.11/site-packages/transformers/models/sam/modeling_sam.py:939, in SamVisionLayer.forward(self, hidden_states, output_attentions)
    936     height, width = hidden_states.shape[1], hidden_states.shape[2]
    937     hidden_states, padding_shape = self.window_partition(hidden_states, self.window_size)
--> 939 hidden_states, attn_weights = self.attn(
    940     hidden_states=hidden_states,
    941     output_attentions=output_attentions,
    942 )
    943 # Reverse window partition
    944 if self.window_size > 0:

File ~/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File ~/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File ~/miniconda3/lib/python3.11/site-packages/transformers/models/sam/modeling_sam.py:842, in SamVisionAttention.forward(self, hidden_states, output_attentions)
    839 attn_weights = (query * self.scale) @ key.transpose(-2, -1)
    841 if self.use_rel_pos:
--> 842     attn_weights = self.add_decomposed_rel_pos(
    843         attn_weights, query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
    844     )
    846 attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype)
    848 attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)

File ~/miniconda3/lib/python3.11/site-packages/transformers/models/sam/modeling_sam.py:825, in SamVisionAttention.add_decomposed_rel_pos(self, attn, query, rel_pos_h, rel_pos_w, q_size, k_size)
    823 attn = attn.reshape(batch_size, query_height, query_width, key_height, key_width)
    824 attn = attn + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
--> 825 attn = attn.reshape(batch_size, query_height * query_width, key_height * key_width)
    826 return attn

KeyboardInterrupt: 

pt_train

 pt_train (train_dataloader:torch.utils.data.dataloader.DataLoader,
           model:transformers.models.sam.modeling_sam.SamModel,
           optimizer:torch.optim.optimizer.Optimizer,
           device:Optional[str]='cpu', epoch_n:int=2)
pt_train(
    train_dataloader=train_dataloader,
    model=model, 
    optimizer=optimizer, 
    device=device, 
    epoch_n=2)
Epoch 1
23it [08:41, 22.89s/it]
The Kernel crashed while executing code in the current cell or a previous cell. 

Please review the code in the cell(s) to identify a possible cause of the failure. 

Click <a href='https://aka.ms/vscodeJupyterKernelCrash'>here</a> for more info. 

View Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details.

validate

 validate (model:transformers.models.sam.modeling_sam.SamModel,
           dataloader:torch.utils.data.dataloader.DataLoader, loss_fn:<mod
           ule'monai.losses'from'/opt/hostedtoolcache/Python/3.10.14/x64/l
           ib/python3.10/site-packages/monai/losses/__init__.py'>,
           device:str='cpu')

*Validate the model using a validation dataloader.

Parameters: - model: The PyTorch model to validate. - dataloader: DataLoader for validation data. - loss_fn: Loss function used for validation. - device: Device to run validation on (‘cuda’ or ‘cpu’).*

Type Default Details
model SamModel SAM model
dataloader DataLoader Torch dataloader
loss_fn monai.losses Monai loss function
device str cpu whether to use cpu or gpu

train_and_validate

 train_and_validate (model:transformers.models.sam.modeling_sam.SamModel,
                     num_epochs:int,
                     optimizer:torch.optim.optimizer.Optimizer, scheduler:
                     <module'torch.optim.lr_scheduler'from'/opt/hostedtool
                     cache/Python/3.10.14/x64/lib/python3.10/site-
                     packages/torch/optim/lr_scheduler.py'>, train_dataloa
                     der:torch.utils.data.dataloader.DataLoader, val_datal
                     oader:torch.utils.data.dataloader.DataLoader, loss_fn
                     :<module'monai.losses'from'/opt/hostedtoolcache/Pytho
                     n/3.10.14/x64/lib/python3.10/site-
                     packages/monai/losses/__init__.py'>,
                     device:str='cpu')

Train and validate a model with the given parameters.

Type Default Details
model SamModel SAM model
num_epochs int Number of epochs to train for
optimizer Optimizer Optimizer to use
scheduler torch.optim.lr_scheduler Learning rate scheduler
train_dataloader DataLoader DataLoader for training data
val_dataloader DataLoader DataLoader for validation data
loss_fn monai.losses Loss function used for training
device str cpu Device to train on (‘cuda’ or ‘cpu’)
#model = SamModel.from_pretrained('facebook/sam-vit-base')
optimizer = AdamW(
    model.mask_decoder.parameters(),
    lr=0.001,
    weight_decay=0.0001)

scheduler = CosineAnnealingWarmRestarts(
   T_0=10,
   T_mult=2,
   optimizer=optimizer, 
   eta_min=0.00001)
seg_loss = monai.losses.DiceCELoss()
device='cpu'
train_and_validate(
    model=model,
    num_epochs=2,
    optimizer=optimizer,
    scheduler=scheduler,
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    loss_fn=seg_loss,
    device=device
)
Epoch 1/2:   0%|          | 1/821 [00:36<8:21:37, 36.70s/it]
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
Cell In[161], line 1
----> 1 train_and_validate(
      2     model=model,
      3     num_epochs=2,
      4     optimizer=optimizer,
      5     scheduler=scheduler,
      6     train_dataloader=train_dataloader,
      7     val_dataloader=val_dataloader,
      8     loss_fn=seg_loss,
      9     device=device
     10 )

Cell In[159], line 26, in train_and_validate(model, num_epochs, optimizer, scheduler, train_dataloader, val_dataloader, loss_fn, device)
     20 progress_bar = tqdm(
     21     train_dataloader, 
     22     desc=f'Epoch {epoch+1}/{num_epochs}', total=len(train_dataloader))
     24 for batch in progress_bar:
     25     # Forward pass
---> 26     outputs = model(pixel_values=batch["pixel_values"].to(device),
     27                     input_boxes=batch["input_boxes"].to(device),
     28                     multimask_output=False)
     30     # Compute loss
     31     predicted_masks = outputs.pred_masks.squeeze(1)

File ~/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File ~/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File ~/miniconda3/lib/python3.11/site-packages/transformers/models/sam/modeling_sam.py:1358, in SamModel.forward(self, pixel_values, input_points, input_labels, input_boxes, input_masks, image_embeddings, multimask_output, attention_similarity, target_embedding, output_attentions, output_hidden_states, return_dict, **kwargs)
   1355 vision_hidden_states = None
   1357 if pixel_values is not None:
-> 1358     vision_outputs = self.vision_encoder(
   1359         pixel_values,
   1360         output_attentions=output_attentions,
   1361         output_hidden_states=output_hidden_states,
   1362         return_dict=return_dict,
   1363     )
   1364     image_embeddings = vision_outputs[0]
   1366     if output_hidden_states:

File ~/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File ~/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File ~/miniconda3/lib/python3.11/site-packages/transformers/models/sam/modeling_sam.py:1046, in SamVisionEncoder.forward(self, pixel_values, output_attentions, output_hidden_states, return_dict)
   1041     layer_outputs = self._gradient_checkpointing_func(
   1042         layer_module.__call__,
   1043         hidden_states,
   1044     )
   1045 else:
-> 1046     layer_outputs = layer_module(hidden_states, output_attentions=output_attentions)
   1048 hidden_states = layer_outputs[0]
   1050 if output_attentions:

File ~/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File ~/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File ~/miniconda3/lib/python3.11/site-packages/transformers/models/sam/modeling_sam.py:939, in SamVisionLayer.forward(self, hidden_states, output_attentions)
    936     height, width = hidden_states.shape[1], hidden_states.shape[2]
    937     hidden_states, padding_shape = self.window_partition(hidden_states, self.window_size)
--> 939 hidden_states, attn_weights = self.attn(
    940     hidden_states=hidden_states,
    941     output_attentions=output_attentions,
    942 )
    943 # Reverse window partition
    944 if self.window_size > 0:

File ~/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File ~/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File ~/miniconda3/lib/python3.11/site-packages/transformers/models/sam/modeling_sam.py:846, in SamVisionAttention.forward(self, hidden_states, output_attentions)
    841 if self.use_rel_pos:
    842     attn_weights = self.add_decomposed_rel_pos(
    843         attn_weights, query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
    844     )
--> 846 attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype)
    848 attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
    850 attn_output = (attn_probs @ value).reshape(batch_size, self.num_attention_heads, height, width, -1)

File ~/miniconda3/lib/python3.11/site-packages/torch/nn/functional.py:1860, in softmax(input, dim, _stacklevel, dtype)
   1858     ret = input.softmax(dim)
   1859 else:
-> 1860     ret = input.softmax(dim, dtype=dtype)
   1861 return ret

KeyboardInterrupt: