diff --git a/devtools/conda-envs/test_env.yaml b/devtools/conda-envs/test_env.yaml index c9ae7f06a..bd882b8fe 100644 --- a/devtools/conda-envs/test_env.yaml +++ b/devtools/conda-envs/test_env.yaml @@ -25,7 +25,8 @@ dependencies: - geometric # - gromacs =2019.1 - openff-toolkit >=0.11.3 + - rdkit !=2024.03.4 - openff-evaluator-base >= 0.4.1 - pint =0.20 # - openff-recharge - # - openeye-toolkits (Don't have a license file to use with GH Actions.) + - openeye-toolkits diff --git a/src/smirnoffio.py b/src/smirnoffio.py index 5e1cf37c6..adebef217 100644 --- a/src/smirnoffio.py +++ b/src/smirnoffio.py @@ -249,7 +249,27 @@ def smirnoff_update_pgrads(target): class SMIRNOFF(OpenMM): - """ Derived from Engine object for carrying out OpenMM calculations that use the SMIRNOFF force field. """ + """ + Derived from Engine object for carrying out OpenMM calculations that use the SMIRNOFF force field. + + Parameters + ---------- + name : str + ffxml : str + pdb : str + mol : + mol2 : list[str] + mol : Molecule + coords : str + platname : string + precision : string + nonbonded_cutoff + mmopts : dict + vsite_bonds : list + implicit_solvent : string + restrain_k : float + freeze_atoms : list + """ def __init__(self, name="openmm", **kwargs): self.valkwd = ['ffxml', 'pdb', 'mol2', 'platname', 'precision', 'mmopts', 'vsite_bonds', 'implicit_solvent', 'restrain_k', 'freeze_atoms'] @@ -419,8 +439,8 @@ def prepare(self, pbc=False, mmopts={}, **kwargs): n_virtual_sites = 0 self._has_virtual_sites = False - if 'VirtualSites' in interchange.handlers: - n_virtual_sites = len(interchange['VirtualSites'].slot_map) + if 'VirtualSites' in interchange.collections: + n_virtual_sites = len(interchange['VirtualSites'].key_map) if n_virtual_sites > 0: self._has_virtual_sites = True @@ -504,15 +524,6 @@ def update_simulation(self, **kwargs): # delattr(self, 'simulation') # self.vsprm = vsprm.copy() - has_vsites = False - for particle_idx in range(self.system.getNumParticles()): - if self.system.isVirtualSite(particle_idx): - has_vsites = True - - if has_vsites: - raise Exception("ForceBalance can't currently handle SMIRNOFF vsites. " - "Downgrade to ForceBalance 1.9.3 or earlier to handle those.") - if hasattr(self, 'simulation'): UpdateSimulationParameters(self.system, self.simulation) else: diff --git a/src/tests/files/opc/dimer.mol2 b/src/tests/files/opc/dimer.mol2 new file mode 100644 index 000000000..ff30eae2f --- /dev/null +++ b/src/tests/files/opc/dimer.mol2 @@ -0,0 +1,21 @@ +@MOLECULE +***** + 8 5 0 0 0 +SMALL +GASTEIGER + +@ATOM + 1 O 11.9600 8.3530 7.6660 O.3 1 HOH1 0.0000 + 2 H 11.8800 7.4390 7.3960 H 0 HOH0 0.0000 + 3 H 11.5470 8.8500 6.9600 H 0 HOH0 0.0000 + 4 M 10.0920 10.1310 9.5470 Du 1 HOH1 0.0000 + 5 O 9.3160 9.6510 9.8370 O.3 1 HOH1 0.0000 + 6 H 10.7430 9.4520 9.3730 H 1 HOH1 0.0000 + 7 H 11.8930 8.2960 7.5330 H 2 UNK2 0.0000 + 8 M 10.0750 9.9730 9.5630 Du 1 HOH1 0.0000 +@BOND + 1 1 2 1 + 2 1 3 1 + 3 4 5 1 + 4 5 8 1 + 5 6 8 1 diff --git a/src/tests/files/opc/dimer.pdb b/src/tests/files/opc/dimer.pdb new file mode 100644 index 000000000..fbd275150 --- /dev/null +++ b/src/tests/files/opc/dimer.pdb @@ -0,0 +1,13 @@ +REMARK 1 CREATED WITH OPENMM 8.1, 2023-12-05 +CRYST1 93.604 93.604 93.604 90.00 90.00 90.00 P 1 1 +HETATM 1 O HOH A 1 11.960 8.353 7.666 1.00 0.00 O +HETATM 2 H1 HOH A 1 11.880 7.439 7.396 1.00 0.00 H +HETATM 3 H2 HOH A 1 11.547 8.850 6.960 1.00 0.00 H +HETATM 4 EP HOH A 1 10.092 10.131 9.547 1.00 0.00 EP +TER 5 HOH A 1 +HETATM 6 O HOH B 1 9.316 9.651 9.837 1.00 0.00 O +HETATM 7 H1 HOH B 1 10.743 9.452 9.373 1.00 0.00 H +HETATM 8 H2 HOH B 1 11.893 8.296 7.533 1.00 0.00 H +HETATM 9 EP HOH B 1 10.075 9.973 9.563 1.00 0.00 EP +TER 10 HOH B 1 +END diff --git a/src/tests/files/opc/opc.offxml b/src/tests/files/opc/opc.offxml new file mode 100644 index 000000000..e58775b71 --- /dev/null +++ b/src/tests/files/opc/opc.offxml @@ -0,0 +1,139 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/src/tests/test_smirnoffio.py b/src/tests/test_smirnoffio.py new file mode 100644 index 000000000..0beb46995 --- /dev/null +++ b/src/tests/test_smirnoffio.py @@ -0,0 +1,59 @@ +import os + +import pytest + +from forcebalance.nifty import logger +from forcebalance.smirnoffio import SMIRNOFF + +from .__init__ import ForceBalanceTestCase + + +class TestSMIRNOFF(ForceBalanceTestCase): + """Test behavior of SMIRNOFF class""" + + @classmethod + def setup_class(cls): + """ + setup any state specific to the execution of the given class (which usually contains tests). + """ + super(TestSMIRNOFF, cls).setup_class() + + cls.cwd = os.path.dirname(os.path.realpath(__file__)) + + os.chdir(os.path.join(cls.cwd, "files", "opc")) + cls.tmpfolder = os.path.join(cls.cwd, "files", "opc", "temp") + + if not os.path.exists(cls.tmpfolder): + os.makedirs(cls.tmpfolder) + + os.chdir(cls.tmpfolder) + + pytest.importorskip("openeye.oechem") + from openeye import oechem + + if not oechem.OEChemIsLicensed(): + pytest.skip("Need OEChem license to run this test") + for file in ["dimer.pdb", "dimer.mol2", "opc.offxml"]: + os.system(f"ln -fs ../{file}") + + cls.engines = dict() + + try: + import openmm # noqa + + cls.engines["SMIRNOFF"] = SMIRNOFF( + mol2=["dimer.mol2"], + coords="dimer.pdb", + ffxml="opc.offxml", + platname="Reference", + precision="double", + ) + + except ModuleNotFoundError: + logger.warning("OpenMM cannot be imported, skipping OpenMM tests.") + + def test_energy_with_virtual_sites(self): + + data = {name: eng.energy_force() for name, eng in self.engines.items()} + + assert data['SMIRNOFF'] is not None \ No newline at end of file