Skip to content

Commit

Permalink
clean and fix typo for camera-ready version
Browse files Browse the repository at this point in the history
  • Loading branch information
hoangcaobao committed Sep 4, 2024
1 parent e2fea72 commit 2b8a5e4
Show file tree
Hide file tree
Showing 15 changed files with 45 additions and 331 deletions.
2 changes: 1 addition & 1 deletion Data/DataInit.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ def data_init(cfg_proj, mci_subject, nl_subject, dic_id2feature, df_labels, seed

x_train, y_train, g_train, x_test, y_test, g_test = get_feature_from_id(id_train, id_test, dic_id2feature, df_labels)

return x_train, y_train, g_train, x_test, y_test, g_test
return x_train, y_train, g_train, x_test, y_test, g_test
2 changes: 1 addition & 1 deletion Data/DataPreProcessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ def data_pre_processing(cfg_proj, cfg_m, x_train_raw, y_train, g_train, x_test_r
x_train_raw = x_train_raw[:, :-3]
x_test_raw = x_test_raw[:, :-3]

if cfg_proj.solver not in ["whiting_confounder_solver", "whiting_confounderS_solver"]:
if cfg_proj.solver not in ["confounder_harmonization_solver"]:
g_train, g_test = [g[0] for g in g_train], [g[0] for g in g_test]

scaler = StandardScaler()
Expand Down
2 changes: 1 addition & 1 deletion Moldes/model.py → Models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,4 +209,4 @@ def __getitem__(self, idx):
def kept(self, idx_kept):
self.X = self.X[idx_kept]
self.Y = self.Y[idx_kept]
return self
return self
7 changes: 5 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
Official code for paper: "Subject Harmonization of Digital Biomarkers: Improved Detection of Mild Cognitive Impairment from Language Markers", Bao Hoang, Yijiang Pang, Hiroko H. Dodge, and Jiayu Zhou, PSB 2024.

## Overview

![](pipeline.png)

Mild cognitive impairment (MCI) represents the early stage of dementia including Alzheimer’s disease (AD) and plays a crucial role in developing therapeutic interventions and treatment. Early detection of MCI offers opportunities for early intervention and significantly benefits cohort enrichment for clinical trials. Imaging markers and in vivo markers
in plasma and cerebrospinal fluid biomarkers have high detection performance, and yet their prohibitive costs and intrusiveness demand more affordable and accessible alternatives. The recent advances in digital biomarkers, especially language markers, have shown great potential, where variables informative to MCI are derived from linguistic and/or speech and later
used for predictive modeling. A major challenge in modeling language markers comes from the variability of how each person speaks. As the cohort size for language studies is usually
Expand All @@ -25,11 +28,11 @@ Here we provide several demos of using harminzation commands. Remember to use yo

- **Deep harmonization - subject (Proposed method):**

- Run ```python main.py --solver whiting_solver```
- Run ```python main.py --solver subject_harmonization_solver```

- **Deep harmonization - confounder:**

- Run ```python main.py --solver whiting_confounder_solver```
- Run ```python main.py --solver confounder_harmonization_solver```

- You can change confounder variable using variable ``config.training.confounder_var`` in ``configs/cfg.py``

Expand Down
11 changes: 4 additions & 7 deletions Solvers/Baseline_confounder_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
from Solvers.Solver_Base import Solver_Base
import torch
from torch.utils.data import DataLoader
from Moldes.model import MSE_pytorch, CustomDataset, LR_pytorch
import torch.nn.functional as F
from Models.model import MSE_pytorch, CustomDataset, LR_pytorch

class Baseline_confounder_solver(Solver_Base):

Expand All @@ -12,11 +11,12 @@ def __init__(self, cfg_proj, cfg_m, name = "baseline"):

def run(self, x_train, y_train, g_train, x_test, y_test, g_test, seed):
# Set seed
# self.set_random_seed(seed)
self.set_random_seed(seed)

# train for confounder classifier
epochs = 50
X_confounder, X_test_confounder = x_train[:, -3:], x_test[:, -3:]

for i in range(x_train.shape[-1] - 3):
Y = x_train[:, i]
dataloader_train_c = DataLoader(CustomDataset(X_confounder, Y, g_train), batch_size = self.cfg_m.training.batch_size, drop_last=True, shuffle = True)
Expand Down Expand Up @@ -52,7 +52,6 @@ def run(self, x_train, y_train, g_train, x_test, y_test, g_test, seed):

return auc, f1, sens, spec, auc_sbj, f1_sbj, sens_sbj, spec_sbj


def basic_train_confounder(self, model, dataloader_train, criterion, optimizer, lr_scheduler, epochs):
loss_train_trace = []
for epoch in range(epochs):
Expand All @@ -70,6 +69,4 @@ def basic_train_confounder(self, model, dataloader_train, criterion, optimizer,
optimizer.step()
lr_scheduler.step()
loss_train_trace.append(np.mean(loss_epoch))
return model, loss_train_trace


return model, loss_train_trace
12 changes: 1 addition & 11 deletions Solvers/Solver_Base.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,13 @@
from tqdm import tqdm
import random
import torch
import numpy as np
import logging
import torch.nn.functional as F
from sklearn.metrics import (
precision_recall_fscore_support as prf,
accuracy_score,
roc_auc_score,
)
import os
import torch.nn as nn
import time
import ml_collections
from sklearn import metrics
import pandas as pd
import math


class Solver_Base:

Expand Down Expand Up @@ -186,8 +177,7 @@ def set_random_seed(self, seed):
if torch.cuda.device_count() > 1: torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False

self.seed_current = seed


Expand Down
18 changes: 7 additions & 11 deletions Solvers/Solver_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,19 @@

#solver load
from Solvers.Standard_solver import Standard_solver
from Solvers.whiting_solver import whiting_solver
from Solvers.whiting_confounder_solver import whiting_confounder_solver
from Solvers.whiting_confounderS_solver import whiting_confounderS_solver
from Solvers.subject_harmonization_solver import subject_harmonization_solver
from Solvers.confounder_harmonization_solver import confounder_harmonization_solver
from Solvers.Baseline_confounder_solver import Baseline_confounder_solver

#solver_loader = lambda cfg_proj, cfg_m : getattr(sys.modules[__name__], cfg_proj.solver)(cfg_proj, cfg_m)

def solver_loader(cfg_proj, cfg_m):
if cfg_proj.solver == "Standard_solver":
s = Standard_solver(cfg_proj, cfg_m)
elif cfg_proj.solver == "whiting_solver":
s = whiting_solver(cfg_proj, cfg_m)
elif cfg_proj.solver == "subject_harmonization_solver":
s = subject_harmonization_solver(cfg_proj, cfg_m)
elif cfg_proj.solver == "Baseline_confounder_solver":
s = Baseline_confounder_solver(cfg_proj, cfg_m)
elif cfg_proj.solver == "whiting_confounder_solver":
s = whiting_confounder_solver(cfg_proj, cfg_m)
elif cfg_proj.solver == "whiting_confounderS_solver":
s = whiting_confounderS_solver(cfg_proj, cfg_m)
return s

elif cfg_proj.solver == "confounder_harmonization_solver":
s = confounder_harmonization_solver(cfg_proj, cfg_m)
return s
10 changes: 3 additions & 7 deletions Solvers/Standard_solver.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import numpy as np
from Solvers.Solver_Base import Solver_Base
import torch
from torch.utils.data import DataLoader
from Moldes.model import MLP_pytorch, CustomDataset
import torch.nn.functional as F
from Models.model import MLP_pytorch, CustomDataset

class Standard_solver(Solver_Base):

Expand All @@ -12,7 +10,7 @@ def __init__(self, cfg_proj, cfg_m, name = "Std"):

def run(self, x_train, y_train, g_train, x_test, y_test, g_test, seed):
# Set seed
# self.set_random_seed(seed)
self.set_random_seed(seed)

# Initialize
dataloader_train = DataLoader(CustomDataset(x_train, y_train, g_train), batch_size = self.cfg_m.training.batch_size, drop_last=True, shuffle = True)
Expand All @@ -28,6 +26,4 @@ def run(self, x_train, y_train, g_train, x_test, y_test, g_test, seed):
# Evaluation
auc, f1, sens, spec, auc_sbj, f1_sbj, sens_sbj, spec_sbj = self.eval_func(model, x_test, y_test, g_test)

return auc, f1, sens, spec, auc_sbj, f1_sbj, sens_sbj, spec_sbj


return auc, f1, sens, spec, auc_sbj, f1_sbj, sens_sbj, spec_sbj
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
from tkinter import E
import numpy as np
from Solvers.Solver_Base import Solver_Base
import torch
from torch.utils.data import DataLoader
from Moldes.model import MLP_pytorch, MLP_whiting, CustomDataset
from Models.model import MLP_whiting, CustomDataset
import torch.nn.functional as F
import pickle
from sklearn import metrics
import math


def generalization(values, values_f, num_class):
v_mean = np.mean(values)
Expand All @@ -29,7 +26,7 @@ def generalization(values, values_f, num_class):
return values_g, values_f_g


class whiting_confounder_solver(Solver_Base):
class confounder_harmonization_solver(Solver_Base):

def __init__(self, cfg_proj, cfg_m, name = "white_c"):
Solver_Base.__init__(self, cfg_proj, cfg_m, name)
Expand Down Expand Up @@ -263,6 +260,4 @@ def eval_func(self, model, x_test, y_test, g_test, g_test_sbj):
spec_sbj = tn_sbj/(fp_sbj+tn_sbj)

self.save_results_each_run(auc, f1, sens, spec, auc_sbj, f1_sbj, sens_sbj, spec_sbj)
return auc, f1, sens, spec, auc_sbj, f1_sbj, sens_sbj, spec_sbj


return auc, f1, sens, spec, auc_sbj, f1_sbj, sens_sbj, spec_sbj
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
from Solvers.Solver_Base import Solver_Base
import torch
from torch.utils.data import DataLoader
from Moldes.model import MLP_pytorch, MLP_whiting, CustomDataset
from Models.model import MLP_whiting, CustomDataset
import torch.nn.functional as F
import pickle

class whiting_solver(Solver_Base):
class subject_harmonization_solver(Solver_Base):

def __init__(self, cfg_proj, cfg_m, name = "white"):
Solver_Base.__init__(self, cfg_proj, cfg_m, name)
Expand Down Expand Up @@ -166,4 +166,4 @@ def predict_proba(self, model, X, flag_prob = True):
pred = pred if torch.is_tensor(pred) else pred[1]
pred = torch.nn.functional.softmax(pred, dim = 1)

return pred.detach().cpu().numpy()
return pred.detach().cpu().numpy()
Loading

0 comments on commit 2b8a5e4

Please sign in to comment.