Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hungarian Algorithm Infinite Loop with NaN Input #1179

Open
m-sdr opened this issue Jan 21, 2025 · 4 comments
Open

Hungarian Algorithm Infinite Loop with NaN Input #1179

m-sdr opened this issue Jan 21, 2025 · 4 comments

Comments

@m-sdr
Copy link

m-sdr commented Jan 21, 2025

The Hungarian algorithm implementation in optax.assignment.hungarian_algorithm enters an infinite loop when the input cost matrix contains only NaN (Not a Number) values.

Reproducible Example:

import optax
import jax.numpy as jnp

cost_matrix_nan = jnp.nan * jnp.array([
    [1, 2, 4.],
    [4, 5, 6],
    [7, 8, 9]
])

row_indices, col_indices = optax.assignment.hungarian_algorithm(cost_matrix_nan)
@mblondel
Copy link
Collaborator

To be checked if #1140 fixes this issue or not. @carlosgmartin

@carlosgmartin
Copy link
Contributor

carlosgmartin commented Jan 22, 2025

It's not clear to me that this is a bug, since (to my mind) it does not really make sense to ask what's the solution to an assignment problem which consists of or contains nan. What would be the correct output?

(One can always apply nan_to_num to the input matrix to get rid of its nans before passing it to hungarian_algorithm, since the output won't matter anyway.)

@mblondel
Copy link
Collaborator

But I think we should not allow the algorithm to enter an infinite loop. The ideal would have been to do error reporting, by returning a status variable (succeeded or failed) but this changes the function signature.

What would be the correct output?

Maybe the identity permutation?

An alternative would be to add the nan_to_num call inside hungarian_algorithm instead.

@carlosgmartin
Copy link
Contributor

Actually, it looks like the newer version of the algorithm in #1140 doesn't loop forever on nans. So this should be fixed after that is merged.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants