From 2f3aff17cf2bbd7e3f1bbbe4be897bd8c3ebe505 Mon Sep 17 00:00:00 2001 From: wjm41 Date: Thu, 24 Nov 2022 15:18:40 +0000 Subject: [PATCH] minor fix --- examples/pose_centroids_molplotly.ipynb | 59 ++++++++++++++++++++++--- molplotly/main.py | 3 +- 2 files changed, 56 insertions(+), 6 deletions(-) diff --git a/examples/pose_centroids_molplotly.ipynb b/examples/pose_centroids_molplotly.ipynb index fde8075..bb8a637 100644 --- a/examples/pose_centroids_molplotly.ipynb +++ b/examples/pose_centroids_molplotly.ipynb @@ -614,6 +614,30 @@ 20.56523380975767 ], "yaxis": "y" + }, + { + "marker": { + "color": "red", + "size": 12, + "symbol": "x" + }, + "mode": "markers", + "showlegend": false, + "type": "scatter", + "x": [ + 20.710113587648642, + 4.312267729434556, + 18.935473315914678, + 17.68965247025577, + 20.61053154671133 + ], + "y": [ + 20.431212351814835, + 4.189062767055896, + 18.95856209889753, + 17.674840683395807, + 20.55944905349298 + ] } ], "layout": { @@ -1473,7 +1497,7 @@ } ], "source": [ - "import plotly.express as px\n", + "import plotly.graph_objects as go\n", "\n", "rmsd_df['pose_index'] = range(1, len(renamed_poses)+1)\n", "rmsd_df['cluster'] = kmeans.labels_\n", @@ -1490,6 +1514,15 @@ " labels={'1':'RMSD to pose 1',\n", " '2':'RMSD to pose 2',}\n", " )\n", + "scatter_fig.add_trace(\n", + " go.Scatter(\n", + " x = kmeans.cluster_centers_[:,0],\n", + " y = kmeans.cluster_centers_[:,1],\n", + " mode='markers',\n", + " marker=dict(color=\"red\", \n", + " symbol='x',\n", + " size=10),\n", + " showlegend=False,))\n", "scatter_fig.update_traces(marker=dict(size=12))\n" ] }, @@ -1497,7 +1530,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Adding `molplotly` and specifying `mol_col` allows the 3D coordinates of the pose to be shown on mouseover!" + "Adding `molplotly` and specifying `mol_col` allows the 3D coordinates of the pose to be shown on mouseover!\n", + "\n", + "Unfortunately there are issues with handling additional non-molecular traces in `molplotly` so we'll have to re-create the scatter plot without the centroids." ] }, { @@ -1520,7 +1555,7 @@ " " ], "text/plain": [ - "" + "" ] }, "metadata": {}, @@ -1532,6 +1567,20 @@ "\n", "rmsd_df['mol'] = renamed_poses \n", "\n", + "scatter_fig = px.scatter(rmsd_df, \n", + " x=\"1\", \n", + " y=\"2\",\n", + " color='cluster',\n", + " hover_name='pose_index',\n", + " width=1000,\n", + " height=800,\n", + " title='Clustering of ligand poses',\n", + " labels={'1':'RMSD to pose 1',\n", + " '2':'RMSD to pose 2',}\n", + " )\n", + "\n", + "scatter_fig.update_traces(marker=dict(size=12))\n", + "\n", "app_clusters = molplotly.add_molecules(\n", " fig=scatter_fig,\n", " df=rmsd_df,\n", @@ -1553,7 +1602,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "metadata": {}, "outputs": [ { @@ -1571,7 +1620,7 @@ " " ], "text/plain": [ - "" + "" ] }, "metadata": {}, diff --git a/molplotly/main.py b/molplotly/main.py index 5d5359c..0eb63ec 100644 --- a/molplotly/main.py +++ b/molplotly/main.py @@ -55,9 +55,10 @@ def find_grouping( if fig.data[0].hovertemplate is not None: col_names = re.findall(r"(.*?)=(?!%).*?<.*?>", fig.data[0].hovertemplate) + col_names = [re.sub(r"(.*)>", "", col_name) for col_name in col_names] if set(col_names) != set(cols): raise ValueError( - "marker_col/color_col/facet_col is misspecified because the dataframe grouping names don't match the names in the plotly figure.", + f"marker_col/color_col/facet_col is misspecified because the specified dataframe grouping names {cols} don't match the names in the plotly figure {col_names}.", ) df_grouped = df_data.groupby(col_names)