Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update README.md #1

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Byte-compiled / optimized / DLL files
.DS_Store

__pycache__/
*.py[cod]
*$py.class
Expand Down
15 changes: 14 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1 +1,14 @@
# zooms_classifier
# ZooMS Classifier Documentation

## Model Training and Explainability Guide

> **Note:** The project Google Drive access is essential for using this Colab notebook.

[![Run on Google Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1Ljos45IErs819W3ynY5-u9ach85y9J9G?usp=sharing)

### Local Usage Instructions

To run the notebook locally, either:

- Download the `.ipynb` file directly from Google Colab
- Obtain the notebook from the GitHub repository at [ZooMS Classifier on GitHub](https://github.com/mlcolab/zooms_classifier/blob/vadims_branch/ZooMS_1DCNN_model_explainability.ipynb)
12,648 changes: 12,648 additions & 0 deletions ZooMS_1DCNN_model_explainability.ipynb

Large diffs are not rendered by default.

16 changes: 16 additions & 0 deletions windows_app/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# ZooMS Windows Classifer App


Please use pyinstaller to compile the app on a Windows machine.

```
MyApp/
|-- model/
| |-- weights.pth # 1DCNN weights (PyTorch)
|-- src/
| |-- main.py
| |-- model.py
| |-- file_ops.py
| |-- gui.py
|-- requirements.txt
```
Binary file added windows_app/model/model.pth
Binary file not shown.
3 changes: 3 additions & 0 deletions windows_app/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
pandas >= 1.4.4
torch >= 1.10.2
pyinstaller >= 5.13.2
8 changes: 8 additions & 0 deletions windows_app/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from cx_Freeze import setup, Executable

setup(
name="AI_ZooMS",
version="0.0.1",
description="Find homininis with AI",
executables=[Executable("src/main.py")]
)
45 changes: 45 additions & 0 deletions windows_app/src/file_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import pandas as pd
import os
from datetime import datetime

def read_csv_files(directory):
file_paths = []
file_names = []
for filename in os.listdir(directory):
if filename.endswith('.csv'):
filepath = os.path.join(directory, filename)
file_paths.append(filepath)
file_names.append(filename)
return file_paths, file_names



def save_results(results, output_directory, file_names, targets):
# Column names
column_names = [
'Canidae', 'Cervidae', 'CervidaeGazellaSaiga', 'Ovis', 'Equidae',
'CrocutaPanthera', 'BisonYak', 'Capra', 'Ursidae', 'Vulpes vulpes',
'Elephantidae', 'Others', 'Rhinocerotidae', 'Rangifer tarandus', 'Hominins'
]

# Updated the dataframe creation line to handle numpy arrays
concatenated_df = pd.concat([pd.DataFrame(result) for result in results], ignore_index=True)

concatenated_df.columns = column_names
concatenated_df['Most Probable Class'] = [column_names[i] for i in targets]

# Reorder columns to place 'Most Probable Class' as the second column
cols = ['Most Probable Class'] + [col for col in concatenated_df.columns if col != 'Most Probable Class']
concatenated_df = concatenated_df[cols]

# Insert file names as the first column
concatenated_df.insert(0, 'File Name', file_names)

# Get current date and time
current_datetime = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')

# Create a unique filename with current date and time
output_path = os.path.join(output_directory, f'results_{current_datetime}.csv')

# Save the concatenated dataframe to the unique output path
concatenated_df.to_csv(output_path, index=False)
92 changes: 92 additions & 0 deletions windows_app/src/gui.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from PyQt5.QtWidgets import QApplication, QWidget, QVBoxLayout, QHBoxLayout, QLabel, QLineEdit, QPushButton, QFileDialog
from PyQt5.QtGui import QIcon, QFont

def create_gui(main_func):
app = QApplication([])

window = QWidget()
window.setWindowTitle("ML-Based Mass Spectra Species Identifier")
window.setWindowIcon(QIcon('icon.png'))
window.setFixedSize(700, 350)

layoutV = QVBoxLayout()
font_label = QFont("Arial", 16, QFont.Bold)
font_button = QFont("Arial", 14)

layoutV.setContentsMargins(40, 40, 40, 40)

layoutH3 = QHBoxLayout()
model_file_path = QLineEdit(placeholderText="Select your ML model file...")
model_file_path.setFont(font_button)
browse_model_btn = QPushButton("Select Model")
browse_model_btn.setFont(font_button)
label3 = QLabel("Model File:")
label3.setFont(font_label)
layoutH3.addWidget(label3)
layoutH3.addWidget(model_file_path)
layoutH3.addWidget(browse_model_btn)
layoutV.addLayout(layoutH3)

layoutH1 = QHBoxLayout()
input_directory = QLineEdit(placeholderText="Select a directory with CSV files...")
input_directory.setFont(font_button)
browse_input_btn = QPushButton("Browse")
browse_input_btn.setFont(font_button)
label1 = QLabel("Input Directory:")
label1.setFont(font_label)
layoutH1.addWidget(label1)
layoutH1.addWidget(input_directory)
layoutH1.addWidget(browse_input_btn)
layoutV.addLayout(layoutH1)

layoutH2 = QHBoxLayout()
output_directory = QLineEdit(placeholderText="Select a directory for results...")
output_directory.setFont(font_button)
browse_output_btn = QPushButton("Browse")
browse_output_btn.setFont(font_button)
label2 = QLabel("Output Directory:")
label2.setFont(font_label)
layoutH2.addWidget(label2)
layoutH2.addWidget(output_directory)
layoutH2.addWidget(browse_output_btn)
layoutV.addLayout(layoutH2)

classify_btn = QPushButton("Classify")
classify_btn.setFont(font_button)
layoutV.addWidget(classify_btn)

window.setStyleSheet("""
QWidget {
background-color: #fafafa;
font-size: 18px;
color: #333;
}
QPushButton {
background-color: #11a611; /* Green */
color: white;
border: none;
border-radius: 10px;
padding: 14px 28px;
}
QPushButton:pressed {
background-color: #005900; /* Darker green on click */
}
QLineEdit {
background-color: #fff;
border: 1px solid #ccc;
border-radius: 10px;
padding: 14px;
}
""")

browse_model_btn.clicked.connect(lambda: model_file_path.setText(QFileDialog.getOpenFileName()[0]))
browse_input_btn.clicked.connect(lambda: input_directory.setText(QFileDialog.getExistingDirectory()))
browse_output_btn.clicked.connect(lambda: output_directory.setText(QFileDialog.getExistingDirectory()))
classify_btn.clicked.connect(lambda: main_func(model_file_path.text(), input_directory.text(), output_directory.text()))

window.setLayout(layoutV)
window.show()
app.exec_()

if __name__ == "__main__":
create_gui(main)
74 changes: 74 additions & 0 deletions windows_app/src/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from file_ops import read_csv_files, save_results
from gui import create_gui
import torch
from PyQt5.QtWidgets import QMessageBox
import numpy as np
import pandas as pd
import os
from model import CNN1D

def mean_intensity(temp_df, bin_resolution=0.5):
bins = np.arange(899.9, 3500, bin_resolution)
temp_df['bin'] = pd.cut(temp_df['mass'], bins=bins)
return temp_df.groupby('bin')['intensity'].mean().values

def normalize(tensor):
tensor[torch.isnan(tensor)] = 0
mean = tensor.mean()
std = tensor.std()
return (tensor - mean) / (std + torch.finfo(torch.float32).eps)

def load_model(weight_path):
model = torch.load(weight_path, map_location=torch.device('cpu'))
model.eval()
return model

def main(model_file_path, input_directory, output_directory):
if not model_file_path or not input_directory or not output_directory:
show_missing_paths_message()
return

model = load_model(model_file_path)
file_paths, file_names = read_csv_files(input_directory)
results, file_names, targets = make_predictions(file_paths, model)
save_results(results, output_directory, file_names, targets)
show_done_message()

def show_missing_paths_message():
msg = QMessageBox()
msg.setIcon(QMessageBox.Critical)
msg.setWindowTitle("Missing Path")
msg.setText("Please provide all required paths (Model, Input Directory, Output Directory).")
msg.exec_()

def show_done_message():
msg = QMessageBox()
msg.setIcon(QMessageBox.Information)
msg.setWindowTitle("Process Completed")
msg.setText("The classification process is complete.")
msg.exec_()

def make_predictions(file_paths, model):
results = []
file_names = []
targets = []
for i, file_path in enumerate(file_paths):
temp_df = pd.read_csv(file_path)
file_name = os.path.basename(file_path)
file_names.append(file_name)

intensities = mean_intensity(temp_df)
tensor_data = torch.tensor(intensities, dtype=torch.float32)
tensor_data = normalize(tensor_data)

output = model(tensor_data.unsqueeze(0).unsqueeze(0))
probabilities = torch.softmax(output, dim=1).detach().numpy().round(3)
results.append(probabilities)

target = np.argmax(probabilities)
targets.append(target)

return results, file_names, targets

if __name__ == "__main__":
create_gui(main)
43 changes: 43 additions & 0 deletions windows_app/src/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import torch
import torch.nn as nn

class CNN1D(nn.Module):
def __init__(self, input_size, num_classes):
super(CNN1D, self).__init__()

self.conv1 = nn.Conv1d(1, 32, kernel_size=5, stride=1)

self.conv2 = nn.Conv1d(32, 64, kernel_size=5, stride=1)
#self.bn2 = nn.BatchNorm1d(32)

self.pool = nn.AvgPool1d(kernel_size=3)

output_size = (input_size - 5 + 1) // 3 # After conv1 and pool
output_size = (output_size - 5 + 1) // 3 # After conv2 and pool

self.fc1 = nn.Linear(64 * output_size, 128)
self.dropout1 = nn.Dropout(0.25)

self.fc2 = nn.Linear(128, num_classes)

self.relu = nn.ReLU()

def forward(self, x):
x = self.pool(self.relu(self.conv1(x)))
x = self.pool(self.relu(self.conv2(x)))

x = x.view(x.size(0), -1)

x = self.relu(self.fc1(x))
x = self.dropout1(x)

x = self.fc2(x)
return x


def load_model(weight_path):
model = torch.load(weight_path, map_location=torch.device('cpu'))
model.eval()
return model


84 changes: 84 additions & 0 deletions windows_app/test.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"id": "55bd4cb8-4bf1-48f8-bcb3-11375f947e31",
"metadata": {},
"outputs": [],
"source": [
"import tkinter as tk\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "ab248bcb-4c77-442b-a12d-9bb89965d029",
"metadata": {},
"outputs": [],
"source": [
"from cx_Freeze import setup, Executable\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "bc3e8b53-753d-4d63-9f31-ae485850acbd",
"metadata": {},
"outputs": [],
"source": [
"import torch, pandas "
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "74fb7e63-4b89-43ce-a98d-06871d43f416",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"('1.12.1+cu116', '1.5.0')"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.__version__, pandas.__version__"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "20a31617-7550-424c-8119-fd76d66c1cfa",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}