Skip to content

Commit

Permalink
Add get_config to BruteForceRetrieval layer.
Browse files Browse the repository at this point in the history
For consistency with other layers. Note that this does not serialize candidates.
  • Loading branch information
hertschuh committed Jan 27, 2025
1 parent 163707b commit 41826cd
Showing 1 changed file with 14 additions and 0 deletions.
14 changes: 14 additions & 0 deletions keras_rs/src/layers/retrieval/brute_force_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ class BruteForceRetrieval(keras.layers.Layer):
The identifiers for the candidates can be specified as a tensor. If not
provided, the IDs used are simply the candidate indices.
Note that the serialization of this layer does not preserve the candidates
and only saves the `k` and `return_scores` arguments. One has to call
`update_candidates` after deserializing the layers.
Args:
candidate_embeddings: The candidate embeddings. If `None`,
candidates must be provided using `update_candidates` before
Expand Down Expand Up @@ -183,3 +187,13 @@ def compute_score(
return keras.ops.matmul(
query_embedding, keras.ops.transpose(candidate_embedding)
)

def get_config(self) -> dict[str, Any]:
config: dict[str, Any] = super().get_config()
config.update(
{
"k": self.k,
"return_scores": self.compute_score,
}
)
return config

0 comments on commit 41826cd

Please sign in to comment.