Skip to content

Commit

Permalink
enabled plotting of categorical annotation in pl.paga
Browse files Browse the repository at this point in the history
  • Loading branch information
falexwolf committed Jun 8, 2018
1 parent f48476a commit f3b769e
Showing 1 changed file with 30 additions and 43 deletions.
73 changes: 30 additions & 43 deletions scanpy/plotting/tools/paga.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,10 +216,9 @@ def paga(
Do not draw edges for weights below this threshold. Set to 0 if you want
all edges. Discarding low-connectivity edges helps in getting a much
clearer picture of the graph.
color : gene name or iterable of `int` or `float` or color strings, optional (default: `None`)
The node colors. Besides lists, uniform colors this also automatically
plots the degree of the abstracted graph when passing {'degree_dashed',
'degree_solid'}.
color : gene name or obs. annotation, optional (default: `None`)
The node colors. Also plots the degree of the abstracted graph when
passing {'degree_dashed', 'degree_solid'}.
labels : `None`, `str`, `list`, `dict`, optional (default: `None`)
The node labels. If `None`, this defaults to the group labels stored in
the categorical for which :func:`~scanpy.api.tl.paga` has been computed.
Expand Down Expand Up @@ -250,14 +249,6 @@ def paga(
dashed_edges : `str` or `None`, optional (default: `None`)
Key for `.uns['paga']` that specifies the matrix that stores the edges
to be drawn dashed grey. If `None`, no dashed edges are drawn.
threshold_solid : `float` or `None`, optional (default: `threshold`)
Do not draw edges for weights below this threshold. Set to `None` if you
want all edges.
threshold_dashed : `float` or `None`, optional (default: `threshold`)
Do not draw edges for weights below this threshold. Set to `None` if you
want all edges.
fontsize : `int` (default: `None`)
Font size for node labels.
fontsize : `int` (default: `None`)
Font size for node labels.
text_kwds : keywords for text
Expand Down Expand Up @@ -527,14 +518,26 @@ def _paga_graph(
colors = x_color

# plot continuous annotation
if isinstance(colors, str) and colors in adata.obs:
if (isinstance(colors, str) and colors in adata.obs
and not is_categorical_dtype(adata.obs[colors])):
x_color = []
cats = adata.obs[groups_key].cat.categories
for icat, cat in enumerate(cats):
subset = (cat == adata.obs[groups_key]).values
x_color.append(adata.obs.loc[subset, colors].mean())
colors = x_color

# plot categorical annotation
if (isinstance(colors, str) and colors in adata.obs and
is_categorical_dtype(adata.obs[colors])):
from ... import utils as sc_utils
asso_names, asso_matrix = sc_utils.compute_association_matrix_of_groups(
adata, prediction=groups_key, reference=colors, normalization='reference')
utils.add_colors_for_categorical_sample_annotation(adata, colors)
asso_colors = sc_utils.get_associated_colors_of_groups(
adata.uns[colors + '_colors'], asso_matrix)
colors = asso_colors

if len(colors) < len(node_labels):
print(node_labels, colors)
raise ValueError(
Expand Down Expand Up @@ -750,6 +753,10 @@ def _paga_graph(
size=fontsize, fontweight=fontweight, **text_kwds)
# else pie chart plot
else:
# start with this dummy plot... otherwise strange behavior
sct = ax.scatter(
pos_array[:, 0], pos_array[:, 1],
c='white', edgecolors='face', s=groups_sizes, cmap=cmap)
trans = ax.transData.transform
bbox = ax.get_position().get_points()
ax_x_min = bbox[0, 0]
Expand All @@ -759,17 +766,19 @@ def _paga_graph(
ax_len_x = ax_x_max - ax_x_min
ax_len_y = ax_y_max - ax_y_min
trans2 = ax.transAxes.inverted().transform
force_labels_to_front = True # TODO: solve this differently!
pie_axs = []
for count, n in enumerate(nx_g_solid.nodes()):
pie_size = groups_sizes[count] / base_scale_scatter
xx, yy = trans(pos[n]) # data coordinates
xa, ya = trans2((xx, yy)) # axis coordinates
x1, y1 = trans(pos[n]) # data coordinates
xa, ya = trans2((x1, y1)) # axis coordinates
xa = ax_x_min + (xa - pie_size/2) * ax_len_x
ya = ax_y_min + (ya - pie_size/2) * ax_len_y
# clip, the fruchterman layout sometimes places below figure
if ya < 0: ya = 0
if xa < 0: xa = 0
a = ax.axes([xa, ya, pie_size * ax_len_x, pie_size * ax_len_y])
pie_axs.append(pl.axes([xa, ya, pie_size * ax_len_x, pie_size * ax_len_y], frameon=False))
pie_axs[count].set_xticks([])
pie_axs[count].set_yticks([])
if not isinstance(colors[count], dict):
raise ValueError('{} is neither a dict of valid matplotlib colors '
'nor a valid matplotlib color.'.format(colors[count]))
Expand All @@ -779,32 +788,10 @@ def _paga_graph(
color_single = list(color_single)
color_single.append('grey')
fracs.append(1-sum(fracs))
a.pie(fracs, colors=color_single)
if not force_labels_to_front and node_labels is not None:
a.text(0.5, 0.5, node_labels[count],
verticalalignment='center',
horizontalalignment='center',
transform=a.transAxes,
size=fontsize)
# TODO: this is a terrible hack, but if we use the solution above (`not
# force_labels_to_front`), labels get hidden behind pies
if force_labels_to_front and node_labels is not None:
for count, n in enumerate(nx_g_solid.nodes()):
pie_size = groups_sizes[count] / base_scale_scatter
# all copy and paste from above
xx, yy = trans(pos[n]) # data coordinates
xa, ya = trans2((xx, yy)) # axis coordinates
# make sure a new axis is created
xa = ax_x_min + (xa - pie_size/2.0000001) * ax_len_x
ya = ax_y_min + (ya - pie_size/2.0000001) * ax_len_y
# clip, the fruchterman layout sometimes places below figure
if ya < 0: ya = 0
if xa < 0: xa = 0
a = pl.axes([xa, ya, pie_size * ax_len_x, pie_size * ax_len_y])
a.set_frame_on(False)
a.set_xticks([])
a.set_yticks([])
a.text(0.5, 0.5, node_labels[count],
pie_axs[count].pie(fracs, colors=color_single)
if node_labels is not None:
for ia, a in enumerate(pie_axs):
a.text(0.5, 0.5, node_labels[ia],
verticalalignment='center',
horizontalalignment='center',
transform=a.transAxes, size=fontsize)
Expand Down

0 comments on commit f3b769e

Please sign in to comment.