Skip to content

Commit

Permalink
release
Browse files Browse the repository at this point in the history
  • Loading branch information
koaning committed Jul 22, 2024
1 parent 8200d98 commit 01308e9
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 3 deletions.
52 changes: 51 additions & 1 deletion Untitled.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "b2f8a2aa96d84f7992698bc529823d47",
"model_id": "85a84c8facf74b79887c29265bfad96e",
"version_major": 2,
"version_minor": 0
},
Expand Down Expand Up @@ -134,6 +134,56 @@
"explorer.show()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "55668ebd-6887-4352-a967-63ad26759bc7",
"metadata": {},
"outputs": [],
"source": [
"from sklearn.decomposition import PCA\n",
"\n",
"pca = PCA(n_components=2)\n",
"text_emb_pipeline = make_pipeline(\n",
" enc, pca\n",
")\n",
"\n",
"# Calculate embeddings \n",
"X_tfm_pca = pca.fit_transform(X)\n",
"\n",
"# Write to disk. Note! Text column must be named \"text\"\n",
"df = pd.DataFrame({\"text\": sentences})\n",
"df['x'] = X_tfm_pca[:, 0]\n",
"df['y'] = X_tfm_pca[:, 1]"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "14626a98-3bc7-48c1-af09-ca90cd1a1add",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "938d1ff742114fb6b0d1754458bfa784",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(VBox(children=(Text(value='', description='String:', placeholder='Type something'), HBox(childr…"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"explorer = BaseTextExplorer(df, encoder=enc, X=X)\n",
"explorer.show()"
]
},
{
"cell_type": "code",
"execution_count": 29,
Expand Down
62 changes: 62 additions & 0 deletions bulk/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import jscatter
from ipywidgets import HBox, VBox, HTML, Layout, Button, Text
from IPython.display import display
from sklearn.metrics.pairwise import cosine_similarity

class BaseTextExplorer:
"""
Early preview of Jupyter Widget explorer.
"""
def __init__(self, dataf, X=None, encoder=None):
self.dataf = dataf
self.scatter = jscatter.Scatter(data=self.dataf, x="x", y="y", width=500, height=500)
self.html = HTML(layout=Layout(width='600px', overflow_y='scroll', height='400px'))
self.sample_btn = Button(description='resample')
self.elem = HBox([self.scatter.show(), VBox([self.sample_btn, self.html])])
self.X = X
self.encoder = encoder

if self.encoder and (self.X is not None):
self.text_input = Text(value='', placeholder='Type something', description='String:')
self.elem = HBox([VBox([self.text_input, self.scatter.show()]), VBox([self.sample_btn, self.html])])

def update_text(change):
X_tfm = encoder.transform([self.text_input.value])
dists = cosine_similarity(X, X_tfm).reshape(1, -1)
self.dists = dists
norm_dists = 0.01 + (dists - dists.min())/(0.1 + dists.max() - dists.min())
self.scatter.color(by=norm_dists[0])
self.scatter.size(by=norm_dists[0])

self.text_input.observe(update_text)

self.scatter.widget.observe(lambda d: self.update(), ['selection'])
self.sample_btn.on_click(lambda d: self.update())

def show(self):
return self.elem

def update(self):
if len(self.scatter.selection()) > 10:
texts = self.dataf.iloc[self.scatter.selection()].sample(10)["text"]
else:
texts = self.dataf.iloc[self.scatter.selection()]["text"]
self.html.value = ''.join([f'<p style="margin: 0px">{t}</p>' for t in texts])

def observe(self, func):
self.scatter.widget.observe(func, ['selection'])

@property
def selected_idx(self):
return self.scatter.selection()

@property
def selected_texts(self):
return list(self.dataf.iloc[self.selection_idx]["text"])

@property
def selected_dataframe(self):
return self.dataf.iloc[self.selection_idx]

def _repr_html_(self):
return display(self.elem)
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

setup(
name="bulk",
version="0.3.1",
version="0.3.2",
packages=find_packages(),
install_requires=["radicli>=0.0.8,<0.1.0", "bokeh>=2.4.3,<3.0.0", "pandas>=1.0.0", "wasabi>=0.9.1", "numpy<2", "jupyter-scatter"],
install_requires=["radicli>=0.0.8,<0.1.0", "bokeh>=2.4.3,<3.0.0", "pandas>=1.0.0", "wasabi>=0.9.1", "numpy<2", "jupyter-scatter", "scikit-learn"],
extras_require={
"dev": ["pytest-playwright==0.3.0"],
},
Expand Down

0 comments on commit 01308e9

Please sign in to comment.