#from cv_tools.core import *Dataset preparation from HF
Data will be downloaded from hugging face and then will be processed to get the data in the format we want.
dataset = load_dataset("nielsr/breast-cancer", split="train")
datasetDataset({
features: ['image', 'label'],
num_rows: 130
})
img = dataset[0]['image']
msk = dataset[0]['label']
img.size, img.mode, msk.size, msk.mode((256, 256), 'RGB', (256, 256), 'I')
from datasets import load_datasettraining_dataset = load_dataset(
"hasangoni/Electron_microscopy_dataset",
split="train"
)
validation_dataset = load_dataset(
"hasangoni/Electron_microscopy_dataset",
split="test"
)training_datasetDataset({
features: ['image', 'label'],
num_rows: 1642
})
get_bounding_box
get_bounding_box (ground_truth_map:numpy.ndarray)
Get bounding box coordinates from mask image
| Type | Details | |
|---|---|---|
| ground_truth_map | ndarray | mask image type cv2 |
tesing get_bounding_box
SAMDataset
SAMDataset (dataset:torch.utils.data.dataset.Dataset, processor:transformers.models.sam.processing_sam.SamProcessor )
Creating dataset for SAM Training
| Type | Details | |
|---|---|---|
| dataset | Dataset | pytorch dataset |
| processor | SamProcessor | hf model processor |
Creating pytorch dataset
processor = SamProcessor.from_pretrained('facebook/sam-vit-base')
train_dataset = SAMDataset(
dataset=training_dataset,
processor=processor)
val_dataset = SAMDataset(
dataset=validation_dataset,
processor=processor)trn_ = np.transpose(train_dataset[0]['pixel_values'].to('cpu').numpy(), (1,2,0))
val_ = np.transpose(val_dataset[0]['pixel_values'].to('cpu').numpy(), (1,2,0))
#show_(trn_)train_dataloader = DataLoader(
train_dataset,
batch_size=2,
shuffle=True)
val_dataloader = DataLoader(
val_dataset,
batch_size=2,
shuffle=False)example = train_dataset[0]
for k,v in example.items():
print(k,v.shape)
model = SamModel.from_pretrained("facebook/sam-vit-base")
# make sure we only compute gradients for mask decoder
for name, param in model.named_parameters():
if name.startswith("vision_encoder") or name.startswith("prompt_encoder"):
param.requires_grad_(False)pixel_values torch.Size([3, 1024, 1024])
original_sizes torch.Size([2])
reshaped_input_sizes torch.Size([2])
input_boxes torch.Size([1, 4])
ground_truth_mask (256, 256)
Creating pytorch pytorch dataloader
train_dataloader = DataLoader(
train_ds,
batch_size=2,
shuffle=True)
val_dataloader = DataLoader(
val_ds,
batch_size=2,
shuffle=False)--------------------------------------------------------------------------- NameError Traceback (most recent call last) Cell In[149], line 2 1 train_dataloader = DataLoader( ----> 2 train_ds, 3 batch_size=2, 4 shuffle=True) 5 val_dataloader = DataLoader( 6 val_ds, 7 batch_size=2, 8 shuffle=False) NameError: name 'train_ds' is not defined
testing pytorch dataloader
batch = next(iter(train_dataloader))
for k,v in batch.items():
print(k,v.shape)pixel_values torch.Size([2, 3, 1024, 1024])
original_sizes torch.Size([2, 2])
reshaped_input_sizes torch.Size([2, 2])
input_boxes torch.Size([2, 1, 4])
ground_truth_mask torch.Size([2, 256, 256])
loading Model
model = SamModel.from_pretrained('facebook/sam-vit-base')# make sure we only compute gradients for mask decoder
for name, param in model.named_parameters():
if name.startswith("vision_encoder") or name.startswith("prompt_encoder"):
param.requires_grad_(False)NUM_EPOCHS = 2
T_0 = int(0.5 * NUM_EPOCHS)
ITERS = len(train_dataloader)optimizer = AdamW(
model.mask_decoder.parameters(),
lr=0.001,
weight_decay=0.0001)device = "cuda" if torch.cuda.is_available() else "cpu"
device= "cpu"
# in case of very small gpu memory, like me then use cpu
#device = "cpu"
model.to(device)
scheduler = CosineAnnealingWarmRestarts(
T_0=T_0,
optimizer=optimizer,
eta_min=0.00001)device='cpu'validate
validate (model:transformers.models.sam.modeling_sam.SamModel, dataloader:torch.utils.data.dataloader.DataLoader, loss_fn:<mod ule'monai.losses.dice'from'/opt/hostedtoolcache/Python/3.10.14/ x64/lib/python3.10/site-packages/monai/losses/dice.py'>, device:str='cpu')
num_epochs = 100
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
seg_loss = monai.losses.DiceCELoss(
sigmoid=True,
squared_pred=True,
reduction='mean')
model.train()
for epoch in range(num_epochs):
epoch_losses = []
for batch in tqdm(train_dataloader):
# forward pass
print(batch['pixel_values'].shape)
outputs = model(pixel_values=batch["pixel_values"].to(device),
input_boxes=batch["input_boxes"].to(device),
multimask_output=False)
print(f'outputs shape {outputs.pred_masks.shape}')
# compute loss
predicted_masks = outputs.pred_masks.squeeze(1)
print(f'predicted_masks shape {predicted_masks.shape}')
ground_truth_masks = batch["ground_truth_mask"].float().to(device)
print(f'ground_truth_masks shape {ground_truth_masks.shape}')
loss = seg_loss(predicted_masks, ground_truth_masks.unsqueeze(1))
# backward pass (compute gradients of parameters w.r.t. loss)
optimizer.zero_grad()
loss.backward()
# optimize
optimizer.step()
epoch_losses.append(loss.item())
print(f'EPOCH: {epoch}')
print(f'Mean Training Loss: {mean(epoch_losses)}')
validation_loss = validate(
model=model,
dataloader=val_dataloader,
loss_fn=seg_loss,
device=device)
print(f'Validation Loss: {validation_loss}') 0%| | 0/65 [00:00<?, ?it/s]
torch.Size([2, 3, 1024, 1024])
0%| | 0/65 [00:03<?, ?it/s]
--------------------------------------------------------------------------- KeyboardInterrupt Traceback (most recent call last) Cell In[44], line 16 13 for batch in tqdm(train_dataloader): 14 # forward pass 15 print(batch['pixel_values'].shape) ---> 16 outputs = model(pixel_values=batch["pixel_values"].to(device), 17 input_boxes=batch["input_boxes"].to(device), 18 multimask_output=False) 19 print(f'outputs shape {outputs.pred_masks.shape}') 21 # compute loss File ~/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs) 1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1510 else: -> 1511 return self._call_impl(*args, **kwargs) File ~/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs) 1515 # If we don't have any hooks, we want to skip the rest of the logic in 1516 # this function, and just call forward. 1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1518 or _global_backward_pre_hooks or _global_backward_hooks 1519 or _global_forward_hooks or _global_forward_pre_hooks): -> 1520 return forward_call(*args, **kwargs) 1522 try: 1523 result = None File ~/miniconda3/lib/python3.11/site-packages/transformers/models/sam/modeling_sam.py:1358, in SamModel.forward(self, pixel_values, input_points, input_labels, input_boxes, input_masks, image_embeddings, multimask_output, attention_similarity, target_embedding, output_attentions, output_hidden_states, return_dict, **kwargs) 1355 vision_hidden_states = None 1357 if pixel_values is not None: -> 1358 vision_outputs = self.vision_encoder( 1359 pixel_values, 1360 output_attentions=output_attentions, 1361 output_hidden_states=output_hidden_states, 1362 return_dict=return_dict, 1363 ) 1364 image_embeddings = vision_outputs[0] 1366 if output_hidden_states: File ~/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs) 1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1510 else: -> 1511 return self._call_impl(*args, **kwargs) File ~/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs) 1515 # If we don't have any hooks, we want to skip the rest of the logic in 1516 # this function, and just call forward. 1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1518 or _global_backward_pre_hooks or _global_backward_hooks 1519 or _global_forward_hooks or _global_forward_pre_hooks): -> 1520 return forward_call(*args, **kwargs) 1522 try: 1523 result = None File ~/miniconda3/lib/python3.11/site-packages/transformers/models/sam/modeling_sam.py:1046, in SamVisionEncoder.forward(self, pixel_values, output_attentions, output_hidden_states, return_dict) 1041 layer_outputs = self._gradient_checkpointing_func( 1042 layer_module.__call__, 1043 hidden_states, 1044 ) 1045 else: -> 1046 layer_outputs = layer_module(hidden_states, output_attentions=output_attentions) 1048 hidden_states = layer_outputs[0] 1050 if output_attentions: File ~/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs) 1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1510 else: -> 1511 return self._call_impl(*args, **kwargs) File ~/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs) 1515 # If we don't have any hooks, we want to skip the rest of the logic in 1516 # this function, and just call forward. 1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1518 or _global_backward_pre_hooks or _global_backward_hooks 1519 or _global_forward_hooks or _global_forward_pre_hooks): -> 1520 return forward_call(*args, **kwargs) 1522 try: 1523 result = None File ~/miniconda3/lib/python3.11/site-packages/transformers/models/sam/modeling_sam.py:939, in SamVisionLayer.forward(self, hidden_states, output_attentions) 936 height, width = hidden_states.shape[1], hidden_states.shape[2] 937 hidden_states, padding_shape = self.window_partition(hidden_states, self.window_size) --> 939 hidden_states, attn_weights = self.attn( 940 hidden_states=hidden_states, 941 output_attentions=output_attentions, 942 ) 943 # Reverse window partition 944 if self.window_size > 0: File ~/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs) 1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1510 else: -> 1511 return self._call_impl(*args, **kwargs) File ~/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs) 1515 # If we don't have any hooks, we want to skip the rest of the logic in 1516 # this function, and just call forward. 1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1518 or _global_backward_pre_hooks or _global_backward_hooks 1519 or _global_forward_hooks or _global_forward_pre_hooks): -> 1520 return forward_call(*args, **kwargs) 1522 try: 1523 result = None File ~/miniconda3/lib/python3.11/site-packages/transformers/models/sam/modeling_sam.py:842, in SamVisionAttention.forward(self, hidden_states, output_attentions) 839 attn_weights = (query * self.scale) @ key.transpose(-2, -1) 841 if self.use_rel_pos: --> 842 attn_weights = self.add_decomposed_rel_pos( 843 attn_weights, query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width) 844 ) 846 attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype) 848 attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) File ~/miniconda3/lib/python3.11/site-packages/transformers/models/sam/modeling_sam.py:825, in SamVisionAttention.add_decomposed_rel_pos(self, attn, query, rel_pos_h, rel_pos_w, q_size, k_size) 823 attn = attn.reshape(batch_size, query_height, query_width, key_height, key_width) 824 attn = attn + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] --> 825 attn = attn.reshape(batch_size, query_height * query_width, key_height * key_width) 826 return attn KeyboardInterrupt:
pt_train
pt_train (train_dataloader:torch.utils.data.dataloader.DataLoader, model:transformers.models.sam.modeling_sam.SamModel, optimizer:torch.optim.optimizer.Optimizer, device:Optional[str]='cpu', epoch_n:int=2)
pt_train(
train_dataloader=train_dataloader,
model=model,
optimizer=optimizer,
device=device,
epoch_n=2)Epoch 1
23it [08:41, 22.89s/it]
The Kernel crashed while executing code in the current cell or a previous cell. Please review the code in the cell(s) to identify a possible cause of the failure. Click <a href='https://aka.ms/vscodeJupyterKernelCrash'>here</a> for more info. View Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details.
validate
validate (model:transformers.models.sam.modeling_sam.SamModel, dataloader:torch.utils.data.dataloader.DataLoader, loss_fn:<mod ule'monai.losses'from'/opt/hostedtoolcache/Python/3.10.14/x64/l ib/python3.10/site-packages/monai/losses/__init__.py'>, device:str='cpu')
*Validate the model using a validation dataloader.
Parameters: - model: The PyTorch model to validate. - dataloader: DataLoader for validation data. - loss_fn: Loss function used for validation. - device: Device to run validation on (‘cuda’ or ‘cpu’).*
| Type | Default | Details | |
|---|---|---|---|
| model | SamModel | SAM model | |
| dataloader | DataLoader | Torch dataloader | |
| loss_fn | monai.losses | Monai loss function | |
| device | str | cpu | whether to use cpu or gpu |
train_and_validate
train_and_validate (model:transformers.models.sam.modeling_sam.SamModel, num_epochs:int, optimizer:torch.optim.optimizer.Optimizer, scheduler: <module'torch.optim.lr_scheduler'from'/opt/hostedtool cache/Python/3.10.14/x64/lib/python3.10/site- packages/torch/optim/lr_scheduler.py'>, train_dataloa der:torch.utils.data.dataloader.DataLoader, val_datal oader:torch.utils.data.dataloader.DataLoader, loss_fn :<module'monai.losses'from'/opt/hostedtoolcache/Pytho n/3.10.14/x64/lib/python3.10/site- packages/monai/losses/__init__.py'>, device:str='cpu')
Train and validate a model with the given parameters.
| Type | Default | Details | |
|---|---|---|---|
| model | SamModel | SAM model | |
| num_epochs | int | Number of epochs to train for | |
| optimizer | Optimizer | Optimizer to use | |
| scheduler | torch.optim.lr_scheduler | Learning rate scheduler | |
| train_dataloader | DataLoader | DataLoader for training data | |
| val_dataloader | DataLoader | DataLoader for validation data | |
| loss_fn | monai.losses | Loss function used for training | |
| device | str | cpu | Device to train on (‘cuda’ or ‘cpu’) |
#model = SamModel.from_pretrained('facebook/sam-vit-base')
optimizer = AdamW(
model.mask_decoder.parameters(),
lr=0.001,
weight_decay=0.0001)
scheduler = CosineAnnealingWarmRestarts(
T_0=10,
T_mult=2,
optimizer=optimizer,
eta_min=0.00001)
seg_loss = monai.losses.DiceCELoss()
device='cpu'train_and_validate(
model=model,
num_epochs=2,
optimizer=optimizer,
scheduler=scheduler,
train_dataloader=train_dataloader,
val_dataloader=val_dataloader,
loss_fn=seg_loss,
device=device
)Epoch 1/2: 0%| | 1/821 [00:36<8:21:37, 36.70s/it]
--------------------------------------------------------------------------- KeyboardInterrupt Traceback (most recent call last) Cell In[161], line 1 ----> 1 train_and_validate( 2 model=model, 3 num_epochs=2, 4 optimizer=optimizer, 5 scheduler=scheduler, 6 train_dataloader=train_dataloader, 7 val_dataloader=val_dataloader, 8 loss_fn=seg_loss, 9 device=device 10 ) Cell In[159], line 26, in train_and_validate(model, num_epochs, optimizer, scheduler, train_dataloader, val_dataloader, loss_fn, device) 20 progress_bar = tqdm( 21 train_dataloader, 22 desc=f'Epoch {epoch+1}/{num_epochs}', total=len(train_dataloader)) 24 for batch in progress_bar: 25 # Forward pass ---> 26 outputs = model(pixel_values=batch["pixel_values"].to(device), 27 input_boxes=batch["input_boxes"].to(device), 28 multimask_output=False) 30 # Compute loss 31 predicted_masks = outputs.pred_masks.squeeze(1) File ~/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs) 1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1510 else: -> 1511 return self._call_impl(*args, **kwargs) File ~/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs) 1515 # If we don't have any hooks, we want to skip the rest of the logic in 1516 # this function, and just call forward. 1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1518 or _global_backward_pre_hooks or _global_backward_hooks 1519 or _global_forward_hooks or _global_forward_pre_hooks): -> 1520 return forward_call(*args, **kwargs) 1522 try: 1523 result = None File ~/miniconda3/lib/python3.11/site-packages/transformers/models/sam/modeling_sam.py:1358, in SamModel.forward(self, pixel_values, input_points, input_labels, input_boxes, input_masks, image_embeddings, multimask_output, attention_similarity, target_embedding, output_attentions, output_hidden_states, return_dict, **kwargs) 1355 vision_hidden_states = None 1357 if pixel_values is not None: -> 1358 vision_outputs = self.vision_encoder( 1359 pixel_values, 1360 output_attentions=output_attentions, 1361 output_hidden_states=output_hidden_states, 1362 return_dict=return_dict, 1363 ) 1364 image_embeddings = vision_outputs[0] 1366 if output_hidden_states: File ~/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs) 1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1510 else: -> 1511 return self._call_impl(*args, **kwargs) File ~/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs) 1515 # If we don't have any hooks, we want to skip the rest of the logic in 1516 # this function, and just call forward. 1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1518 or _global_backward_pre_hooks or _global_backward_hooks 1519 or _global_forward_hooks or _global_forward_pre_hooks): -> 1520 return forward_call(*args, **kwargs) 1522 try: 1523 result = None File ~/miniconda3/lib/python3.11/site-packages/transformers/models/sam/modeling_sam.py:1046, in SamVisionEncoder.forward(self, pixel_values, output_attentions, output_hidden_states, return_dict) 1041 layer_outputs = self._gradient_checkpointing_func( 1042 layer_module.__call__, 1043 hidden_states, 1044 ) 1045 else: -> 1046 layer_outputs = layer_module(hidden_states, output_attentions=output_attentions) 1048 hidden_states = layer_outputs[0] 1050 if output_attentions: File ~/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs) 1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1510 else: -> 1511 return self._call_impl(*args, **kwargs) File ~/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs) 1515 # If we don't have any hooks, we want to skip the rest of the logic in 1516 # this function, and just call forward. 1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1518 or _global_backward_pre_hooks or _global_backward_hooks 1519 or _global_forward_hooks or _global_forward_pre_hooks): -> 1520 return forward_call(*args, **kwargs) 1522 try: 1523 result = None File ~/miniconda3/lib/python3.11/site-packages/transformers/models/sam/modeling_sam.py:939, in SamVisionLayer.forward(self, hidden_states, output_attentions) 936 height, width = hidden_states.shape[1], hidden_states.shape[2] 937 hidden_states, padding_shape = self.window_partition(hidden_states, self.window_size) --> 939 hidden_states, attn_weights = self.attn( 940 hidden_states=hidden_states, 941 output_attentions=output_attentions, 942 ) 943 # Reverse window partition 944 if self.window_size > 0: File ~/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs) 1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1510 else: -> 1511 return self._call_impl(*args, **kwargs) File ~/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs) 1515 # If we don't have any hooks, we want to skip the rest of the logic in 1516 # this function, and just call forward. 1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1518 or _global_backward_pre_hooks or _global_backward_hooks 1519 or _global_forward_hooks or _global_forward_pre_hooks): -> 1520 return forward_call(*args, **kwargs) 1522 try: 1523 result = None File ~/miniconda3/lib/python3.11/site-packages/transformers/models/sam/modeling_sam.py:846, in SamVisionAttention.forward(self, hidden_states, output_attentions) 841 if self.use_rel_pos: 842 attn_weights = self.add_decomposed_rel_pos( 843 attn_weights, query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width) 844 ) --> 846 attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype) 848 attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) 850 attn_output = (attn_probs @ value).reshape(batch_size, self.num_attention_heads, height, width, -1) File ~/miniconda3/lib/python3.11/site-packages/torch/nn/functional.py:1860, in softmax(input, dim, _stacklevel, dtype) 1858 ret = input.softmax(dim) 1859 else: -> 1860 ret = input.softmax(dim, dtype=dtype) 1861 return ret KeyboardInterrupt: