This repository hosts the version of the code used for the publication "Deep Reinforcement Learning for Multi-class Imbalanced Training".
We have tested this implementation using:
- Python version 3.6.9 and Tensorflow version 2.6.2 on a linux OS machine.
- Python version 3.9.2 and Tensorflow version 2.11.0 on a mac OS machine (Big Sur).
To use this branch, you can run the following lines of code:
conda create -n ImbalancedLearningEnv python==3.7
conda activate ImbalancedLearningEnv
git clone https://github.com/yangjenny/ImbalancedLearningRL.git
cd ImbalancedLearningRL
pip install -e .
To run code:
python ImbalancedLearningRL/run.py
(UCI Adult dataset automatically loaded for training)
This example uses the UCI Adult dataset, where one is trying to classify income (two classes: <=50K and >50K). Additional details about the dataset, including all attributes included, can be found here.
After training, performance metrics (auroc,npv,ppv,recall,specificity) and raw prediction results will be saved as csv files in the path. An example run and expected output can be found in example/training_example.ipynb
If you found our work useful, please consider citing:
Yang, J., El-Bouri, R., O'Donoghue, O., Lachapelle, A. S., Soltan, A. A., & Clifton, D. A. (2022). Deep Reinforcement Learning for Multi-class Imbalanced Training. arXiv preprint arXiv:2205.12070.