Skip to content

Commit

Permalink
bin and configs update
Browse files Browse the repository at this point in the history
  • Loading branch information
jiangsizhu1201 authored Jan 31, 2025
1 parent dab5978 commit bde7144
Show file tree
Hide file tree
Showing 22 changed files with 629 additions and 388 deletions.
6 changes: 4 additions & 2 deletions bin/create_dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,8 +260,10 @@ def main():
"scRNA_barcodes_UMI_thresholds.png" : "svg/barplot.svg",
"guides_UMI_thresholds.png": "svg/barplot.svg",
"guides_hist_num_sgRNA.png": "svg/barplot.svg",
"network_plot.png": "svg/network.svg",
"volcano_plot.png": "svg/volcano.svg",
"sceptre_network_plot.png": "svg/network.svg",
"perturbo_network_plot.png": "svg/network.svg",
"sceptre_volcano_plot.png": "svg/volcano.svg",
"perturbo_volcano_plot.png": "svg/volcano.svg",
"guides_per_cell_histogram.png": "svg/barplot.svg",
"cells_per_guide_histogram.png": "svg/barplot.svg",
"cells_per_hto_barplot.png": "svg/barplot.svg",
Expand Down
6 changes: 4 additions & 2 deletions bin/create_dashboard_HASHING.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,8 +260,10 @@ def main():
"scRNA_barcodes_UMI_thresholds.png" : "svg/barplot.svg",
"guides_UMI_thresholds.png": "svg/barplot.svg",
"guides_hist_num_sgRNA.png": "svg/barplot.svg",
"network_plot.png": "svg/network.svg",
"volcano_plot.png": "svg/volcano.svg",
"sceptre_network_plot.png": "svg/network.svg",
"perturbo_network_plot.png": "svg/network.svg",
"sceptre_volcano_plot.png": "svg/volcano.svg",
"perturbo_volcano_plot.png": "svg/volcano.svg",
"guides_per_cell_histogram.png": "svg/barplot.svg",
"cells_per_guide_histogram.png": "svg/barplot.svg",
"cells_per_hto_barplot.png": "svg/barplot.svg",
Expand Down
39 changes: 33 additions & 6 deletions bin/create_dashboard_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,15 +171,42 @@ def create_dashboard_df(guide_fq_tbl, mudata_path, gene_ann_path, filtered_ann_p
mean_cells_per_guide = np.mean(cells_per_guide)

iv_highlight = f"Mean guides per cell: {human_format(mean_guides_per_cell)}, Mean cells per guide: {human_format(mean_cells_per_guide)}"
inf_img_df = new_block('Inference', '', 'Visualization', iv_highlight, True,
image = ['evaluation_output/network_plot.png', 'evaluation_output/volcano_plot.png'],
image_description= ['Gene interaction networks of selected genes.', 'Volcano Plot.'])

# Collect existing network plots
network_plots = []
network_descs = []

if os.path.exists('evaluation_output/sceptre_network_plot.png'):
network_plots.append('evaluation_output/sceptre_network_plot.png')
network_descs.append('Sceptre network plot')

if os.path.exists('evaluation_output/perturbo_network_plot.png'):
network_plots.append('evaluation_output/perturbo_network_plot.png')
network_descs.append('Perturbo network plot')

# Collect existing volcano plots
volcano_plots = []
volcano_descs = []

if os.path.exists('evaluation_output/sceptre_volcano_plot.png'):
volcano_plots.append('evaluation_output/sceptre_volcano_plot.png')
volcano_descs.append('Sceptre volcano plot')

if os.path.exists('evaluation_output/perturbo_volcano_plot.png'):
volcano_plots.append('evaluation_output/perturbo_volcano_plot.png')
volcano_descs.append('Perturbo volcano plot')

# Combine network + volcano
all_plots = network_plots + volcano_plots
all_descs = network_descs + volcano_descs

inf_img_df = new_block('Inference', '', 'Visualization', iv_highlight, True, image=all_plots, image_description=all_descs)

### check guide seqspec check df
guide_check_df = new_block("Guide", '', 'Fastq Overview', '', False, table = guide_fq_table,
table_description='Summary of Sequence Index: A summary of the positions where the Guide starts are mapped on the reads (Use to inspect or calibrate the position where the guide is supposed to be found in your SeqSpec File)',
image = ['guide_seqSpec_plots/seqSpec_check_plots.png'],
image_description= ['The frequency of each nucleotides along the Read 1 (Use to inspect the expected read parts with their expected signature) and Read 2 (Use to inspect the expected read parts with their expected signature)'])
table_description='Summary of Sequence Index: A summary of the positions where the Guide starts are mapped on the reads (Use to inspect or calibrate the position where the guide is supposed to be found in your SeqSpec File)',
image = ['guide_seqSpec_plots/seqSpec_check_plots.png'],
image_description= ['The frequency of each nucleotides along the Read 1 (Use to inspect the expected read parts with their expected signature) and Read 2 (Use to inspect the expected read parts with their expected signature)'])

return guide_check_df, cell_stats, gene_stats, rna_img_df, guide_img_df, gi_df, gs_img_df, inf_img_df

Expand Down
45 changes: 36 additions & 9 deletions bin/create_dashboard_df_HASHING.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,9 +183,36 @@ def create_dashboard_df(guide_fq_tbl, hashing_fq_tbl, mudata_path, gene_ann_path
mean_cells_per_guide = np.mean(cells_per_guide)

iv_highlight = f"Mean guides per cell: {human_format(mean_guides_per_cell)}, Mean cells per guide: {human_format(mean_cells_per_guide)}"
inf_img_df = new_block('Inference', '', 'Visualization', iv_highlight, True,
image = ['evaluation_output/network_plot.png', 'evaluation_output/volcano_plot.png'],
image_description= ['Gene interaction networks of selected genes.', 'Volcano Plot.'])

# Collect existing network plots
network_plots = []
network_descs = []

if os.path.exists('evaluation_output/sceptre_network_plot.png'):
network_plots.append('evaluation_output/sceptre_network_plot.png')
network_descs.append('Sceptre network plot')

if os.path.exists('evaluation_output/perturbo_network_plot.png'):
network_plots.append('evaluation_output/perturbo_network_plot.png')
network_descs.append('Perturbo network plot')

# Collect existing volcano plots
volcano_plots = []
volcano_descs = []

if os.path.exists('evaluation_output/sceptre_volcano_plot.png'):
volcano_plots.append('evaluation_output/sceptre_volcano_plot.png')
volcano_descs.append('Sceptre volcano plot')

if os.path.exists('evaluation_output/perturbo_volcano_plot.png'):
volcano_plots.append('evaluation_output/perturbo_volcano_plot.png')
volcano_descs.append('Perturbo volcano plot')

# Combine network + volcano
all_plots = network_plots + volcano_plots
all_descs = network_descs + volcano_descs

inf_img_df = new_block('Inference', '', 'Visualization', iv_highlight, True, image=all_plots, image_description=all_descs)

### Create hashing demultiplex df
non_multiplet_count = hashing_demux.obs[hashing_demux.obs['hto_type_split'] != 'multiplets'].shape[0]
Expand All @@ -198,15 +225,15 @@ def create_dashboard_df(guide_fq_tbl, hashing_fq_tbl, mudata_path, gene_ann_path

### check guide seqspec check df
guide_check_df = new_block("Guide", '', 'Fastq Overview', '', False, table = guide_fq_table,
table_description='Summary of Sequence Index: A summary of the positions where the Guide starts are mapped on the reads (Use to inspect or calibrate the position where the guide is supposed to be found in your SeqSpec File)',
image = ['guide_seqSpec_plots/seqSpec_check_plots.png'],
image_description= ['The frequency of each nucleotides along the Read 1 (Use to inspect the expected read parts with their expected signature) and Read 2 (Use to inspect the expected read parts with their expected signature)'])
table_description='Summary of Sequence Index: A summary of the positions where the Guide starts are mapped on the reads (Use to inspect or calibrate the position where the guide is supposed to be found in your SeqSpec File)',
image = ['guide_seqSpec_plots/seqSpec_check_plots.png'],
image_description= ['The frequency of each nucleotides along the Read 1 (Use to inspect the expected read parts with their expected signature) and Read 2 (Use to inspect the expected read parts with their expected signature)'])

### check hashing seqspec check df
hashing_check_df = new_block("Hashing", '', 'Fastq Overview', '', False, table = hashing_fq_table,
table_description='Summary of Sequence Index: A summary of the positions where the Hashtag starts are mapped on the reads (Use to inspect or calibrate the position where the hashtag is supposed to be found in your SeqSpec File)',
image = ['hashing_seqSpec_plots/seqSpec_check_plots.png'],
image_description= ['The frequency of each nucleotides along the Read 1 (Use to inspect the expected read parts with their expected signature )and Read 2 (Use to inspect the expected read parts with their expected signature)'])
table_description='Summary of Sequence Index: A summary of the positions where the Hashtag starts are mapped on the reads (Use to inspect or calibrate the position where the hashtag is supposed to be found in your SeqSpec File)',
image = ['hashing_seqSpec_plots/seqSpec_check_plots.png'],
image_description= ['The frequency of each nucleotides along the Read 1 (Use to inspect the expected read parts with their expected signature )and Read 2 (Use to inspect the expected read parts with their expected signature)'])


return guide_check_df, hashing_check_df, cell_stats, gene_stats, rna_img_df, guide_img_df, gi_df, gs_img_df, inf_img_df, hs_demux_df
Expand Down
169 changes: 122 additions & 47 deletions bin/create_dashboard_plots_HASHING.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import anndata as ad
import numpy as np
from umap import UMAP
from sklearn.decomposition import PCA
import math

def plot_umi_threshold(mudata, save_dir):
Expand Down Expand Up @@ -195,90 +196,164 @@ def plot_umap_HTOs(mudata, save_dir):
axes = np.array([axes])

for i, batch in enumerate(batches):
ax = axes[i]
batch_mask = demux.obs['batch'] == batch
batch_data = demux.X[batch_mask]
num = batch_data.shape[0]
umap = UMAP(n_neighbors=150, min_dist=0.2, n_components=2, spread=1.5,random_state=42)
umap_result = umap.fit_transform(batch_data)
demux.obs.loc[batch_mask, ['UMAP1', 'UMAP2']] = umap_result

ax = axes[i]
sc = ax.scatter(demux.obs.loc[batch_mask, 'UMAP1'],
demux.obs.loc[batch_mask, 'UMAP2'],
c=demux.obs.loc[batch_mask, 'hto_color'],
alpha=0.7, s=1)

ax.set_title(f'{batch}, number of cells: {num}')
ax.set_xlabel('UMAP1')
ax.set_ylabel('UMAP2')

# Limit PCA components for this batch
n_samples = batch_data.shape[0]
n_features = batch_data.shape[1]
# At least 2, but not more than min(n_samples, n_features)
n_pca_components = max(2, min(n_samples, n_features) - 1)

# Re-initialize PCA and UMAP for this batch
pca = PCA(n_components=n_pca_components)
umap_model = UMAP(
n_neighbors=150,
min_dist=0.2,
n_components=2,
spread=1.5,
random_state=42
)

# Fit PCA and then fit UMAP on the PCA result
pca_result = pca.fit_transform(batch_data)
umap_result = umap_model.fit_transform(pca_result)

# Store coordinates in obs (for demonstration)
demux.obs.loc[batch_mask, 'UMAP1'] = umap_result[:, 0]
demux.obs.loc[batch_mask, 'UMAP2'] = umap_result[:, 1]

# -- 5) Scatter plot
ax.scatter(
demux.obs.loc[batch_mask, 'UMAP1'],
demux.obs.loc[batch_mask, 'UMAP2'],
c=demux.obs.loc[batch_mask, 'hto_color'],
alpha=0.7,
s=1
)

# Title, labels, and text about variance explained
num_cells_batch = batch_data.shape[0]
ax.set_title(f"{batch}, number of cells: {num_cells_batch}")
ax.set_xlabel("UMAP1")
ax.set_ylabel("UMAP2")

# -- 6) Remove any empty subplots if batch count < rows*cols
for j in range(i + 1, len(axes)):
fig.delaxes(axes[j])

handles = [plt.Line2D([0], [0], marker='o', color=color, markersize=12, linestyle='None') for color in colors]
labels = unique_hto_types
fig.legend(handles, labels, loc='center left', bbox_to_anchor=(0.9, 0.5), title='HTO Type', fontsize='large', title_fontsize='x-large')

#plt.tight_layout(rect=[0, 0, 0.9, 1])
# -- 7) Add a legend on the right
handles = [plt.Line2D([0], [0], marker='o', color=hto_color_map[k],
markersize=8, linestyle='None')
for k in unique_hto_types]
fig.legend(
handles,
[str(k) for k in unique_hto_types],
loc='center left',
bbox_to_anchor=(0.9, 0.5),
title='HTO Type'
)

plt.tight_layout(rect=[0, 0, 0.85, 1])

plot_path = os.path.join(save_dir, 'umap_hto.png')
plt.savefig(plot_path, dpi=300)
plt.close()

def plot_umap_HTOs_singlets(mudata, save_dir):
## remove multiplets

# -- 1) Subset to remove multiplets
demux = mudata['hashing']
demux_s = demux[demux.obs['hto_type_split'] != "multiplets"]
demux_s = demux[demux.obs['hto_type_split'] != "multiplets"].copy()

# -- 2) Set up colors for each unique HTO type
unique_hto_types = demux_s.obs['hto_type_split'].cat.categories
color_palette = plt.get_cmap('tab20')
colors = [color_palette(i) for i in range(len(unique_hto_types))]
hto_color_map = dict(zip(unique_hto_types, colors))
demux_s.obs['hto_color'] = [hto_color_map[hto_type] for hto_type in demux_s.obs['hto_type_split']]

# -- 3) Prepare subplots (rows x columns)
batches = demux_s.obs['batch'].unique()

num_batches = len(batches)
columns = 2 if num_batches > 1 else 1
cols = 2 if num_batches > 1 else 1
rows = math.ceil(num_batches / 2)
fig, axes = plt.subplots(rows, columns, figsize=(18, 8 * rows))

fig, axes = plt.subplots(rows, cols, figsize=(18, 8 * rows))
if isinstance(axes, np.ndarray):
axes = axes.flatten()
else:
axes = np.array([axes])

# -- 4) For each batch, do PCA + UMAP
for i, batch in enumerate(batches):
ax = axes[i]
batch_mask = demux_s.obs['batch'] == batch
batch_data = demux_s.X[batch_mask]
num = batch_data.shape[0]

umap = UMAP(n_neighbors=150, min_dist=0.2, n_components=2, spread=1.5,random_state=42)

umap_result = umap.fit_transform(batch_data)
demux_s.obs.loc[batch_mask, ['UMAP1', 'UMAP2']] = umap_result

ax = axes[i]
sc = ax.scatter(demux_s.obs.loc[batch_mask, 'UMAP1'],
demux_s.obs.loc[batch_mask, 'UMAP2'],
c=demux_s.obs.loc[batch_mask, 'hto_color'],
alpha=0.7, s=1)

ax.set_title(f'{batch}, number of cells: {num}')
ax.set_xlabel('UMAP1')
ax.set_ylabel('UMAP2')

# Limit PCA components for this batch
n_samples = batch_data.shape[0]
n_features = batch_data.shape[1]
# At least 2, but not more than min(n_samples, n_features)
n_pca_components = max(2, min(n_samples, n_features) - 1)

# Re-initialize PCA and UMAP for this batch
pca = PCA(n_components=n_pca_components)
umap_model = UMAP(
n_neighbors=150,
min_dist=0.2,
n_components=2,
spread=1.5,
random_state=42
)

# Fit PCA and then fit UMAP on the PCA result
pca_result = pca.fit_transform(batch_data)
umap_result = umap_model.fit_transform(pca_result)

# Store coordinates in obs (for demonstration)
demux_s.obs.loc[batch_mask, 'UMAP1'] = umap_result[:, 0]
demux_s.obs.loc[batch_mask, 'UMAP2'] = umap_result[:, 1]

