Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Mmasoud1 committed Sep 23, 2024
1 parent 3ec4f00 commit 8410053
Show file tree
Hide file tree
Showing 14 changed files with 635 additions and 404 deletions.
Binary file modified app/code/executor/__pycache__/meshnet_executor.cpython-38.pyc
Binary file not shown.
69 changes: 46 additions & 23 deletions app/code/executor/meshnet_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,31 +10,38 @@
from .loader import Scanloader # Import the Scanloader for MRI data
from .dist import GenericLogger # Import GenericLogger
import torch.cuda.amp as amp
from torch.utils.checkpoint import checkpoint # For layer checkpointing

class MeshNetExecutor(Executor):
def __init__(self):
super().__init__()
# Initialize the MeshNet model
# Model Initialization: The MeshNet model is initialized with input channels, number of classes, and the configuration file (modelAE.json).
# Construct the absolute path to the modelAE.json file
# Model Initialization
config_file_path = os.path.join(os.path.dirname(__file__), "modelAE.json")
self.model = enMesh_checkpoint(in_channels=1, n_classes=3, channels=5, config_file=config_file_path)
self.model = enMesh_checkpoint(in_channels=1, n_classes=3, channels=1, config_file=config_file_path)

# self.model = MeshNet(in_channels=1, n_classes=3, channels=32, config_file="modelAE.json")
# Check if GPU availabel
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(self.device)
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001)

# Ensure model parameters require gradients
for param in self.model.parameters():
param.requires_grad = True

# Optimizer and criterion setup
self.optimizer = torch.optim.RMSprop(self.model.parameters(), lr=0.001)
# I guess adam also can be used
self.criterion = torch.nn.CrossEntropyLoss()

# Initialize data loader (assuming mindboggle.db is the database file)
# We load the MRI data using the Scanloader class from loader.py, which reads data from an SQLite database.
# amp for mixed precision to overcome memory limitation for now
self.scaler = amp.GradScaler()

# Data Loader with min batch size to save memory
db_file_path = os.path.join(os.path.dirname(__file__), "mindboggle.db")
self.data_loader = Scanloader(db_file=db_file_path, label_type='GWlabels', num_cubes=1)
self.trainloader, self.validloader, self.testloader = self.data_loader.get_loaders()

# Initializes the logger to write logs to a file named meshnet_executor.log.
self.logger = GenericLogger(log_file_path='meshnet_executor.log')
self.trainloader, self.validloader, self.testloader = self.data_loader.get_loaders(batch_size=1)

# Logger can be found for example with: MeshDist_nvflare/simulator_workspace/simulate_job/app_site-1 and app_site-2
self.logger = GenericLogger(log_file_path='meshnet_executor.log')
self.current_iteration = 0

