-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Port clade_data_utils function to get clade list for forecasting
There are a few changes (for example, sorting by date and clade before slicing on the "n" clades that create the final list). This changeset also adds test cases for various permutations of threshold, threshold_weeks, and the maximum number of clades allowed in the list being returned.
- Loading branch information
Showing
7 changed files
with
212 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
"""Get a list of SARS-CoV-2 clades.""" | ||
|
||
import os | ||
import time | ||
from datetime import timedelta | ||
|
||
import polars as pl | ||
import structlog | ||
from cloudpathlib import AnyPath | ||
|
||
from virus_clade_utils.util.config import Config | ||
from virus_clade_utils.util.sequence import ( | ||
download_covid_genome_metadata, | ||
filter_covid_genome_metadata, | ||
get_clade_counts, | ||
get_covid_genome_metadata, | ||
) | ||
|
||
logger = structlog.get_logger() | ||
|
||
|
||
def get_clades(clade_counts: pl.LazyFrame, threshold: float, threshold_weeks: int, max_clades: int) -> list[str]: | ||
"""Get a list of clades to forecast based.""" | ||
start = time.perf_counter() | ||
|
||
# based on the data's most recent date, get the week start three weeks ago (not including this week) | ||
max_day = clade_counts.select(pl.max("date")).collect().item() | ||
threshold_sundays_ago = max_day - timedelta(days=max_day.weekday() + 7 * (threshold_weeks)) | ||
|
||
# sum over weeks, combine states, and limit to just the past 3 weeks (not including current week) | ||
lf = ( | ||
clade_counts.filter(pl.col("date") >= threshold_sundays_ago) | ||
.sort("date") | ||
.group_by_dynamic("date", every="1w", start_by="sunday", group_by="clade") | ||
.agg(pl.col("count").sum()) | ||
) | ||
|
||
# create a separate frame with the total counts per week | ||
total_counts = lf.group_by("date").agg(pl.col("count").sum().alias("total_count")) | ||
|
||
# join with count data to add a total counts per day column | ||
prop_dat = lf.join(total_counts, on="date").with_columns( | ||
(pl.col("count") / pl.col("total_count")).alias("proportion") | ||
) | ||
|
||
# retrieve list of variants which have crossed the threshold over the past threshold_weeks | ||
high_prev_variants = prop_dat.filter(pl.col("proportion") > threshold).select("clade").unique().collect() | ||
|
||
# if more than the specified number of clades cross the threshold, | ||
# take the clades with the largest counts over the past threshold_weeks | ||
# (if there's a tie, take the first clade alphabetically) | ||
if len(high_prev_variants) > max_clades: | ||
high_prev_variants = ( | ||
prop_dat.group_by("clade") | ||
.agg(pl.col("count").sum()) | ||
.sort("count", "clade", descending=[True, False]) | ||
.collect() | ||
) | ||
|
||
variants = high_prev_variants.get_column("clade").to_list()[:max_clades] | ||
|
||
end = time.perf_counter() | ||
elapsed = end - start | ||
logger.info("generated clade list", elapsed=elapsed) | ||
|
||
return variants | ||
|
||
|
||
# FIXME: provide ability to instantiate Config for the get_clade_list function and get the data_path from there | ||
def main( | ||
genome_metadata_path: AnyPath = Config.nextstrain_latest_genome_metadata, | ||
data_dir: AnyPath = AnyPath(".").home() / "covid_variant", | ||
threshold: float = 0.01, | ||
threshold_weeks: int = 3, | ||
max_clades: int = 9, | ||
) -> list[str]: | ||
""" | ||
Determine list of clades to model | ||
Parameters | ||
---------- | ||
genome_metadata_path : AnyPath | ||
Path to location of the most recent genome metadata file published by Nextstrain | ||
data_dir : AnyPath | ||
Path to the location where the genome metadata file is saved after download. | ||
clade_counts : polars.LazyFrame | ||
Clade counts by date and location, summarized from Nextstrain metadata | ||
threshold : float | ||
Clades that account for at least ``threshold`` proportion of reported | ||
sequences are candidates for inclusion. | ||
threshold_weeks : int | ||
The number of weeks that we look back to identify clades. | ||
max_clades : int | ||
The maximum number of clades to include in the list. | ||
Returns | ||
------- | ||
list of strings | ||
""" | ||
|
||
os.makedirs(data_dir, exist_ok=True) | ||
genome_metadata_path = download_covid_genome_metadata( | ||
genome_metadata_path, | ||
data_dir, | ||
) | ||
lf_metadata = get_covid_genome_metadata(genome_metadata_path) | ||
lf_metadata_filtered = filter_covid_genome_metadata(lf_metadata) | ||
counts = get_clade_counts(lf_metadata_filtered) | ||
clade_list = get_clades(counts, threshold, threshold_weeks, max_clades) | ||
|
||
return clade_list | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
genbank_accession genbank_accession_rev unwanted_column date host country division clade_nextstrain location another unwanted column | ||
abc abc.1 i ❤️ wombats 2024-09-01 Homo sapiens USA Massachusetts AA.ZZ Vulcan hummus a tune | ||
abc abc.1 i ❤️ wombats 2024-09-01 Homo sapiens USA Massachusetts AA.ZZ Vulcan hummus a tune | ||
def def.1 i ❤️ wombats 2024-09-01 Homo sapiens USA Massachusetts AA.ZZ Earth hummus a tune | ||
ghi ghi.4 i ❤️ wombats 2024-09-01 Homo sapiens USA Utah BB Cardassia hummus a tune | ||
jkl jkl.1 i ❤️ wombats 2024-09-01 Homo sapiens USA Utah CC Bajor hummus a tune | ||
mno mno.1 i ❤️ wombats 2024-09-01 Homo sapiens Canada Alberta DD Vulcan hummus a tune | ||
mno mno.1 i ❤️ wombats 2024-09-01 marmots USA Massachusetts DD Vulcan hummus a tune | ||
mno mno.1 i ❤️ wombats 2024-09-01 Homo sapiens USA Puerto Rico DD Reisa hummus a tune | ||
abc abc.1 i ❤️ wombats 2024-09-08 Homo sapiens USA Massachusetts EE Vulcan hummus a tune | ||
abc abc.1 i ❤️ wombats 2024-09-08 Homo sapiens USA Massachusetts EE Vulcan hummus a tune | ||
def def.1 i ❤️ wombats 2024-09-08 Homo sapiens USA Massachusetts DD Earth hummus a tune | ||
ghi ghi.4 i ❤️ wombats 2024-09-08 Homo sapiens USA Utah AA Cardassia hummus a tune | ||
jkl jkl.1 i ❤️ wombats 2024-09-08 Homo sapiens USA Utah AA.ZZ Bajor hummus a tune | ||
abc abc.1 i ❤️ wombats 2024-09-15 Homo sapiens USA Massachusetts AA Vulcan hummus a tune | ||
abc abc.1 i ❤️ wombats 2024-09-15 Homo sapiens USA Massachusetts AA Vulcan hummus a tune | ||
def def.1 i ❤️ wombats 2024-09-15 Homo sapiens USA Massachusetts AA Earth hummus a tune | ||
ghi ghi.4 i ❤️ wombats 2024-09-15 Homo sapiens USA Utah BB Cardassia hummus a tune | ||
jkl jkl.1 i ❤️ wombats 2024-09-15 Homo sapiens USA Utah CC Bajor hummus a tune | ||
mno mno.1 i ❤️ wombats 2024-09-15 Homo sapiens Canada Mississippi DD Earth hummus a tune | ||
mno mno.1 i ❤️ wombats 2024-09-15 marmots USA Massachusetts DD Cardassia hummus a tune | ||
mno mno.1 i ❤️ wombats 2024-09-15 Homo sapiens USA Puerto Rico DD Bajor hummus a tune | ||
abcd abcd.1 i ❤️ wombats 2024-09-22 Homo sapiens USA Massachusetts FF Vulcan hummus a tune | ||
abc abc.1 i ❤️ wombats 2024-09-22 Homo sapiens USA Massachusetts AA Vulcan hummus a tune | ||
def def.1 i ❤️ wombats 2024-09-22 Homo sapiens USA Massachusetts AA Earth hummus a tune | ||
ghi ghi.4 i ❤️ wombats 2024-09-22 Homo sapiens USA Utah BB Cardassia hummus a tune | ||
jkl jkl.1 i ❤️ wombats 2024-09-22 Homo sapiens USA Utah CC Bajor hummus a tune | ||
mno mno.1 i ❤️ wombats 2024-09-22 Homo sapiens Canada Mississippi FF Earth hummus a tune | ||
mno mno.1 i ❤️ wombats 2024-09-22 marmots USA Massachusetts FF Cardassia hummus a tune | ||
mno mno.1 i ❤️ wombats 2024-09-22 Homo sapiens USA Puerto Rico FF Bajor hummus a tune |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
from pathlib import Path | ||
from unittest.mock import MagicMock, patch | ||
|
||
import pytest | ||
from virus_clade_utils.get_clade_list import main | ||
|
||
|
||
@pytest.fixture | ||
def test_file_path() -> Path: | ||
""" | ||
Return path to the unit test files. | ||
""" | ||
test_file_path = Path(__file__).parents[1].joinpath("data") | ||
return test_file_path | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"threshold, weeks, max_clades, expected_list", | ||
[ | ||
(0.1, 3, 9, ["AA", "AA.ZZ", "BB", "CC", "DD", "EE", "FF"]), | ||
(0.3, 3, 9, ["AA", "AA.ZZ", "EE"]), | ||
(0.1, 2, 9, ["AA.ZZ", "AA", "BB", "CC", "DD", "EE", "FF"]), | ||
(0.1, 1, 9, ["AA", "BB", "CC", "FF"]), | ||
(0.3, 1, 9, ["AA"]), | ||
(0.1, 3, 4, ["AA", "AA.ZZ", "BB", "CC"]), | ||
(0.3, 3, 2, ["AA", "AA.ZZ"]), | ||
(0.1, 2, 3, ["AA", "BB", "CC"]), | ||
(0.1, 1, 3, ["AA", "BB", "CC"]), | ||
(1, 3, 9, []), | ||
], | ||
) | ||
def test_clade_list(test_file_path, tmp_path, threshold, weeks, max_clades, expected_list): | ||
test_genome_metadata = test_file_path / "test_metadata.tsv" | ||
mock = MagicMock(return_value=test_genome_metadata, name="genome_metadata_download_mock") | ||
|
||
with patch("virus_clade_utils.get_clade_list.download_covid_genome_metadata", mock): | ||
actual_list = main(test_genome_metadata, tmp_path, threshold, weeks, max_clades) | ||
|
||
assert set(expected_list) == set(actual_list) |