Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Mmasoud1 committed Oct 8, 2024
1 parent 76e2d79 commit fc67e95
Show file tree
Hide file tree
Showing 28 changed files with 1,317 additions and 615 deletions.
Binary file not shown.
Binary file modified app/code/executor/__pycache__/meshnet_executor.cpython-38.pyc
Binary file not shown.
Binary file not shown.
235 changes: 173 additions & 62 deletions app/code/executor/meshnet_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,24 @@
from .meshnet import enMesh_checkpoint
from .loader import Scanloader # Import the Scanloader for MRI data
from .dist import GenericLogger # Import GenericLogger
from .dice import faster_dice # Import Dice score calculation
import torch.cuda.amp as amp
from torch.utils.checkpoint import checkpoint # For layer checkpointing
from .paths import get_data_directory_path, get_output_directory_path
import logging

# Setup logging
logging.basicConfig(level=logging.DEBUG)

class MeshNetExecutor(Executor):
def __init__(self):
super().__init__()
# Model Initialization
config_file_path = os.path.join(os.path.dirname(__file__), "modelAE.json")
self.model = enMesh_checkpoint(in_channels=1, n_classes=3, channels=1, config_file=config_file_path)

# Check if GPU availabel
self.model = enMesh_checkpoint(in_channels=1, n_classes=3, channels=5, config_file=config_file_path)

# Check if GPU available
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(self.device)

Expand All @@ -29,20 +36,20 @@ def __init__(self):

# Optimizer and criterion setup
self.optimizer = torch.optim.RMSprop(self.model.parameters(), lr=0.001)
# I guess adam also can be used
self.criterion = torch.nn.CrossEntropyLoss()

# amp for mixed precision to overcome memory limitation for now
self.scaler = amp.GradScaler()

# Data Loader with min batch size to save memory
db_file_path = os.path.join(os.path.dirname(__file__), "mindboggle.db")
self.data_loader = Scanloader(db_file=db_file_path, label_type='GWlabels', num_cubes=1)
self.trainloader, self.validloader, self.testloader = self.data_loader.get_loaders(batch_size=1)
# AMP for mixed precision to overcome memory limitations
self.scaler = torch.amp.GradScaler() # No need to specify 'cuda', it's inferred automatically

# Logger can be found for example with: MeshDist_nvflare/simulator_workspace/simulate_job/app_site-1 and app_site-2
self.logger = GenericLogger(log_file_path='meshnet_executor.log')
self.current_iteration = 0
self.current_epoch = 0

# Epochs and aggregation interval
self.total_epochs = 2 # Set the total number of epochs
# self.aggregation_interval = 1 # Aggregation occurs every 5 epochs (you can modify this)

self.dice_threshold = 0.9 # Set the Dice score threshold

