Skip to content

Commit

Permalink
Speed up the MBAR calculation (#357)
Browse files Browse the repository at this point in the history
* adapted @mrshirts 's advice of using BAR results as initial guess into the MBAR, 
  which seems to provide a 5-fold speed up.
  * MBAR can now take initial_nk="BAR" to enable this speed up
  * changed default of initial_nk from None (all zeroes) to "BAR"
* optimise the convergence detection as well, where using the result from the previous 
  MBAR run as input into the next MBAR run. Compared to using BAR as input, this approach
  still provides some speed up.
* update CHANGES
* add tests
  • Loading branch information
xiki-tempula authored May 17, 2024
1 parent 55870c8 commit 46cc83b
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 10 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ jobs:
MPLBACKEND: agg

- name: Codecov
uses: codecov/codecov-action@v3
uses: codecov/codecov-action@v4
with:
token: ${{ secrets.CODECOV_TOKEN }}
name: codecov-${{ matrix.os }}-py${{ matrix.python-version }}
Expand Down
13 changes: 13 additions & 0 deletions CHANGES
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,19 @@ The rules for this file:

------------------------------------------------------------------------------

**/**/2024 xiki-tempula

* 2.3.0

Changes
- Default value for keyword argument `initial_nk` of the MBAR estimator was
changed to "BAR" (run an initial BAR calculation before MBAR) instead of
`None` (start from all zeros) as this change provides a sizable speedup (PR #357)

Enhancements
- `BAR` result is used as initial guess for `MBAR` estimator. (PR #357)
- `forward_backward_convergence` uses the result from the previous step as the initial guess for the next step. (PR #357)

06/04/2024 hl2500, xiki-tempula

* 2.2.0
Expand Down
10 changes: 7 additions & 3 deletions src/alchemlyb/convergence/convergence.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def forward_backward_convergence(df_list, estimator="MBAR", num=10, **kwargs):
raise ValueError(msg)
else:
# select estimator class by name
estimator_fit = estimators_dispatch[estimator](**kwargs).fit
my_estimator = estimators_dispatch[estimator](**kwargs)
logger.info(f"Use {estimator} estimator for convergence analysis.")

logger.info("Begin forward analysis")
Expand All @@ -94,7 +94,9 @@ def forward_backward_convergence(df_list, estimator="MBAR", num=10, **kwargs):
for data in df_list:
sample.append(data[: len(data) // num * i])
sample = concat(sample)
result = estimator_fit(sample)
result = my_estimator.fit(sample)
if estimator == "MBAR":
my_estimator.initial_f_k = result.delta_f_.iloc[0, :]
forward_list.append(result.delta_f_.iloc[0, -1])
if estimator.lower() == "bar":
error = np.sqrt(
Expand All @@ -121,7 +123,9 @@ def forward_backward_convergence(df_list, estimator="MBAR", num=10, **kwargs):
for data in df_list:
sample.append(data[-len(data) // num * i :])
sample = concat(sample)
result = estimator_fit(sample)
result = my_estimator.fit(sample)
if estimator == "MBAR":
my_estimator.initial_f_k = result.delta_f_.iloc[0, :]
backward_list.append(result.delta_f_.iloc[0, -1])
if estimator.lower() == "bar":
error = np.sqrt(
Expand Down
42 changes: 36 additions & 6 deletions src/alchemlyb/estimators/mbar_.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
from __future__ import annotations
from typing import Literal

import numpy as np
import pandas as pd
import pymbar
from sklearn.base import BaseEstimator

from . import BAR
from .base import _EstimatorMixOut


Expand All @@ -17,9 +22,18 @@ class MBAR(BaseEstimator, _EstimatorMixOut):
relative_tolerance : float, optional
Set to determine the relative tolerance convergence criteria.
initial_f_k : np.ndarray, float, shape=(K), optional
Set to the initial dimensionless free energies to use as a
guess (default None, which sets all :math:`f_k = 0`).
initial_f_k : np.ndarray, float, shape=(K), optional or String `BAR`
When `isinstance(initial_f_k, np.ndarray)`, `initial_f_k` will be used as
initial guess for MBAR estimator. initial_f_k should be dimensionless
free energies.
When `initial_f_k` is ``None``, ``initial_f_k`` will be set to 0.
When `initial_f_k` is set to "BAR", a BAR calculation will be done and
the result is used as the initial guess (default).
.. versionchanged:: 2.3.0
The new default is now "BAR" as it provides a substantial speedup
over the previous default `None`.
method : str, optional, default="robust"
The optimization routine to use. This can be any of the methods
Expand Down Expand Up @@ -71,14 +85,19 @@ def __init__(
self,
maximum_iterations=10000,
relative_tolerance=1.0e-7,
initial_f_k=None,
initial_f_k: np.ndarray | Literal["BAR"] | None = "BAR",
method="robust",
n_bootstraps=0,
verbose=False,
):
self.maximum_iterations = maximum_iterations
self.relative_tolerance = relative_tolerance
self.initial_f_k = initial_f_k
if isinstance(initial_f_k, str) and initial_f_k != "BAR":
raise ValueError(
f"Only `BAR` is supported as string input to `initial_f_k`. Got ({initial_f_k})."
)
else:
self.initial_f_k = initial_f_k
self.method = method
self.verbose = verbose
self.n_bootstraps = n_bootstraps
Expand Down Expand Up @@ -108,13 +127,24 @@ def fit(self, u_nk):
]
self._states_ = u_nk.columns.values.tolist()

if isinstance(self.initial_f_k, str) and self.initial_f_k == "BAR":
bar = BAR(
maximum_iterations=self.maximum_iterations,
relative_tolerance=self.relative_tolerance,
verbose=self.verbose,
)
bar.fit(u_nk)
initial_f_k = bar.delta_f_.iloc[0, :]
else:
initial_f_k = self.initial_f_k

self._mbar = pymbar.MBAR(
u_nk.T,
N_k,
maximum_iterations=self.maximum_iterations,
relative_tolerance=self.relative_tolerance,
verbose=self.verbose,
initial_f_k=self.initial_f_k,
initial_f_k=initial_f_k,
solver_protocol=self.method,
n_bootstraps=self.n_bootstraps,
)
Expand Down
16 changes: 16 additions & 0 deletions src/alchemlyb/tests/test_fep_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
"""

import numpy as np
import pytest

import alchemlyb
Expand Down Expand Up @@ -151,3 +152,18 @@ def test_bootstrap(gmx_benzene_Coulomb_u_nk):

assert mbar_bootstrap_mean == mbar_mean
assert mbar_bootstrap_err != mbar_err


def test_wrong_initial_f_k():
with pytest.raises(
ValueError, match="Only `BAR` is supported as string input to `initial_f_k`"
):
MBAR(initial_f_k="aaa")


@pytest.mark.parametrize("initial_f_k", ["BAR", None])
def test_initial_f_k(gmx_benzene_Coulomb_u_nk, initial_f_k):
u_nk = alchemlyb.concat(gmx_benzene_Coulomb_u_nk)
mbar = MBAR(initial_f_k=initial_f_k)
mbar.fit(u_nk)
assert np.isclose(mbar.delta_f_.loc[0.00, 1.00], 3.0411556983908046)

0 comments on commit 46cc83b

Please sign in to comment.