def execute(
Expand All @@ -46,38 +53,52 @@ def execute(
) -> Shareable:

if task_name == "train_and_get_gradients":
# Perform local training and return gradients
# This function trains the model on a single batch of data and returns the gradients.
gradients = self.train_and_get_gradients()
outgoing_shareable = Shareable()
outgoing_shareable["gradients"] = gradients
return outgoing_shareable

elif task_name == "accept_aggregated_gradients":
# Accept aggregated gradients and apply them to the model
aggregated_gradients = shareable["aggregated_gradients"]

# Applies the aggregated gradients from the central node to update the model.
self.apply_gradients(aggregated_gradients)
return Shareable()

def train_and_get_gradients(self):
# Perform one iteration of training and return the gradients
self.model.train()
image, label = self.get_next_train_batch()
image, label = image.to(self.device), label.to(self.device)

# Ensure input requires gradients
image.requires_grad = True

self.optimizer.zero_grad()
output = self.model(image)
loss = self.criterion(output, label)
loss.backward()

# Log loss and training information
# Mixed precision and checkpointing
with amp.autocast():
# Ensure checkpoint works with requires_grad
output = self.model(image) # To avoid using checkpointing for now

# Fix shape dim and ensure label is in long type
label = label.squeeze(1)
loss = self.criterion(output, label.long())

# Scale the loss before backward pass with amp
self.scaler.scale(loss).backward()

# Update the optimizer with scaled gradients
self.scaler.step(self.optimizer)
self.scaler.update()

# Log loss
self.logger.log_message(f"Iteration {self.current_iteration}: Loss = {loss.item()}")

# Extract gradients
gradients = [param.grad.clone().cpu().numpy() for param in self.model.parameters()]
gradients = [param.grad.clone().cpu().numpy() for param in self.model.parameters() if param.grad is not None]
self.current_iteration += 1

# Clear GPU memory cache to free memory
torch.cuda.empty_cache()

return gradients

def get_next_train_batch(self):
Expand All @@ -93,6 +114,8 @@ def apply_gradients(self, aggregated_gradients):
param.grad = torch.tensor(grad).to(self.device)
self.optimizer.step()

# Clear GPU memory cache after applying gradients
torch.cuda.empty_cache()

# Log the gradient application step
self.logger.log_message("Aggregated gradients applied to the model.")
104 changes: 104 additions & 0 deletions app/code/executor/meshnet_executor_backup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import torch
import os
from nvflare.apis.executor import Executor
from nvflare.apis.fl_constant import FLContextKey
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable
from nvflare.apis.signal import Signal
from .meshnet import MeshNet # Import the MeshNet model
from .meshnet import enMesh_checkpoint
from .loader import Scanloader # Import the Scanloader for MRI data
from .dist import GenericLogger # Import GenericLogger
import torch.cuda.amp as amp

class MeshNetExecutor(Executor):
def __init__(self):
super().__init__()
# Initialize the MeshNet model
# Model Initialization: The MeshNet model is initialized with input channels, number of classes, and the configuration file (modelAE.json).
# Construct the absolute path to the modelAE.json file
config_file_path = os.path.join(os.path.dirname(__file__), "modelAE.json")
self.model = enMesh_checkpoint(in_channels=1, n_classes=3, channels=1, config_file=config_file_path)

# self.model = MeshNet(in_channels=1, n_classes=3, channels=32, config_file="modelAE.json")
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(self.device)
self.optimizer = torch.optim.RMSprop(self.model.parameters(), lr=0.001)
self.criterion = torch.nn.CrossEntropyLoss()

# Initialize data loader (assuming mindboggle.db is the database file)
# We load the MRI data using the Scanloader class from loader.py, which reads data from an SQLite database.
db_file_path = os.path.join(os.path.dirname(__file__), "mindboggle.db")
self.data_loader = Scanloader(db_file=db_file_path, label_type='GWlabels', num_cubes=1)
self.trainloader, self.validloader, self.testloader = self.data_loader.get_loaders()

# Initializes the logger to write logs to a file named meshnet_executor.log.
self.logger = GenericLogger(log_file_path='meshnet_executor.log')

self.current_iteration = 0

def execute(
self,
task_name: str,
shareable: Shareable,
fl_ctx: FLContext,
abort_signal: Signal,
) -> Shareable:

if task_name == "train_and_get_gradients":
# Perform local training and return gradients
# This function trains the model on a single batch of data and returns the gradients.
gradients = self.train_and_get_gradients()
outgoing_shareable = Shareable()
outgoing_shareable["gradients"] = gradients
return outgoing_shareable

elif task_name == "accept_aggregated_gradients":
# Accept aggregated gradients and apply them to the model
aggregated_gradients = shareable["aggregated_gradients"]

# Applies the aggregated gradients from the central node to update the model.
self.apply_gradients(aggregated_gradients)
return Shareable()

def train_and_get_gradients(self):
# Perform one iteration of training and return the gradients
self.model.train()
image, label = self.get_next_train_batch()
image, label = image.to(self.device), label.to(self.device)

self.optimizer.zero_grad()
output = self.model(image)

# Fix the shape mismatch
label = label.squeeze(1) # Remove the extra channel dimension
loss = self.criterion(output, label.long())
# label.long()) ensures the labels are cast to the Long type, which is required by the CrossEntropyLoss function
loss.backward()

# Log loss and training information
self.logger.log_message(f"Iteration {self.current_iteration}: Loss = {loss.item()}")

# Extract gradients
gradients = [param.grad.clone().cpu().numpy() for param in self.model.parameters()]
self.current_iteration += 1
return gradients



def get_next_train_batch(self):
# Get the next batch of data from the trainloader
for batch_id, (image, label) in enumerate(self.trainloader):
if batch_id == self.current_iteration % len(self.trainloader):
return image, label

def apply_gradients(self, aggregated_gradients):
# Apply aggregated gradients to the model parameters
with torch.no_grad():
for param, grad in zip(self.model.parameters(), aggregated_gradients):
param.grad = torch.tensor(grad).to(self.device)
self.optimizer.step()


# Log the gradient application step
self.logger.log_message("Aggregated gradients applied to the model.")
2 changes: 1 addition & 1 deletion app/config/config_fed_server.json
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@
}
}
]
}
}
2 changes: 1 addition & 1 deletion jobs/job/app/config/config_fed_server.json
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@
}
}
]
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@
}
}
]
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@
}
}
]
}
}
Loading

0 comments on commit 8410053

Please sign in to comment.