From 1ba1c75c5e0a283ca57aae0e71a883bfc5594e56 Mon Sep 17 00:00:00 2001 From: Louis de Bruijn Date: Wed, 22 Jun 2022 10:19:46 +0200 Subject: [PATCH 1/2] new warning messages --- spark_matcher/matching_base/matching_base.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/spark_matcher/matching_base/matching_base.py b/spark_matcher/matching_base/matching_base.py index 0d1e518..8b2b1e3 100644 --- a/spark_matcher/matching_base/matching_base.py +++ b/spark_matcher/matching_base/matching_base.py @@ -28,15 +28,18 @@ def __init__(self, spark_session: SparkSession, table_checkpointer: Optional[Tab blocking_recall: float = 1.0, n_perfect_train_matches=1, n_train_samples: int = 100_000, ratio_hashed_samples: float = 0.5, scorer: Optional[Scorer] = None, verbose: int = 0): self.spark_session = spark_session - self.table_checkpointer = table_checkpointer if not self.table_checkpointer and checkpoint_dir: self.table_checkpointer = ParquetCheckPointer(self.spark_session, checkpoint_dir, "checkpoint_deduplicator") + elif table_checkpointer and not checkpoint_dir: + self.table_checkpointer = table_checkpointer else: warnings.warn( 'Either `table_checkpointer` or `checkpoint_dir` should be provided. This instance can only be used ' 'when loading a previously saved instance.') - if col_names: + if col_names and field_info: + raise ValueError("Either `col_names` or `field_info` should be provided.") + elif col_names: self.col_names = col_names self.field_info = {col_name: [token_set_ratio, token_sort_ratio] for col_name in self.col_names} From b95e2b5544816a9244723468d9eb1d57aa2bc74d Mon Sep 17 00:00:00 2001 From: Louis de Bruijn Date: Tue, 28 Jun 2022 14:31:50 +0200 Subject: [PATCH 2/2] code review changes --- spark_matcher/matching_base/matching_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spark_matcher/matching_base/matching_base.py b/spark_matcher/matching_base/matching_base.py index 8b2b1e3..5ea1fa1 100644 --- a/spark_matcher/matching_base/matching_base.py +++ b/spark_matcher/matching_base/matching_base.py @@ -30,7 +30,7 @@ def __init__(self, spark_session: SparkSession, table_checkpointer: Optional[Tab self.spark_session = spark_session if not self.table_checkpointer and checkpoint_dir: self.table_checkpointer = ParquetCheckPointer(self.spark_session, checkpoint_dir, "checkpoint_deduplicator") - elif table_checkpointer and not checkpoint_dir: + elif table_checkpointer: self.table_checkpointer = table_checkpointer else: warnings.warn(