Skip to content

Commit

Permalink
Benchmark v2 (#155)
Browse files Browse the repository at this point in the history
Co-authored-by: Paweł Czyż <[email protected]>
  • Loading branch information
grfrederic and pawel-czyz authored Apr 18, 2024
1 parent 1dd009f commit 44fc87e
Show file tree
Hide file tree
Showing 14 changed files with 588 additions and 31 deletions.
2 changes: 1 addition & 1 deletion .flake8
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[flake8]
max-line-length = 99
ignore = W503
ignore = W503,E202,E251
exclude =
.git,
__pycache__,
Expand Down
6 changes: 3 additions & 3 deletions scripts/Beyond_Normal/figures/student/correction.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
from typing import cast

import matplotlib.pyplot as plt
from subplots_from_axsize import subplots_from_axsize

import bmi.api as bmi
from bmi.plot_utils.subplots_from_axsize import subplots_from_axsize
import bmi


def correction(df: int, m: int, n: int) -> float:
Expand Down Expand Up @@ -43,7 +43,7 @@ def main() -> None:
for i, (m, n) in enumerate(mns):
values = [correction(df=df, m=m, n=n) for df in nus]
# ax.scatter(nus, values, s=4, marker=MARKER_LIST[i], c=f"C{i+1}")
ax.plot(nus, values, c=f"C{i+1}", label=f"$m={m}$,\t$n={n}$")
ax.plot(nus, values, c=f"C{i + 1}", label=f"$m={m}$,\t$n={n}$")

ax.spines[["right", "top"]].set_visible(False)
ax.set_xlabel("Degrees of freedom")
Expand Down
14 changes: 10 additions & 4 deletions src/bmi/benchmark/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,19 @@ def __init__(
task_params: Optional[dict] = None,
):
self.sampler = sampler
self.metadata = TaskMetadata(
task_id=task_id,
task_name=task_name,
self._task_id = task_id
self._task_name = task_name
self._task_params = task_params

@property
def metadata(self) -> TaskMetadata:
return TaskMetadata(
task_id=self._task_id,
task_name=self._task_name,
dim_x=self.sampler.dim_x,
dim_y=self.sampler.dim_y,
mi_true=self.sampler.mutual_information(),
task_params=task_params or dict(),
task_params=self._task_params or dict(),
)

@property
Expand Down
15 changes: 15 additions & 0 deletions src/bmi/benchmark/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,15 @@
task_multinormal_lvm,
)

from bmi.benchmark.tasks.mixtures import (
task_x,
task_ai,
task_waves,
task_galaxy,
task_concentric_multinormal,
task_multinormal_sparse_w_inliers,
)

# isort: on
from bmi.benchmark.tasks.normal_cdf import transform_normal_cdf_task
from bmi.benchmark.tasks.rotate import transform_rotate_task
Expand Down Expand Up @@ -44,6 +53,12 @@
"task_multinormal_dense",
"task_multinormal_sparse",
"task_multinormal_2pair",
"task_x",
"task_ai",
"task_waves",
"task_galaxy",
"task_concentric_multinormal",
"task_multinormal_sparse_w_inliers",
"task_student_dense",
"task_student_sparse",
"task_student_2pair",
Expand Down
269 changes: 269 additions & 0 deletions src/bmi/benchmark/tasks/mixtures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,269 @@
import jax.numpy as jnp
import numpy as np

import bmi.samplers as samplers
import bmi.transforms as transforms
from bmi.benchmark.task import Task
from bmi.samplers import fine

_MC_MI_ESTIMATE_SAMPLE = 100_000


def task_x(
gaussian_correlation=0.9,
mi_estimate_sample=_MC_MI_ESTIMATE_SAMPLE,
) -> Task:
"""The X distribution."""

dist = fine.mixture(
proportions=jnp.array([0.5, 0.5]),
components=[
fine.MultivariateNormalDistribution(
covariance=samplers.canonical_correlation([x * gaussian_correlation]),
mean=jnp.zeros(2),
dim_x=1,
dim_y=1,
)
for x in [-1, 1]
],
)
sampler = fine.FineSampler(dist, mi_estimate_sample=mi_estimate_sample)

return Task(
sampler=sampler,
task_id=f"1v1-X-{gaussian_correlation}",
task_name="X 1 × 1",
task_params={
"gaussian_correlation": gaussian_correlation,
},
)


def task_ai(
mi_estimate_sample=_MC_MI_ESTIMATE_SAMPLE,
) -> Task:
"""The AI distribution."""

corr = 0.95
var_x = 0.04

dist = fine.mixture(
proportions=jnp.full(6, fill_value=1 / 6),
components=[
# I components
fine.MultivariateNormalDistribution(
dim_x=1,
dim_y=1,
mean=jnp.array([1.0, 0.0]),
covariance=np.diag([0.01, 0.2]),
),
fine.MultivariateNormalDistribution(
dim_x=1,
dim_y=1,
mean=jnp.array([1.0, 1]),
covariance=np.diag([0.05, 0.001]),
),
fine.MultivariateNormalDistribution(
dim_x=1,
dim_y=1,
mean=jnp.array([1.0, -1]),
covariance=np.diag([0.05, 0.001]),
),
# A components
fine.MultivariateNormalDistribution(
dim_x=1,
dim_y=1,
mean=jnp.array([-0.8, -0.2]),
covariance=np.diag([0.03, 0.001]),
),
fine.MultivariateNormalDistribution(
dim_x=1,
dim_y=1,
mean=jnp.array([-1.2, 0.0]),
covariance=jnp.array(
[[var_x, jnp.sqrt(var_x * 0.2) * corr], [jnp.sqrt(var_x * 0.2) * corr, 0.2]]
),
),
fine.MultivariateNormalDistribution(
dim_x=1,
dim_y=1,
mean=jnp.array([-0.4, 0.0]),
covariance=jnp.array(
[[var_x, -jnp.sqrt(var_x * 0.2) * corr], [-jnp.sqrt(var_x * 0.2) * corr, 0.2]]
),
),
],
)
sampler = fine.FineSampler(dist, mi_estimate_sample=mi_estimate_sample)

return Task(
sampler=sampler,
task_id="1v1-AI",
task_name="AI 1 × 1",
)


def task_galaxy(
speed=0.5,
distance=3.0,
mi_estimate_sample=_MC_MI_ESTIMATE_SAMPLE,
) -> Task:
"""The Galaxy distribution."""

balls_mixt = fine.mixture(
proportions=jnp.array([0.5, 0.5]),
components=[
fine.MultivariateNormalDistribution(
covariance=samplers.canonical_correlation([0.0], additional_y=1),
mean=jnp.array([x, x, x]) * distance / 2,
dim_x=2,
dim_y=1,
)
for x in [-1, 1]
],
)

base_sampler = fine.FineSampler(balls_mixt, mi_estimate_sample=mi_estimate_sample)
a = jnp.array([[0, -1], [1, 0]])
spiral = transforms.Spiral(a, speed=speed)

sampler = samplers.TransformedSampler(base_sampler, transform_x=spiral)

return Task(
sampler=sampler,
task_id=f"2v1-galaxy-{speed}-{distance}",
task_name="Galaxy 2 × 1",
task_params={
"speed": speed,
"distance": distance,
},
)


def task_waves(
n_components=12,
wave_amplitude=5.0,
wave_frequency=3.0,
mi_estimate_sample=_MC_MI_ESTIMATE_SAMPLE,
) -> Task:
"""The Waves distribution."""

assert n_components > 0

base_dist = fine.mixture(
proportions=jnp.ones(n_components) / n_components,
components=[
fine.MultivariateNormalDistribution(
covariance=jnp.diag(jnp.array([0.1, 1.0, 0.1])),
mean=jnp.array([x, 0, x % 4]) * 1.5,
dim_x=2,
dim_y=1,
)
for x in range(n_components)
],
)
base_sampler = fine.FineSampler(base_dist, mi_estimate_sample=mi_estimate_sample)
aux_sampler = samplers.TransformedSampler(
base_sampler,
transform_x=lambda x: x
+ jnp.array([wave_amplitude, 0.0]) * jnp.sin(wave_frequency * x[1]),
)
sampler = samplers.TransformedSampler(
aux_sampler, transform_x=lambda x: jnp.array([0.1 * x[0] - 0.8, 0.5 * x[1]])
)

return Task(
sampler=sampler,
task_id=f"2v1-waves-{n_components}-{wave_amplitude}-{wave_frequency}",
task_name="Waves 2 × 1",
task_params={
"n_components": n_components,
"wave_amplitude": wave_amplitude,
"wave_frequency": wave_frequency,
},
)


def task_concentric_multinormal(
dim_x,
n_components=3,
mi_estimate_sample=_MC_MI_ESTIMATE_SAMPLE,
) -> Task:
"""Isotropic Gaussians with varying standard deviation."""

assert n_components > 0

dist = fine.mixture(
proportions=jnp.ones(n_components) / n_components,
components=[
fine.MultivariateNormalDistribution(
covariance=jnp.diag(jnp.array(dim_x * [i**2] + [0.0001])),
mean=jnp.array(dim_x * [0.0] + [1.0 * i]),
dim_x=dim_x,
dim_y=1,
)
for i in range(1, 1 + n_components)
],
)
sampler = fine.FineSampler(dist, mi_estimate_sample=mi_estimate_sample)

return Task(
sampler=sampler,
task_id=f"{dim_x}v1-concentric_gaussians-{n_components}",
task_name=f"Concentric {dim_x} × 1",
task_params={
"n_components": n_components,
},
)


def task_multinormal_sparse_w_inliers(
dim_x,
dim_y,
n_interacting: int = 2,
strength: float = 2.0,
inlier_fraction: float = 0.2,
mi_estimate_sample=_MC_MI_ESTIMATE_SAMPLE,
) -> Task:

assert 0.0 <= inlier_fraction <= 1.0

params = samplers.GaussianLVMParametrization(
dim_x=dim_x,
dim_y=dim_y,
n_interacting=n_interacting,
alpha=0.0,
lambd=strength,
beta_x=0.0,
eta_x=strength,
)

signal_dist = fine.MultivariateNormalDistribution(
dim_x=dim_x,
dim_y=dim_y,
covariance=params.correlation,
)

noise_dist = fine.ProductDistribution(
dist_x=signal_dist.dist_x,
dist_y=signal_dist.dist_y,
)

dist = fine.mixture(
proportions=jnp.array([1 - inlier_fraction, inlier_fraction]),
components=[signal_dist, noise_dist],
)

sampler = fine.FineSampler(dist, mi_estimate_sample=mi_estimate_sample)

task_id = f"mult-sparse-w-inliers-{dim_x}-{dim_y}-{n_interacting}-{strength}-{inlier_fraction}"
return Task(
sampler=sampler,
task_id=task_id,
task_name=f"Multinormal {dim_x} × {dim_y} (sparse, {inlier_fraction:.0%} inliers)",
task_params={
"n_interacting": n_interacting,
"strength": strength,
"inlier_fraction": inlier_fraction,
},
)
18 changes: 10 additions & 8 deletions src/bmi/samplers/_matrix_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Utilities for creating dispersion matrices."""

import dataclasses
from typing import Optional
from typing import List, Optional, Union

import numpy as np

Expand Down Expand Up @@ -73,7 +73,9 @@ def parametrised_correlation_matrix(
return corr_matrix


def canonical_correlation(rho: np.ndarray, additional_y: int = 0) -> np.ndarray:
def canonical_correlation(
rho: Union[np.ndarray, List[float]], additional_y: int = 0
) -> np.ndarray:
"""Constructs a covariance matrix given by canonical correlations.
Namely,
Expand Down Expand Up @@ -327,17 +329,17 @@ def correlation(self) -> np.ndarray:
def latent_variable_labels(self) -> list[str]:
return (
["$U_\\mathrm{all}$", "$U_X$", "$U_Y$"]
+ [f"$Z_{i+1}$" for i in range(self.n_interacting)]
+ [f"$E_{i+1}$" for i in range(self.dim_x)]
+ [f"$F_{i+1}$" for i in range(self.dim_y)]
+ [f"$V_{i+1}$" for i in range(self.n_interacting, self.dim_x)]
+ [f"$Z_{i + 1}$" for i in range(self.n_interacting)]
+ [f"$E_{i + 1}$" for i in range(self.dim_x)]
+ [f"$F_{i + 1}$" for i in range(self.dim_y)]
+ [f"$V_{i + 1}$" for i in range(self.n_interacting, self.dim_x)]
+ [f"$W_{i + 1}$" for i in range(self.n_interacting, self.dim_y)]
)

@property
def xy_labels(self) -> list[str]:
return [f"$X_{i+1}$" for i in range(self.dim_x)] + [
f"$Y_{j+1}$" for j in range(self.dim_y)
return [f"$X_{i + 1}$" for i in range(self.dim_x)] + [
f"$Y_{j + 1}$" for j in range(self.dim_y)
]


Expand Down
Loading

0 comments on commit 44fc87e

Please sign in to comment.