Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Faster filtering of sparce matrix #3465

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 82 additions & 10 deletions src/scanpy/preprocessing/_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import numba
import numpy as np
import scipy as sp
from anndata import AnnData
from pandas.api.types import CategoricalDtype
from scipy.sparse import csc_matrix, csr_matrix, issparse
Expand Down Expand Up @@ -55,6 +56,36 @@
A = TypeVar("A", bound=np.ndarray | _CSMatrix | DaskArray)


@njit()
def get_rows_to_keep_1(indptr, data):
lens = np.zeros(len(indptr) - 1, dtype=type(data[0]))
for i in range(len(lens)):
lens[i] = np.sum(data[indptr[i] : indptr[i + 1]])
return lens

Check warning on line 64 in src/scanpy/preprocessing/_simple.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/preprocessing/_simple.py#L61-L64

Added lines #L61 - L64 were not covered by tests


def get_rows_to_keep(indptr, dtype):
lens = indptr[1:] - indptr[:-1]
# lens.astype(dtype)
return lens


@njit()
def get_cols_to_keep(indices, data, colcount, nthr, Flag):
counts = np.zeros((nthr, colcount), dtype=type(data[0]))
for i in numba.prange(nthr):
start = i * indices.shape[0] // nthr
end = (i + 1) * indices.shape[0] // nthr
for j in range(start, end):
if data[j] != 0: # and indices[j]<colcount:
if Flag:
counts[i, indices[j]] += 1

Check warning on line 82 in src/scanpy/preprocessing/_simple.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/preprocessing/_simple.py#L75-L82

Added lines #L75 - L82 were not covered by tests
else:
counts[i, indices[j]] += data[j] # 1
counts = np.sum(counts, axis=0, dtype=type(data[0]))
return counts # , keep_cols

Check warning on line 86 in src/scanpy/preprocessing/_simple.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/preprocessing/_simple.py#L84-L86

Added lines #L84 - L86 were not covered by tests


@old_positionals(
"min_counts", "min_genes", "max_counts", "max_genes", "inplace", "copy"
)
Expand Down Expand Up @@ -173,11 +204,31 @@
X = data # proceed with processing the data matrix
min_number = min_counts if min_genes is None else min_genes
max_number = max_counts if max_genes is None else max_genes
number_per_cell = axis_sum(
X if min_genes is None and max_genes is None else X > 0, axis=1
)
if isinstance(X, sp.sparse._csr.csr_matrix):
if min_genes is None and max_genes is None:
number_per_cell = get_rows_to_keep_1(X.indptr, X.data)
else:
number_per_cell = get_rows_to_keep(X.indptr, type(X.data[0]))

elif isinstance(X, sp.sparse._csc.csc_matrix):
nclos = X.shape[0]
nthr = numba.get_num_threads()
if min_genes is None and max_genes is None:
number_per_cell = get_cols_to_keep(
X.indices, X.data, nclos, nthr, Flag=False
)
else:
number_per_cell = get_cols_to_keep(
X.indices, X.data, nclos, nthr, Flag=True
)

else:
number_per_cell = axis_sum(
X if min_genes is None and max_genes is None else X > 0, axis=1
)

if issparse(X):
number_per_cell = number_per_cell.A1
number_per_cell = number_per_cell
if min_number is not None:
cell_subset = number_per_cell >= min_number
if max_number is not None:
Expand Down Expand Up @@ -291,11 +342,33 @@
X = data # proceed with processing the data matrix
min_number = min_counts if min_cells is None else min_cells
max_number = max_counts if max_cells is None else max_cells
number_per_gene = axis_sum(
X if min_cells is None and max_cells is None else X > 0, axis=0
)

if isinstance(X, sp.sparse._csr.csr_matrix):
ncols = X.shape[1]
nthr = numba.get_num_threads()
if min_cells is None and max_cells is None:
number_per_gene = get_cols_to_keep(
X.indices, X.data, ncols, nthr, Flag=False
)
else:
number_per_gene = get_cols_to_keep(
X.indices, X.data, ncols, nthr, Flag=True
)

elif isinstance(X, sp.sparse._csc.csc_matrix):
if min_cells is None and max_cells is None:
number_per_gene = get_rows_to_keep_1(X.indptr, X.data)
else:
number_per_gene = get_rows_to_keep(X.indptr, type(X.data[0]))
else:
number_per_gene = axis_sum(
X if min_cells is None and max_cells is None else X > 0, axis=0
)
if issparse(X):
number_per_gene = number_per_gene.A1
if isinstance(X, (sp.sparse._csr.csr_matrix | sp.sparse._csc.csc_matrix)):
number_per_gene = number_per_gene
else:
number_per_gene = number_per_gene

Check warning on line 371 in src/scanpy/preprocessing/_simple.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/preprocessing/_simple.py#L371

Added line #L371 was not covered by tests
if min_number is not None:
gene_subset = number_per_gene >= min_number
if max_number is not None:
Expand Down Expand Up @@ -379,7 +452,6 @@
@log1p.register(np.ndarray)
def log1p_array(X: np.ndarray, *, base: Number | None = None, copy: bool = False):
# Can force arrays to be np.ndarrays, but would be useful to not
# X = check_array(X, dtype=(np.float64, np.float32), ensure_2d=False, copy=copy)
if copy:
X = X.astype(float) if not np.issubdtype(X.dtype, np.floating) else X.copy()
elif not (np.issubdtype(X.dtype, np.floating) or np.issubdtype(X.dtype, complex)):
Expand Down Expand Up @@ -1129,7 +1201,7 @@


# TODO: can/should this be parallelized?
@numba.njit(cache=True) # noqa: TID251
@njit() # noqa: TID251
def _downsample_array(
col: np.ndarray,
target: int,
Expand Down
Loading