From 01308e9c68ba3ee05b2050989dbd893cf6c16ca7 Mon Sep 17 00:00:00 2001 From: vincent d warmerdam Date: Mon, 22 Jul 2024 14:03:05 +0000 Subject: [PATCH] release --- Untitled.ipynb | 52 +++++++++++++++++++++++++++++++++++++++- bulk/__init__.py | 62 ++++++++++++++++++++++++++++++++++++++++++++++++ setup.py | 4 ++-- 3 files changed, 115 insertions(+), 3 deletions(-) diff --git a/Untitled.ipynb b/Untitled.ipynb index cf80a59..d3f9663 100644 --- a/Untitled.ipynb +++ b/Untitled.ipynb @@ -53,7 +53,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "b2f8a2aa96d84f7992698bc529823d47", + "model_id": "85a84c8facf74b79887c29265bfad96e", "version_major": 2, "version_minor": 0 }, @@ -134,6 +134,56 @@ "explorer.show()" ] }, + { + "cell_type": "code", + "execution_count": 8, + "id": "55668ebd-6887-4352-a967-63ad26759bc7", + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn.decomposition import PCA\n", + "\n", + "pca = PCA(n_components=2)\n", + "text_emb_pipeline = make_pipeline(\n", + " enc, pca\n", + ")\n", + "\n", + "# Calculate embeddings \n", + "X_tfm_pca = pca.fit_transform(X)\n", + "\n", + "# Write to disk. Note! Text column must be named \"text\"\n", + "df = pd.DataFrame({\"text\": sentences})\n", + "df['x'] = X_tfm_pca[:, 0]\n", + "df['y'] = X_tfm_pca[:, 1]" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "14626a98-3bc7-48c1-af09-ca90cd1a1add", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "938d1ff742114fb6b0d1754458bfa784", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(VBox(children=(Text(value='', description='String:', placeholder='Type something'), HBox(childr…" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "explorer = BaseTextExplorer(df, encoder=enc, X=X)\n", + "explorer.show()" + ] + }, { "cell_type": "code", "execution_count": 29, diff --git a/bulk/__init__.py b/bulk/__init__.py index e69de29..c63d54b 100644 --- a/bulk/__init__.py +++ b/bulk/__init__.py @@ -0,0 +1,62 @@ +import jscatter +from ipywidgets import HBox, VBox, HTML, Layout, Button, Text +from IPython.display import display +from sklearn.metrics.pairwise import cosine_similarity + +class BaseTextExplorer: + """ + Early preview of Jupyter Widget explorer. + """ + def __init__(self, dataf, X=None, encoder=None): + self.dataf = dataf + self.scatter = jscatter.Scatter(data=self.dataf, x="x", y="y", width=500, height=500) + self.html = HTML(layout=Layout(width='600px', overflow_y='scroll', height='400px')) + self.sample_btn = Button(description='resample') + self.elem = HBox([self.scatter.show(), VBox([self.sample_btn, self.html])]) + self.X = X + self.encoder = encoder + + if self.encoder and (self.X is not None): + self.text_input = Text(value='', placeholder='Type something', description='String:') + self.elem = HBox([VBox([self.text_input, self.scatter.show()]), VBox([self.sample_btn, self.html])]) + + def update_text(change): + X_tfm = encoder.transform([self.text_input.value]) + dists = cosine_similarity(X, X_tfm).reshape(1, -1) + self.dists = dists + norm_dists = 0.01 + (dists - dists.min())/(0.1 + dists.max() - dists.min()) + self.scatter.color(by=norm_dists[0]) + self.scatter.size(by=norm_dists[0]) + + self.text_input.observe(update_text) + + self.scatter.widget.observe(lambda d: self.update(), ['selection']) + self.sample_btn.on_click(lambda d: self.update()) + + def show(self): + return self.elem + + def update(self): + if len(self.scatter.selection()) > 10: + texts = self.dataf.iloc[self.scatter.selection()].sample(10)["text"] + else: + texts = self.dataf.iloc[self.scatter.selection()]["text"] + self.html.value = ''.join([f'

{t}

' for t in texts]) + + def observe(self, func): + self.scatter.widget.observe(func, ['selection']) + + @property + def selected_idx(self): + return self.scatter.selection() + + @property + def selected_texts(self): + return list(self.dataf.iloc[self.selection_idx]["text"]) + + @property + def selected_dataframe(self): + return self.dataf.iloc[self.selection_idx] + + def _repr_html_(self): + return display(self.elem) diff --git a/setup.py b/setup.py index d448036..a256648 100644 --- a/setup.py +++ b/setup.py @@ -2,9 +2,9 @@ setup( name="bulk", - version="0.3.1", + version="0.3.2", packages=find_packages(), - install_requires=["radicli>=0.0.8,<0.1.0", "bokeh>=2.4.3,<3.0.0", "pandas>=1.0.0", "wasabi>=0.9.1", "numpy<2", "jupyter-scatter"], + install_requires=["radicli>=0.0.8,<0.1.0", "bokeh>=2.4.3,<3.0.0", "pandas>=1.0.0", "wasabi>=0.9.1", "numpy<2", "jupyter-scatter", "scikit-learn"], extras_require={ "dev": ["pytest-playwright==0.3.0"], },