From 41826cd42cd1ee23b83add9bff3a7ec68320b8b4 Mon Sep 17 00:00:00 2001 From: Fabien Hertschuh <1091026+hertschuh@users.noreply.github.com> Date: Mon, 27 Jan 2025 10:30:41 -0800 Subject: [PATCH] Add `get_config` to `BruteForceRetrieval` layer. For consistency with other layers. Note that this does not serialize candidates. --- .../src/layers/retrieval/brute_force_retrieval.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/keras_rs/src/layers/retrieval/brute_force_retrieval.py b/keras_rs/src/layers/retrieval/brute_force_retrieval.py index 658a8fb..4a0c5b5 100644 --- a/keras_rs/src/layers/retrieval/brute_force_retrieval.py +++ b/keras_rs/src/layers/retrieval/brute_force_retrieval.py @@ -22,6 +22,10 @@ class BruteForceRetrieval(keras.layers.Layer): The identifiers for the candidates can be specified as a tensor. If not provided, the IDs used are simply the candidate indices. + Note that the serialization of this layer does not preserve the candidates + and only saves the `k` and `return_scores` arguments. One has to call + `update_candidates` after deserializing the layers. + Args: candidate_embeddings: The candidate embeddings. If `None`, candidates must be provided using `update_candidates` before @@ -183,3 +187,13 @@ def compute_score( return keras.ops.matmul( query_embedding, keras.ops.transpose(candidate_embedding) ) + + def get_config(self) -> dict[str, Any]: + config: dict[str, Any] = super().get_config() + config.update( + { + "k": self.k, + "return_scores": self.compute_score, + } + ) + return config