Skip to content

Commit

Permalink
minor fix
Browse files Browse the repository at this point in the history
  • Loading branch information
wjm41 committed Nov 24, 2022
1 parent c8cce63 commit 2f3aff1
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 6 deletions.
59 changes: 54 additions & 5 deletions examples/pose_centroids_molplotly.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,30 @@
20.56523380975767
],
"yaxis": "y"
},
{
"marker": {
"color": "red",
"size": 12,
"symbol": "x"
},
"mode": "markers",
"showlegend": false,
"type": "scatter",
"x": [
20.710113587648642,
4.312267729434556,
18.935473315914678,
17.68965247025577,
20.61053154671133
],
"y": [
20.431212351814835,
4.189062767055896,
18.95856209889753,
17.674840683395807,
20.55944905349298
]
}
],
"layout": {
Expand Down Expand Up @@ -1473,7 +1497,7 @@
}
],
"source": [
"import plotly.express as px\n",
"import plotly.graph_objects as go\n",
"\n",
"rmsd_df['pose_index'] = range(1, len(renamed_poses)+1)\n",
"rmsd_df['cluster'] = kmeans.labels_\n",
Expand All @@ -1490,14 +1514,25 @@
" labels={'1':'RMSD to pose 1',\n",
" '2':'RMSD to pose 2',}\n",
" )\n",
"scatter_fig.add_trace(\n",
" go.Scatter(\n",
" x = kmeans.cluster_centers_[:,0],\n",
" y = kmeans.cluster_centers_[:,1],\n",
" mode='markers',\n",
" marker=dict(color=\"red\", \n",
" symbol='x',\n",
" size=10),\n",
" showlegend=False,))\n",
"scatter_fig.update_traces(marker=dict(size=12))\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Adding `molplotly` and specifying `mol_col` allows the 3D coordinates of the pose to be shown on mouseover!"
"Adding `molplotly` and specifying `mol_col` allows the 3D coordinates of the pose to be shown on mouseover!\n",
"\n",
"Unfortunately there are issues with handling additional non-molecular traces in `molplotly` so we'll have to re-create the scatter plot without the centroids."
]
},
{
Expand All @@ -1520,7 +1555,7 @@
" "
],
"text/plain": [
"<IPython.lib.display.IFrame at 0x13733e8c0>"
"<IPython.lib.display.IFrame at 0x139627d30>"
]
},
"metadata": {},
Expand All @@ -1532,6 +1567,20 @@
"\n",
"rmsd_df['mol'] = renamed_poses \n",
"\n",
"scatter_fig = px.scatter(rmsd_df, \n",
" x=\"1\", \n",
" y=\"2\",\n",
" color='cluster',\n",
" hover_name='pose_index',\n",
" width=1000,\n",
" height=800,\n",
" title='Clustering of ligand poses',\n",
" labels={'1':'RMSD to pose 1',\n",
" '2':'RMSD to pose 2',}\n",
" )\n",
"\n",
"scatter_fig.update_traces(marker=dict(size=12))\n",
"\n",
"app_clusters = molplotly.add_molecules(\n",
" fig=scatter_fig,\n",
" df=rmsd_df,\n",
Expand All @@ -1553,7 +1602,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 6,
"metadata": {},
"outputs": [
{
Expand All @@ -1571,7 +1620,7 @@
" "
],
"text/plain": [
"<IPython.lib.display.IFrame at 0x13733f820>"
"<IPython.lib.display.IFrame at 0x13969be50>"
]
},
"metadata": {},
Expand Down
3 changes: 2 additions & 1 deletion molplotly/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,10 @@ def find_grouping(

if fig.data[0].hovertemplate is not None:
col_names = re.findall(r"(.*?)=(?!%).*?<.*?>", fig.data[0].hovertemplate)
col_names = [re.sub(r"(.*)>", "", col_name) for col_name in col_names]
if set(col_names) != set(cols):
raise ValueError(
"marker_col/color_col/facet_col is misspecified because the dataframe grouping names don't match the names in the plotly figure.",
f"marker_col/color_col/facet_col is misspecified because the specified dataframe grouping names {cols} don't match the names in the plotly figure {col_names}.",
)

df_grouped = df_data.groupby(col_names)
Expand Down

0 comments on commit 2f3aff1

Please sign in to comment.