Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
mmann1123 committed Jan 29, 2024
1 parent 2d45a6b commit 2b5c628
Showing 1 changed file with 15 additions and 18 deletions.
33 changes: 15 additions & 18 deletions xr_fresh/feature_calculator_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,19 +52,12 @@
# potential_predictability
# spearman_correlation
# varweighted_mean_period
# ratio_value_number_to_time_series_length


def _get_jax_backend():
# Get a list of available GPU devices
gpu_devices = [device for device in jax.devices() if device.device_kind == "GPU"]
print("jax running on : ", gpu_devices)
# If there are GPU devices, use 'gpu' as the backend; otherwise, use 'cpu'
return "gpu" if gpu_devices else "cpu"
from jax.lib import xla_bridge


# Set JAX to use the determined backend
jax_backend = _get_jax_backend()
jax.config.update("jax_platform_name", jax_backend)
print(f"Jax is running on: {xla_bridge.get_backend().platform}")


# Define a function to apply strftime('%j') to each element
Expand All @@ -74,8 +67,8 @@ def _get_day_of_year(dt):

def _check_valid_array(obj):
# Check if the object is a NumPy or JAX array or list
if not isinstance(obj, (np.ndarray, jnp.DeviceArray, list)):
raise TypeError("Object must be a NumPy, JAX array or list.")
if not isinstance(obj, (np.ndarray, list)): # jnp.DeviceArray,
raise TypeError("Object must be a NumPy array or list.")

# convert lists to numpy array
if isinstance(obj, list):
Expand Down Expand Up @@ -363,7 +356,7 @@ def _count_longest_consecutive(values):
for value in values:
if value:
current_count += 1
max_count = max(max_count, current_count)
max_count = jnp.nanmax(max_count, current_count)
else:
current_count = 0

Expand Down Expand Up @@ -605,7 +598,7 @@ class ols_slope_intercept(gw.TimeModule):
def __init__(self, returns="slope"):
super(ols_slope_intercept, self).__init__()

allowed_values = ["slope", "intercept", "rsquared", "all"]
allowed_values = ["slope", "intercept", "rsquared"]
self.returns = returns

if self.returns not in allowed_values:
Expand All @@ -622,7 +615,7 @@ def calculate(self, array):
return intercept.squeeze()
elif self.returns == "rsquared":
array, SSR = jnp.apply_along_axis(_lstsq, axis=0, arr=array)
x_mean = jnp.nanmean(array, axis=0)
# x_mean = jnp.nanmean(array, axis=0)
y = jnp.arange(0, array.shape[0])
TSS = jnp.nansum((y - jnp.nanmean(y)) ** 2)

Expand Down Expand Up @@ -680,14 +673,15 @@ def __init__(self, r=2):
self.r = r

def calculate(self, array):
return (
out = (
jnp.nansum(
jnp.abs(array - jnp.nanmean(array, axis=0))
> self.r * jnp.nanstd(array, axis=0),
axis=0,
)
/ len(array)
).squeeze()
return jnp.where(jnp.isnan(out), 0, out)


class skewness(gw.TimeModule):
Expand Down Expand Up @@ -763,10 +757,11 @@ def __init__(self, r=0.1):
self.r = r

def calculate(self, array):
return (
out = (
jnp.abs(jnp.nanmean(array, axis=0) - jnp.nanmedian(array, axis=0))
< (self.r * (jnp.nanmax(array, axis=0) - jnp.nanmin(array, axis=0)))
).squeeze()
return jnp.where(jnp.isnan(out), 0, out)


class ts_complexity_cid_ce(gw.TimeModule):
Expand Down Expand Up @@ -862,7 +857,9 @@ def __init__(self):
super(variance_larger_than_standard_deviation, self).__init__()

def calculate(self, x):
return (jnp.nanvar(x, axis=0) > jnp.nanstd(x, axis=0)).astype(np.int8).squeeze()
out = (jnp.nanvar(x, axis=0) > jnp.nanstd(x, axis=0)).astype(np.int8).squeeze()

return jnp.where(jnp.isnan(out), 0, out)


function_mapping = {
Expand Down

0 comments on commit 2b5c628

Please sign in to comment.