Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Benchmark dask #3319

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion benchmarks/asv.conf.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
// The version of the config file format. Do not change, unless
// you know what you are doing.
"version": 1,
"timeout": 1200,
"install_timeout": 1200,

// The name of the project being benchmarked
"project": "scanpy",
Expand Down Expand Up @@ -43,11 +45,12 @@
// If missing or the empty string, the tool will be automatically
// determined by looking for tools on the PATH environment
// variable.
"timeout": 1200, // Timeout for each benchmark in seconds
"environment_type": "conda",

// timeout in seconds for installing any dependencies in environment
// defaults to 10 min
//"install_timeout": 600,
"install_timeout": 1200,

// the base URL to show a commit for the project.
"show_commit_url": "https://github.com/scverse/scanpy/commit/",
Expand Down Expand Up @@ -86,6 +89,9 @@
"pooch": [""],
"scikit-image": [""],
// "scikit-misc": [""],
"scikit-learn": [""],
"pip+asv_runner": [""],
"dask": [""],
},

// Combinations of libraries/python versions can be excluded/included
Expand Down
62 changes: 62 additions & 0 deletions benchmarks/benchmarks/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,64 @@ def lung93k() -> AnnData:
return _lung93k().copy()


@cache
def _musmus_11m() -> AnnData:
# Define the path to the dataset
path = "/sc/arion/projects/psychAD/mikaela/scanpy/scanpy/benchmarks/data/MusMus_4M_cells_cellxgene.h5ad"
adata = sc.read_h5ad(path)
# assert isinstance(adata.X, sparse.csr_matrix)
# Add counts layer
# adata.layers["counts"] = adata.X.astype(np.int32, copy=True)
sc.pp.log1p(adata)
return adata


def musmus_11m() -> AnnData:
return _musmus_11m().copy()


@cache
def _large_synthetic_dataset(
n_obs: int = 500_000, n_vars: int = 5_000, density: float = 0.01
) -> AnnData:
"""
Generate a synthetic dataset suitable for Dask testing.

Parameters:
n_obs: int
Number of observations (rows, typically cells).
n_vars: int
Number of variables (columns, typically genes).
density: float
Fraction of non-zero entries in the sparse matrix.

Returns:
AnnData
The synthetic dataset.
"""

X = sparse.random(
n_obs, n_vars, density=density, format="csr", dtype=np.float32, random_state=42
)
obs = {"obs_names": [f"cell_{i}" for i in range(n_obs)]}
var = {"var_names": [f"gene_{j}" for j in range(n_vars)]}
adata = anndata.AnnData(X=X, obs=obs, var=var)
adata.layers["counts"] = X.copy()
sc.pp.log1p(adata)
adata.var["mt"] = adata.var_names.str.startswith("MT-")
sc.pp.calculate_qc_metrics(
adata, qc_vars=["mt"], percent_top=None, log1p=False, inplace=True
)

return adata


def large_synthetic_dataset(
n_obs: int = 500_000, n_vars: int = 5_000, density: float = 0.01
) -> AnnData:
return _large_synthetic_dataset(n_obs, n_vars, density).copy()


def to_off_axis(x: np.ndarray | sparse.csr_matrix) -> np.ndarray | sparse.csc_matrix:
if isinstance(x, sparse.csr_matrix):
return x.tocsc()
Expand All @@ -138,6 +196,10 @@ def _get_dataset_raw(dataset: Dataset) -> tuple[AnnData, str | None]:
adata, batch_key = bmmc(400), "sample"
case "lung93k":
adata, batch_key = lung93k(), "PatientNumber"
case "large_synthetic":
adata, batch_key = large_synthetic_dataset(), None
case "musmus_11m":
adata, batch_key = musmus_11m(), None
case _:
msg = f"Unknown dataset {dataset}"
raise AssertionError(msg)
Expand Down
6 changes: 4 additions & 2 deletions benchmarks/benchmarks/preprocessing_counts.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ def setup(dataset: Dataset, layer: KeyCount, *_):
# ASV suite

params: tuple[list[Dataset], list[KeyCount]] = (
["pbmc68k_reduced", "pbmc3k"],
["pbmc3k", "pbmc68k_reduced", "bmmc", "musmus_11m"],
# ["pbmc3k", "pbmc68k_reduced", "bmmc", "lung93k", "large_synthetic", "musmus_11m"],
["counts", "counts-off-axis"],
)
param_names = ["dataset", "layer"]
Expand Down Expand Up @@ -78,7 +79,8 @@ class FastSuite:
"""Suite for fast preprocessing operations."""

params: tuple[list[Dataset], list[KeyCount]] = (
["pbmc3k", "pbmc68k_reduced", "bmmc", "lung93k"],
# ["pbmc3k", "pbmc68k_reduced", "bmmc", "lung93k", "large_synthetic", "musmus_11m"],
["pbmc3k", "pbmc68k_reduced", "bmmc", "musmus_11m"],
["counts", "counts-off-axis"],
)
param_names = ["dataset", "layer"]
Expand Down
213 changes: 213 additions & 0 deletions benchmarks/benchmarks/preprocessing_counts_dask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import dask.array as dd
from dask.distributed import Client, LocalCluster

import scanpy as sc

from ._utils import get_count_dataset

if TYPE_CHECKING:
from anndata import AnnData

from ._utils import Dataset, KeyCount

# Setup global variables
adata: AnnData
batch_key: str | None


def setup(dataset: Dataset, layer: KeyCount, *_):
"""Setup global variables before each benchmark."""
global adata, batch_key
adata, batch_key = get_count_dataset(dataset, layer=layer)
assert "log1p" not in adata.uns


# def setup_dask_cluster():
# """Set up a local Dask cluster for benchmarking."""
# cluster = LocalCluster(n_workers=4, threads_per_worker=2)
# client = Client(cluster)
# return client


def setup_dask_cluster():
"""Set up a local Dask cluster for benchmarking."""
cluster = LocalCluster(
n_workers=5, threads_per_worker=2, memory_limit="60GB", timeout="1200s"
)
client = Client(cluster)
return client


# ASV suite
params: tuple[list[Dataset], list[KeyCount]] = (
["musmus_11m"],
["counts", "counts-off-axis"],
)
param_names = ["dataset", "layer"]


### Dask-Based Benchmarks ###
def time_filter_cells_dask(*_):
client = setup_dask_cluster()
try:
optimal_chunks = (
adata.X.shape[0] // (4 * len(client.nthreads())),
adata.X.shape[1],
)
adata.X = dd.from_array(adata.X, chunks=optimal_chunks).persist()
sc.pp.filter_cells(adata, min_genes=100)
assert adata.n_obs > 0 # Ensure cells are retained
finally:
client.close()


# def time_filter_cells_dask(*_):
# client = setup_dask_cluster()
# try:
# # Compute optimal chunks based on Dask cluster
# optimal_chunks = (adata.X.shape[0] // (4 * len(client.nthreads())), adata.X.shape[1])
# adata.X = dd.from_array(adata.X, chunks=optimal_chunks)
# adata.X = adata.X.persist() # Persist to avoid recomputation
# sc.pp.filter_cells(adata, min_genes=100)
# finally:
# client.close()


def peakmem_filter_cells_dask(*_):
client = setup_dask_cluster()
try:
optimal_chunks = (
adata.X.shape[0] // (4 * len(client.nthreads())),
adata.X.shape[1],
)
adata.X = dd.from_array(adata.X, chunks=optimal_chunks)
sc.pp.filter_cells(adata, min_genes=100)
finally:
client.close()


def time_filter_genes_dask(*_):
client = setup_dask_cluster()
try:
optimal_chunks = (
adata.X.shape[0] // (4 * len(client.nthreads())),
adata.X.shape[1],
)
adata.X = dd.from_array(adata.X, chunks=optimal_chunks)
adata.X = adata.X.persist()
sc.pp.filter_genes(adata, min_cells=3)
finally:
client.close()


def peakmem_filter_genes_dask(*_):
client = setup_dask_cluster()
try:
optimal_chunks = (
adata.X.shape[0] // (4 * len(client.nthreads())),
adata.X.shape[1],
)
adata.X = dd.from_array(adata.X, chunks=optimal_chunks)
sc.pp.filter_genes(adata, min_cells=3)
finally:
client.close()


### General Dask and Non-Dask Preprocessing Benchmarks ###


class FastSuite:
"""Suite for benchmarking preprocessing operations with Dask."""

params: tuple[list[Dataset], list[KeyCount]] = (
["musmus_11m"],
["counts", "counts-off-axis"],
)
param_names = ["dataset", "layer"]

def time_calculate_qc_metrics_dask(self, *_):
client = setup_dask_cluster()
try:
optimal_chunks = (
adata.X.shape[0] // (4 * len(client.nthreads())),
adata.X.shape[1],
)
adata.X = dd.from_array(adata.X, chunks=optimal_chunks)
adata.X = adata.X.persist()
sc.pp.calculate_qc_metrics(
adata, qc_vars=["mt"], percent_top=None, log1p=False, inplace=True
)
finally:
client.close()

def peakmem_calculate_qc_metrics_dask(self, *_):
client = setup_dask_cluster()
try:
optimal_chunks = (
adata.X.shape[0] // (4 * len(client.nthreads())),
adata.X.shape[1],
)
adata.X = dd.from_array(adata.X, chunks=optimal_chunks)
sc.pp.calculate_qc_metrics(
adata, qc_vars=["mt"], percent_top=None, log1p=False, inplace=True
)
finally:
client.close()

def time_normalize_total_dask(self, *_):
client = setup_dask_cluster()
try:
optimal_chunks = (
adata.X.shape[0] // (4 * len(client.nthreads())),
adata.X.shape[1],
)
adata.X = dd.from_array(adata.X, chunks=optimal_chunks)
adata.X = adata.X.map_blocks(
lambda x: x / x.sum(axis=1), dtype=float
) # Optimize normalization
adata.X = adata.X.persist()
finally:
client.close()

def peakmem_normalize_total_dask(self, *_):
client = setup_dask_cluster()
try:
optimal_chunks = (
adata.X.shape[0] // (4 * len(client.nthreads())),
adata.X.shape[1],
)
adata.X = dd.from_array(adata.X, chunks=optimal_chunks)
sc.pp.normalize_total(adata, target_sum=1e4)
finally:
client.close()

def time_log1p_dask(self, *_):
client = setup_dask_cluster()
try:
adata.uns.pop("log1p", None)
optimal_chunks = (
adata.X.shape[0] // (4 * len(client.nthreads())),
adata.X.shape[1],
)
adata.X = dd.from_array(adata.X, chunks=optimal_chunks)
adata.X = adata.X.persist()
sc.pp.log1p(adata)
finally:
client.close()

def peakmem_log1p_dask(self, *_):
client = setup_dask_cluster()
try:
adata.uns.pop("log1p", None)
optimal_chunks = (
adata.X.shape[0] // (4 * len(client.nthreads())),
adata.X.shape[1],
)
adata.X = dd.from_array(adata.X, chunks=optimal_chunks)
sc.pp.log1p(adata)
finally:
client.close()
Loading