Fluent scatter plot #492
ecomodeller
started this conversation in
Ideas
Replies: 1 comment
-
Adding more features... Example implementationimport matplotlib.pyplot as plt
import numpy as np
import scipy.stats as stats
from typing import Any, Optional, Self, Sequence, Union
def quantiles_xy(
x: np.ndarray,
y: np.ndarray,
quantiles: Optional[Union[int, Sequence[float]]] = None,
):
if isinstance(quantiles, Sequence):
q = np.array(quantiles)
else:
if quantiles is None:
if len(x) >= 3000:
n_quantiles = 1000
elif len(x) >= 300:
n_quantiles = 100
else:
n_quantiles = 10
if isinstance(quantiles, int):
n_quantiles = quantiles
q = np.linspace(0, 1, num=n_quantiles)
return np.quantile(x, q=q), np.quantile(y, q=q)
class Comparer:
def __init__(self,x: np.ndarray, y: np.ndarray) -> None:
self.x = x
self.y = y
def plot(self,**kwargs) -> Self:
self.fig, self.ax = plt.subplots(**kwargs)
return self
def scatter(self,alpha=1.0) -> Self:
self.ax.scatter(self.x, self.y,alpha=alpha)
return self
def qq(self,quantiles=None) -> Self:
# Perform a QQ plot
xq = xq, yq = quantiles_xy(self.x, self.y, quantiles)
self.ax.plot(
xq,
yq,
"bo-"
)
return self
def reg_line(self, equation=True) -> Self:
x_data = self.ax.collections[0].get_offsets()[:, 0]
y_data = self.ax.collections[0].get_offsets()[:, 1]
slope, intercept = np.polyfit(x_data, y_data, 1)
self.ax.plot(x_data, slope * x_data + intercept, color='red', linestyle='--')
if equation:
self.ax.legend([f"y = {slope:.2f}x + {intercept:.2f}"])
return self
def default(self) -> Self:
return (self.scatter()
.qq()
.reg_line())
def skill_table(self, metrics:Sequence[str]=("n","bias","rmse")) -> Self:
METRICS = {
"rmse": {"f": np.sqrt(np.mean((self.x - self.y) ** 2)), "type": "float"},
"bias" : {"f": np.mean(self.x - self.y), "type": "float"},
"n" : {"f": len(self.x), "type": "int"}
}
textstr = ""
for metric in metrics:
match METRICS[metric]["type"]:
case "float":
textstr += f"{metric}: {METRICS[metric]['f']:.2f}\n"
case "int":
textstr += f"{metric}: {METRICS[metric]['f']}\n"
case _:
raise ValueError(f"Unknown type: {METRICS[metric]['type']}")
self.ax.text(0.05, 0.95, textstr, transform=self.ax.transAxes, fontsize=14,
verticalalignment='top', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
return self
def show(self) -> None:
plt.show()
# Example usage:
x = np.random.normal(0, 1, 1000)
y = 2 * x + np.random.normal(0, 1, 1000)
(Comparer(x,y)
.plot(figsize=(6,6))
.scatter(alpha=0.5)
.qq([0.05, 0.5, 0.75, 0.95])
.reg_line(equation=True)
.skill_table(("n", "bias"))
); |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
The number of arguments to the
scatter
is pretty long.(Please disregard how the plots look, this thread is only meant to discuss syntax).
I propose an alternative using a fluent interface where arguments to each part,
qq
,reg_line
etc would go in separate methods.Easy to comment out some parts.
This allows easier selection of components of the advanced scatter plot.
It should still be easy to create the default plot with several components, e.g.
Beta Was this translation helpful? Give feedback.
All reactions