Skip to content

Commit

Permalink
Initial SMS-AF and CT-AF push
Browse files Browse the repository at this point in the history
  • Loading branch information
jmcasebeer committed Oct 4, 2023
1 parent 56c4665 commit a61aa9d
Show file tree
Hide file tree
Showing 41 changed files with 5,521 additions and 116 deletions.
2 changes: 1 addition & 1 deletion metaaf/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def on_val_batch_end(self, out, aux, data_batch, cur_batch, cur_epoch):
base_name = os.path.join(epoch_dir, f"{self.num_logged}")
sf.write(f"{base_name}_out.wav", np.array(out[batch_idx, :, 0]), self.fs)

for (k, v) in data_batch["signals"].items():
for k, v in data_batch["signals"].items():
sf.write(f"{base_name}_{k}.wav", np.array(v[batch_idx, :, 0]), self.fs)

batch_idx += 1
Expand Down
19 changes: 0 additions & 19 deletions metaaf/complex_groupnorm.py

This file was deleted.

17 changes: 16 additions & 1 deletion metaaf/complex_gru.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
complex_sigmoid,
complex_tanh,
)
from metaaf.complex_norm import CLNorm
import types
from typing import Any, NamedTuple, Optional, Sequence, Tuple, Union

Expand Down Expand Up @@ -56,6 +57,7 @@ def __init__(
w_i_init: Optional[hk.initializers.Initializer] = None,
w_h_init: Optional[hk.initializers.Initializer] = None,
b_init: Optional[hk.initializers.Initializer] = None,
use_norm=False,
name: Optional[str] = None,
):
super().__init__(name=name)
Expand All @@ -65,6 +67,12 @@ def __init__(
self.b_init = b_init or complex_zeros
self.sig = complex_sigmoid

self.use_norm = use_norm
if self.use_norm:
self.i_norm = CLNorm(axis=-1, create_scale=True, create_offset=True)
self.zrh_norm = CLNorm(axis=-1, create_scale=True, create_offset=True)
self.ah_norm = CLNorm(axis=-1, create_scale=True, create_offset=True)

def __call__(self, inputs, state):
if inputs.ndim not in (1, 2):
raise ValueError("GRU input must be rank-1 or rank-2.")
Expand All @@ -82,15 +90,22 @@ def __call__(self, inputs, state):
b_z, b_a = jnp.split(b, indices_or_sections=[2 * hidden_size], axis=0)

gates_x = jnp.matmul(inputs, w_i)
if self.use_norm:
gates_x = self.i_norm(gates_x)

zr_x, a_x = jnp.split(gates_x, indices_or_sections=[2 * hidden_size], axis=-1)
zr_h = jnp.matmul(state, w_h_z)
if self.use_norm:
zr_h = self.zrh_norm(zr_h)

zr = zr_x + zr_h + jnp.broadcast_to(b_z, zr_h.shape)
z, r = jnp.split(self.sig(zr), indices_or_sections=2, axis=-1)

a_h = jnp.matmul(r * state, w_h_a)

if self.use_norm:
a_h = self.ah_norm(a_h)

a = complex_tanh(a_x + a_h + jnp.broadcast_to(b_a, a_h.shape))

next_state = (1 - z) * state + z * a
Expand Down Expand Up @@ -120,7 +135,7 @@ def make_deep_initial_state(params, **kwargs):
n_layers = kwargs["n_layers"]

def single_layer_initial_state():
state = jnp.zeros([h_size], dtype=np.dtype("complex64"))
state = jnp.zeros([h_size], params.dtype) # dtype=np.dtype("complex64"))
state = add_batch(state, b_size)
return state

Expand Down
78 changes: 78 additions & 0 deletions metaaf/complex_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import jax.numpy as jnp
import haiku as hk


class CGN(hk.Module):
def __init__(
self, groups=6, create_scale=True, create_offset=True, eps=1e-5, name=None
):
super().__init__(name=name)
self.gn_real = hk.GroupNorm(
groups=groups,
create_scale=create_scale,
create_offset=create_offset,
eps=eps,
)
self.gn_imag = hk.GroupNorm(
groups=groups,
create_scale=create_scale,
create_offset=create_offset,
eps=eps,
)

def __call__(self, x):
x_real = jnp.real(x)
x_imag = jnp.imag(x)

x_real_n = self.gn_real(x_real)
x_imag_n = self.gn_imag(x_imag)

return (x_real_n + 1j * x_imag_n) / jnp.sqrt(2)


class CLNorm(hk.Module):
def __init__(
self,
axis,
create_scale,
create_offset,
eps=1e-05,
scale_init=None,
offset_init=None,
use_fast_variance=False,
name=None,
param_axis=None,
):
super().__init__(name=name)
self.n_real = hk.LayerNorm(
axis=axis,
create_scale=create_scale,
create_offset=create_offset,
eps=eps,
scale_init=scale_init,
offset_init=offset_init,
use_fast_variance=use_fast_variance,
name=name,
param_axis=param_axis,
)

self.n_imag = hk.LayerNorm(
axis=axis,
create_scale=create_scale,
create_offset=create_offset,
eps=eps,
scale_init=scale_init,
offset_init=offset_init,
use_fast_variance=use_fast_variance,
name=name,
param_axis=param_axis,
)

def __call__(self, x):
x_real = jnp.real(x)
x_imag = jnp.imag(x)

x_real_n = self.n_real(x_real)
x_imag_n = self.n_imag(x_imag)

return (x_real_n + 1j * x_imag_n) / jnp.sqrt(2)
23 changes: 20 additions & 3 deletions metaaf/complex_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,33 @@
import haiku as hk


def complex_zeros(shape, _):
return jnp.zeros(shape, dtype=jnp.complex64)
def complex_zeros(shape, dtype):
return jnp.zeros(shape, dtype=dtype) + 1j * jnp.zeros(shape, dtype=dtype)


# see https://openreview.net/attachment?id=H1T2hmZAb&name=pdf
def complex_variance_scaling(shape, dtype):
real = hk.initializers.VarianceScaling()(shape, dtype=jnp.float32)
imag = hk.initializers.VarianceScaling()(shape, dtype=jnp.float32)

mag = jnp.sqrt(real ** 2 + imag ** 2)
mag = jnp.sqrt(real**2 + imag**2)
angle = hk.initializers.RandomUniform(minval=-jnp.pi, maxval=jnp.pi)(
shape, dtype=jnp.float32
)

return mag * jnp.exp(1j * angle)


def complex_xavier(shape, dtype):
real = hk.initializers.VarianceScaling(1.0, "fan_avg", "truncated_normal")(
shape, dtype=jnp.float32
)
imag = hk.initializers.VarianceScaling(1.0, "fan_avg", "truncated_normal")(
shape, dtype=jnp.float32
)

eps = 1e-11
mag = jnp.sqrt(real**2 + imag**2 + eps)
angle = hk.initializers.RandomUniform(minval=-jnp.pi, maxval=jnp.pi)(
shape, dtype=jnp.float32
)
Expand Down
Loading

0 comments on commit a61aa9d

Please sign in to comment.