# -- 5) Scatter plot
ax.scatter(
demux_s.obs.loc[batch_mask, 'UMAP1'],
demux_s.obs.loc[batch_mask, 'UMAP2'],
c=demux_s.obs.loc[batch_mask, 'hto_color'],
alpha=0.7,
s=1
)

# Title, labels, and text about variance explained
num_cells_batch = batch_data.shape[0]
ax.set_title(f"{batch}, number of cells: {num_cells_batch}")
ax.set_xlabel("UMAP1")
ax.set_ylabel("UMAP2")

# -- 6) Remove any empty subplots if batch count < rows*cols
for j in range(i + 1, len(axes)):
fig.delaxes(axes[j])

handles = [plt.Line2D([0], [0], marker='o', color=color, markersize=12, linestyle='None') for color in colors]
labels = unique_hto_types
fig.legend(handles, labels, loc='center left', bbox_to_anchor=(0.9, 0.5), title='HTO Type', fontsize='large', title_fontsize='x-large')

#plt.tight_layout(rect=[0, 0, 0.9, 1])
# -- 7) Add a legend on the right
handles = [plt.Line2D([0], [0], marker='o', color=hto_color_map[k],
markersize=8, linestyle='None')
for k in unique_hto_types]
fig.legend(
handles,
[str(k) for k in unique_hto_types],
loc='center left',
bbox_to_anchor=(0.9, 0.5),
title='HTO Type'
)

plt.tight_layout(rect=[0, 0, 0.85, 1])

# -- 8) Save the figure
plot_path = os.path.join(save_dir, 'umap_hto_singlets.png')
plt.savefig(plot_path, dpi=300)
plt.savefig(plot_path, dpi=300, bbox_inches='tight')
plt.close()


def main():
parser = argparse.ArgumentParser(description="Generate various plots from MuData")
parser.add_argument('--mudata', required=True, help='Path to the mudata object')
Expand Down
11 changes: 5 additions & 6 deletions bin/create_mdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,21 @@
import pandas as pd
import numpy as np
from muon import MuData
from gtfparse import read_gtf
import matplotlib.pyplot as plt
import os
from GTFProcessing import GTFProcessing

def main(adata_rna, adata_guide, guide_metadata, gtf, moi):
# Load the data
guide_metadata = pd.read_csv(guide_metadata, sep='\t')
adata_rna = ad.read_h5ad(adata_rna)
adata_guide = ad.read_h5ad(adata_guide)
gtf = GTFProcessing(gtf)
df_gtf = gtf.get_gtf_df()
df_gtf = read_gtf(gtf).to_pandas()

# add targeting_chr, start, end to the targeting elements
gene_map_df = df_gtf.groupby('gene_name')[['chr', 'start', 'end']].first()
gene_map_df = df_gtf.groupby('gene_name').first()[['seqname', 'start', 'end']]
guide_metadata = guide_metadata.merge(gene_map_df, how='left', left_on='targeting', right_index=True)
guide_metadata = guide_metadata.rename(columns={'chr': 'intended_target_chr', 'start': 'intended_target_start', 'end': 'intended_target_end'})
guide_metadata = guide_metadata.rename(columns={'seqname': 'intended_target_chr', 'start': 'intended_target_start', 'end': 'intended_target_end'})

## change in adata_guide
# adding var for guide
Expand Down Expand Up @@ -66,7 +65,7 @@ def main(adata_rna, adata_guide, guide_metadata, gtf, moi):
df_gtf_copy = df_gtf.copy()
df_gtf_copy.set_index('gene_id2', inplace=True)
# adding gene_start, gene_end, gene_chr
adata_rna.var = adata_rna.var.join(df_gtf_copy[['chr', 'start', 'end']].rename(columns={'chr': 'gene_chr',
adata_rna.var = adata_rna.var.join(df_gtf_copy[['seqname', 'start', 'end']].rename(columns={'seqname': 'gene_chr',
'start': 'gene_start',
'end': 'gene_end'}))

Expand Down
Loading

0 comments on commit bde7144

Please sign in to comment.