def execute(
self,
Expand All @@ -51,78 +58,165 @@ def execute(
fl_ctx: FLContext,
abort_signal: Signal,
) -> Shareable:
# Get the correct data directory path
db_file_path = os.path.join(get_data_directory_path(fl_ctx), "mindboggle.db")


# Initialize Data Loader with dynamic path
self.data_loader = Scanloader(db_file=db_file_path, label_type='GWlabels', num_cubes=1)
self.trainloader, self.validloader, self.testloader = self.data_loader.get_loaders(batch_size=1)
self.site_name = fl_ctx.get_prop(FLContextKey.CLIENT_NAME)


if task_name == "train_and_get_gradients":
self.logger.log_message(f"{self.site_name}-train_and_get_gradients called ")
gradients = self.train_and_get_gradients()
outgoing_shareable = Shareable()
outgoing_shareable["gradients"] = gradients
return outgoing_shareable

elif task_name == "accept_aggregated_gradients":
self.logger.log_message(f"{self.site_name}-accept_aggregated_gradients called ")
aggregated_gradients = shareable["aggregated_gradients"]
self.apply_gradients(aggregated_gradients)
self.apply_gradients(aggregated_gradients, fl_ctx)
return Shareable()

def train_and_get_gradients_old(self):
self.model.train()
image, label = self.get_next_train_batch()
image, label = image.to(self.device), label.to(self.device)
# def train_and_get_gradients_old(self):
# for epoch in range(self.total_epochs):

# self.logger.log_message(f"Starting Epoch {epoch}/{self.total_epochs}, Aggregation Interval: {self.aggregation_interval}")
# self.model.train()

# # Initialize accumulators for the loss and gradients
# total_loss = 0.0
# gradient_accumulator = [torch.zeros_like(param).to(self.device) for param in self.model.parameters()]

# # Training loop for one epoch (full pass through the dataset)
# for batch_id, (image, label) in enumerate(self.trainloader):
# image, label = image.to(self.device), label.to(self.device)
# self.optimizer.zero_grad()

# Ensure input requires gradients
image.requires_grad = True
# # Mixed precision and checkpointing
# with torch.amp.autocast(device_type='cuda'):
# output = torch.utils.checkpoint.checkpoint(self.model, image, use_reentrant=False)
# label = label.squeeze(1)
# loss = self.criterion(output, label.long())

self.optimizer.zero_grad()
# total_loss += loss.item()

# Mixed precision and checkpointing
with amp.autocast():
# Ensure checkpoint works with requires_grad
output = self.model(image) # To avoid using checkpointing for now
# # Scale loss and backward pass
# self.scaler.scale(loss).backward()

# Fix shape dim and ensure label is in long type
label = label.squeeze(1)
loss = self.criterion(output, label.long())
# # Accumulate gradients
# for i, param in enumerate(self.model.parameters()):
# if param.grad is not None:
# gradient_accumulator[i] += param.grad.clone()

# Scale the loss before backward pass with amp
self.scaler.scale(loss).backward()
# self.scaler.step(self.optimizer)
# self.scaler.update()

# Update the optimizer with scaled gradients
self.scaler.step(self.optimizer)
self.scaler.update()
# torch.cuda.empty_cache()

# Log loss
self.logger.log_message(f"Iteration {self.current_iteration}: Loss = {loss.item()}")
# # Log the average loss per epoch
# average_loss = total_loss / len(self.trainloader)
# dice_score = self.calculate_dice(self.trainloader)
# self.logger.log_message(f"Site {self.site_name} - Epoch {epoch}: Loss = {average_loss}, Dice = {dice_score}")

# Extract gradients
gradients = [param.grad.clone().cpu().numpy() for param in self.model.parameters() if param.grad is not None]
self.current_iteration += 1
# # Call aggregation based on your set aggregation_interval
# if (epoch + 1) % self.aggregation_interval == 0:
# # Perform model aggregation here
# return [grad.clone().cpu().numpy() for grad in gradient_accumulator if grad is not None]

# Clear GPU memory cache to free memory
torch.cuda.empty_cache()
# return []

# def train_and_get_gradients_new(self):
# for epoch in range(self.total_epochs):
# # self.logger.log_message(f"Starting Epoch {epoch+1}/{self.total_epochs}, Aggregation Interval: {self.aggregation_interval}")
# self.model.train()

# # Initialize accumulators for the loss and gradients
# total_loss = 0.0
# gradient_accumulator = [torch.zeros_like(param).to(self.device) for param in self.model.parameters()]

# # Training loop for one epoch (full pass through the dataset)
# for batch_id, (image, label) in enumerate(self.trainloader):
# image, label = image.to(self.device), label.to(self.device)
# self.optimizer.zero_grad()

# # Mixed precision and checkpointing
# with torch.amp.autocast(device_type='cuda'):
# output = torch.utils.checkpoint.checkpoint(self.model, image, use_reentrant=False)
# label = label.squeeze(1)
# loss = self.criterion(output, label.long())

# total_loss += loss.item()

# # Scale loss and backward pass
# self.scaler.scale(loss).backward()

# # Accumulate gradients
# for i, param in enumerate(self.model.parameters()):
# if param.grad is not None:
# gradient_accumulator[i] += param.grad.clone()

# self.scaler.step(self.optimizer)
# self.scaler.update()

# torch.cuda.empty_cache()

# # Log the average loss per epoch
# average_loss = total_loss / len(self.trainloader)
# dice_score = self.calculate_dice(self.trainloader)
# self.logger.log_message(f"Site {self.site_name} - Epoch {epoch+1}: Loss = {average_loss}, Dice = {dice_score}")

# # Check if it's time to perform aggregation
# if (epoch + 1) % self.aggregation_interval == 0:
# # Return the gradients after completing the specified aggregation interval
# self.logger.log_message(f"Performing aggregation after epoch {epoch+1}")
# return [grad.clone().cpu().numpy() for grad in gradient_accumulator if grad is not None]

# return []


# def setup_site_logger(self):
# site_id = os.getenv('FL_SITE_ID', 'site_unknown') # Use environment variable or other means to set site ID
# log_dir = f'logs/{site_id}'
# os.makedirs(log_dir, exist_ok=True)
# log_filename = os.path.join(log_dir, f'training_log_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log')

# logging.basicConfig(
# filename=log_filename,
# level=logging.INFO,
# format='%(asctime)s - %(levelname)s - %(message)s',
# datefmt='%Y-%m-%d %H:%M:%S',
# )

# self.logger = logging.getLogger()
# self.logger.info("Logging started")

return gradients

def train_and_get_gradients(self):
self.model.train()

# Initialize accumulators for the loss and gradients
total_loss = 0.0
gradient_accumulator = [torch.zeros_like(param).to(self.device) for param in self.model.parameters()]

# Training loop for one epoch (full pass through the dataset)
for batch_id, (image, label) in enumerate(self.trainloader):
image, label = image.to(self.device), label.to(self.device)
self.optimizer.zero_grad()

# Mixed precision and checkpointing
with amp.autocast():
output = self.model(image)
with torch.amp.autocast(device_type='cuda'):
output = torch.utils.checkpoint.checkpoint(self.model, image, use_reentrant=False)
label = label.squeeze(1)
loss = self.criterion(output, label.long())

# Accumulate loss
total_loss += loss.item()

# Scale the loss and backward pass
# Scale loss and backward pass
self.scaler.scale(loss).backward()

# Accumulate gradients
Expand All @@ -136,29 +230,35 @@ def train_and_get_gradients(self):

# Clear GPU cache
torch.cuda.empty_cache()
# Log the average loss per epoch

# Log the average loss and Dice score per epoch
average_loss = total_loss / len(self.trainloader)
self.logger.log_message(f"Epoch {self.current_iteration}: Loss = {average_loss}")

# Extract accumulated gradients
gradients = [grad.clone().cpu().numpy() for grad in gradient_accumulator if grad is not None]

# Increment the iteration count
self.current_iteration += 1
dice_score = self.calculate_dice(self.trainloader)
self.logger.log_message(f"{self.site_name} - Epoch {self.current_epoch}: Loss = {average_loss}, Dice = {dice_score}")

return gradients

# Return the gradients after completing the specified aggregation interval
self.logger.log_message(f"{self.site_name} Performing aggregation after epoch {self.current_epoch}")
gradients = [grad.clone().cpu().numpy() for grad in gradient_accumulator if grad is not None]


# # Increment the iteration count
#-- self.current_epoch += 1

return gradients

def get_next_train_batch(self):
# Get the next batch of data from the trainloader
for batch_id, (image, label) in enumerate(self.trainloader):
if batch_id == self.current_iteration % len(self.trainloader):
return image, label
def calculate_dice(self, loader):
dice_total = 0.0
for image, label in loader:
image, label = image.to(self.device), label.to(self.device)
with torch.no_grad():
output = self.model(image)
output_label = torch.argmax(output, dim=1)
dice_score = faster_dice(output_label, label.squeeze(1), labels=[0, 1, 2])
dice_total += dice_score.mean().item()
return dice_total / len(loader)

def apply_gradients(self, aggregated_gradients):
def apply_gradients(self, aggregated_gradients, fl_ctx):
# Apply aggregated gradients to the model parameters
with torch.no_grad():
for param, grad in zip(self.model.parameters(), aggregated_gradients):
Expand All @@ -169,4 +269,15 @@ def apply_gradients(self, aggregated_gradients):
torch.cuda.empty_cache()

# Log the gradient application step
self.logger.log_message("Aggregated gradients applied to the model.")
self.logger.log_message(f"{self.site_name} Aggregated gradients applied to the model.")

# Get the output directory path
output_dir = get_output_directory_path(fl_ctx)

# Save the model
model_save_path = os.path.join(output_dir, f"model_epoch_{self.current_epoch}.pth")
torch.save(self.model.state_dict(), model_save_path)

# Log the model saving step
self.logger.log_message(f"Model saved at {model_save_path}")
self.current_epoch += 1
Loading

0 comments on commit fc67e95

Please sign in to comment.