diff --git a/workflows/projects/Mixtures/cool_tasks.smk b/workflows/projects/Mixtures/cool_tasks.smk index 9488edab..3abd6804 100644 --- a/workflows/projects/Mixtures/cool_tasks.smk +++ b/workflows/projects/Mixtures/cool_tasks.smk @@ -32,7 +32,25 @@ ESTIMATOR_NAMES = { "Hist-10": "Histogram", "CCA": "CCA", } +ESTIMATOR_COLORS = { + "MINE": '#377eb8', + "InfoNCE": '#ff7f00', + "KSG-10": '#4daf4a', + "Hist-10": '#f781bf', + "CCA": '#a65628', +} + +ESTIMATOR_MARKERS = { + "MINE": 'o', + "InfoNCE": 'v', + "KSG-10": '^', + "Hist-10": 'D', + "CCA": 'X', +} + assert set(ESTIMATOR_NAMES.keys()) == set(ESTIMATORS.keys()) +assert set(ESTIMATOR_COLORS.keys()) == set(ESTIMATORS.keys()) +assert set(ESTIMATOR_MARKERS.keys()) == set(ESTIMATORS.keys()) _SAMPLE_ESTIMATE: int = 200_000 @@ -73,7 +91,7 @@ rule all: 'cool_tasks.pdf', 'results.csv', 'cool_tasks-results.pdf', - 'profiles.pdf' + # 'profiles.pdf' rule plot_distributions: output: "cool_tasks.pdf" @@ -122,21 +140,21 @@ rule plot_distributions: fig.savefig(str(output), dpi=300) -rule plot_pmi_profiles: - output: "profiles.pdf" - run: - fig, axs = subplots_from_axsize(1, 4, axsize=(4, 3)) - dists = [x_dist, ai_dist, fence_base_dist, balls_mixt] - tasks_official = ['X', 'AI', 'Waves', 'Galaxy'] - for dist, task_name, ax in zip(dists, tasks_official, axs): - import jax - key = jax.random.PRNGKey(1024) - pmi_values = bmm.pmi_profile(key=key, dist=dist, n=100_000) - bins = np.linspace(-5, 5, 101) - ax.hist(pmi_values, bins=bins, density=True, alpha=0.5) - ax.set_xlabel(task_name) - axs[0].set_ylabel("Density") - fig.savefig(str(output)) +# rule plot_pmi_profiles: +# output: "profiles.pdf" +# run: +# fig, axs = subplots_from_axsize(1, 4, axsize=(4, 3)) +# dists = [x_dist, ai_dist, fence_base_dist, balls_mixt] +# tasks_official = ['X', 'AI', 'Waves', 'Galaxy'] +# for dist, task_name, ax in zip(dists, tasks_official, axs): +# import jax +# key = jax.random.PRNGKey(1024) +# pmi_values = bmm.pmi_profile(key=key, dist=dist, n=100_000) +# bins = np.linspace(-5, 5, 101) +# ax.hist(pmi_values, bins=bins, density=True, alpha=0.5) +# ax.set_xlabel(task_name) +# axs[0].set_ylabel("Density") +# fig.savefig(str(output)) rule plot_results: @@ -155,7 +173,10 @@ rule plot_results: data_est['task_id'].apply(lambda e: tasks.index(e)) + 0.05 * np.random.normal(size=len(data_est)), data_est['mi_estimate'], label=ESTIMATOR_NAMES[estimator_id], - alpha=0.4, s=3**2, + alpha=0.4, s=5**2, + marker=ESTIMATOR_MARKERS[estimator_id], + c=ESTIMATOR_COLORS[estimator_id], + edgecolor="none", ) for task_id, data_task in data_5k.groupby('task_id'): diff --git a/workflows/projects/Mixtures/how_good_integration_is.smk b/workflows/projects/Mixtures/how_good_integration_is.smk index 0948cb98..61b84c95 100644 --- a/workflows/projects/Mixtures/how_good_integration_is.smk +++ b/workflows/projects/Mixtures/how_good_integration_is.smk @@ -51,11 +51,19 @@ ESTIMATORS: dict[str, Callable] = { } ESTIMATOR_COLORS = { - "InfoNCE": "magenta", - "DV": "red", - "NWJ": "limegreen", - "MC": "mediumblue", + "InfoNCE": '#ff7f00', + "DV": '#984ea3', + "NWJ": "#999999", + "MC": "#dede00", } +ESTIMATOR_MARKERS = { + "InfoNCE": 'v', + "DV": 'D', + "NWJ": "X", + "MC": ".", +} + + four_balls = bmm.mixture( proportions=jnp.array([0.3, 0.3, 0.2, 0.2]), @@ -210,9 +218,6 @@ def plot_estimates(ax: plt.Axes, estimates_path, ground_truth_path, alpha: float with open(ground_truth_path) as fh: ground_truth = json.load(fh) - # Add ground-truth information - x_axis =[df["n_points"].min(), df["n_points"].max()] - ax.plot(x_axis, [ground_truth["mi_mean"]] * 2, c="k", linestyle=":") # ax.fill_between( # x_axis, # [ground_truth["mi_mean"] - ground_truth["mi_std"]] * 2, @@ -232,10 +237,16 @@ def plot_estimates(ax: plt.Axes, estimates_path, ground_truth_path, alpha: float color = ESTIMATOR_COLORS[estimator] - ax.plot(points, mean, color=color, label=estimator) + ax.plot(points, mean, color=color) + ax.scatter(points, mean, color=color, marker=ESTIMATOR_MARKERS[estimator], label=estimator) ax.fill_between(points, mean - std, mean + std, alpha=alpha, color=color) + # Add ground-truth information + x_axis =[df["n_points"].min(), df["n_points"].max()] + ax.plot(x_axis, [ground_truth["mi_mean"]] * 2, c="k", linestyle=":") + + rule plot_performance_all: input: simple_ground_truth="Four_Balls/ground_truth.json", diff --git a/workflows/projects/Mixtures/outliers.smk b/workflows/projects/Mixtures/outliers.smk index 10e74f8c..6053f9d1 100644 --- a/workflows/projects/Mixtures/outliers.smk +++ b/workflows/projects/Mixtures/outliers.smk @@ -112,14 +112,21 @@ for variance in VARIANCES: UNSCALED_TASKS = {**MIXING_TASKS, **VARIANCE_TASKS} - ESTIMATOR_COLORS = { - "InfoNCE": "magenta", - "MINE": "red", - "KSG": "green", - "CCA": "purple", + "InfoNCE": '#ff7f00', + "MINE": '#377eb8', + "KSG": '#4daf4a', + "CCA": '#a65628', +} + +ESTIMATOR_MARKERS = { + "InfoNCE": 'v', + "MINE": '.', + "KSG": '^', + "CCA": 'X', } + ESTIMATORS = { "KSG": bmi.estimators.KSGEnsembleFirstEstimator(neighborhoods=(10,)), "CCA": bmi.estimators.CCAMutualInformationEstimator(), @@ -162,7 +169,8 @@ def plot_data(ax: plt.Axes, data: pd.DataFrame, key: str = "mixing", use_legend: subset = grouped[grouped['estimator_id'] == estimator] color = ESTIMATOR_COLORS[estimator] - ax.plot(subset[key], subset['mean'], color=color, label=estimator) + ax.plot(subset[key], subset['mean'], color=color) + ax.scatter(subset[key], subset['mean'], color=color, marker=ESTIMATOR_MARKERS[estimator], label=estimator) ax.fill_between(subset[key], subset['mean'] - subset['std'], subset['mean'] + subset['std'], alpha=0.3, color=color) if use_legend: