diff --git a/jsl/lds/kalman_filter.py b/jsl/lds/kalman_filter.py index 81c6ad0..4982ec3 100644 --- a/jsl/lds/kalman_filter.py +++ b/jsl/lds/kalman_filter.py @@ -2,7 +2,7 @@ # Author: Gerardo Durán-Martín (@gerdm), Aleyna Kara(@karalleyna) from jax import config -config.update('jax_default_matmul_precision', 'float32') +config.update("jax_default_matmul_precision", "float32") import chex import jax.numpy as jnp @@ -158,7 +158,9 @@ def kalman_step(state, obs, params): Sigma_cond = A @ Sigma @ A.T + Q # \mu_{t |t-1} and xn|{n-1} - mu_cond = A @ mu + + mu_cond = A @ mu + mu_cond = mu_cond + params.get_state_offset_of(t) Ct = params.get_obs_mat_of(t) R = params.get_observation_noise_of(t) @@ -166,7 +168,9 @@ def kalman_step(state, obs, params): St = Ct @ Sigma_cond @ Ct.T + R Kt = solve(St, Ct @ Sigma_cond, sym_pos=True).T - mu = mu_cond + Kt @ (obs - Ct @ mu_cond) + innovation = Ct @ mu_cond + innovation = innovation + params.get_obs_offset_of(t) + mu = mu_cond + Kt @ (obs - innovation) # More stable solution is (I − KtCt)Σt|t−1(I − KtCt)T + KtRtKTt tmp = (I - Kt @ Ct)