diff --git a/src/hessian.py b/src/hessian.py index 5d093fac2..7cf400b86 100644 --- a/src/hessian.py +++ b/src/hessian.py @@ -23,11 +23,6 @@ from forcebalance.output import getLogger from forcebalance.optimizer import Counter from forcebalance.vibration import read_reference_vdata, vib_overlap -from geometric.internal import PrimitiveInternalCoordinates, Distance, Angle, Dihedral, OutOfPlane - -import matplotlib.pyplot as plt -import matplotlib.colors as colors -from mpl_toolkits.axes_grid1 import make_axes_locatable, axes_size import copy logger = getLogger(__name__) @@ -39,7 +34,7 @@ class Hessian(Target): def __init__(self,options,tgt_opts,forcefield): """Initialization.""" - + # Initialize the SuperClass! super(Hessian,self).__init__(options,tgt_opts,forcefield) #======================================# @@ -73,11 +68,14 @@ def __init__(self,options,tgt_opts,forcefield): self.denom = 1 def _build_internal_coordinates(self): + + from geometric.internal import PrimitiveInternalCoordinates + m = Molecule(os.path.join(self.tgtdir, "input.mol2")) IC = PrimitiveInternalCoordinates(m) self.IC = IC - def read_reference_data(self): # HJ: copied from vibration.py and modified + def read_reference_data(self): # HJ: copied from vibration.py and modified """ Read the reference hessian data from a file. """ self.ref_Hq_flat = np.loadtxt(self.hfnm) Hq_size =int(np.sqrt(len(self.ref_Hq_flat))) @@ -89,15 +87,16 @@ def read_reference_data(self): # HJ: copied from vibration.py and modified return def get_wts(self): - + + from geometric.internal import Distance, Angle, Dihedral nb = len([ic for ic in self.IC.Internals if isinstance(ic,Distance) ]) nba = nb + len([ic for ic in self.IC.Internals if isinstance(ic,Angle) ]) nbap = nba + len([ic for ic in self.IC.Internals if isinstance(ic,Dihedral) ]) Hq_size =int(np.sqrt(len(self.ref_Hq_flat))) - if self.hess_normalize_type == 0 : + if self.hess_normalize_type == 0 : self.wts = np.ones(len(self.ref_Hq_flat)) - else: + else: raise NotImplementedError # normalize weights self.wts /= np.sum(self.wts) @@ -113,7 +112,7 @@ def indicate(self): def hessian_driver(self): if hasattr(self, 'engine') and hasattr(self.engine, 'normal_modes'): - if self.optimize_geometry == 1: + if self.optimize_geometry == 1: return self.engine.normal_modes(for_hessian_target=True) else: return self.engine.normal_modes(optimize=False, for_hessian_target=True) @@ -127,10 +126,11 @@ def converting_to_int_vec(self, xyz, dx): dq = multi_dot([Bmat,dx]) return dq - def calc_int_normal_mode(self, xyz, cart_normal_mode): + def calc_int_normal_mode(self, xyz, cart_normal_mode): + from geometric.internal import Distance, Angle, Dihedral, OutOfPlane ninternals_eff= len([ic for ic in self.IC.Internals if isinstance(ic,(Distance, Angle, Dihedral, OutOfPlane))]) int_normal_mode = [] - for idx, vec in enumerate(cart_normal_mode): + for idx, vec in enumerate(cart_normal_mode): # convert cartesian coordinates displacement to internal coordinates dq = self.converting_to_int_vec(xyz, vec) int_normal_mode.append(dq[:ninternals_eff]) # disregard Translations and Rotations @@ -141,14 +141,14 @@ def get(self, mvals, AGrad=False, AHess=False): Answer = {'X':0.0, 'G':np.zeros(self.FF.np), 'H':np.zeros((self.FF.np, self.FF.np))} def compute(mvals_): self.FF.make(mvals_) - Xx, Gx, Hx, freqs, normal_modes, M_opt = self.hessian_driver() - # convert into internal hessian + Xx, Gx, Hx, freqs, normal_modes, M_opt = self.hessian_driver() + # convert into internal hessian Xx *= 1/ Bohr2nm Gx *= Bohr2nm/ Hartree2kJmol Hx *= Bohr2nm**2/ Hartree2kJmol Hq = self.IC.calcHess(Xx, Gx, Hx) compute.Hq_flat = Hq.flatten() - compute.freqs = freqs + compute.freqs = freqs compute.normal_modes = normal_modes compute.M_opt = M_opt diff = Hq - self.ref_Hq @@ -167,13 +167,13 @@ def compute(mvals_): Answer['G'][p] = 2*np.dot(V, dV[p,:]) * len(compute.freqs) for q in self.pgrad: Answer['H'][p,q] = 2*np.dot(dV[p,:], dV[q,:]) * len(compute.freqs) - + if not in_fd(): self.Hq_flat = compute.Hq_flat self.Hq = self.Hq_flat.reshape(self.ref_Hq.shape) self.objective = Answer['X'] self.FF.make(mvals) - + if self.writelevel > 0: # 1. write HessianCompare.txt hessian_comparison = np.array([ @@ -183,11 +183,11 @@ def compute(mvals_): np.sqrt(self.wts)/self.denom ]).T np.savetxt("HessianCompare.txt", hessian_comparison, header="%11s %12s %12s %12s" % ("QMHessian", "MMHessian", "Delta(MM-QM)", "Weight"), fmt="% 12.6e") - - # 2. rearrange MM vibrational frequencies using overlap between normal modes in redundant internal coordinates + + # 2. rearrange MM vibrational frequencies using overlap between normal modes in redundant internal coordinates ref_int_normal_modes = self.calc_int_normal_mode(self.ref_xyz, self.ref_eigvecs) int_normal_modes = self.calc_int_normal_mode(np.array(compute.M_opt.xyzs[0]), compute.normal_modes) - a = np.array([[(1.0-np.abs(np.dot(v1/np.linalg.norm(v1),v2/np.linalg.norm(v2)))) for v2 in int_normal_modes] for v1 in ref_int_normal_modes]) + a = np.array([[(1.0-np.abs(np.dot(v1/np.linalg.norm(v1),v2/np.linalg.norm(v2)))) for v2 in int_normal_modes] for v1 in ref_int_normal_modes]) row, c2r = optimize.linear_sum_assignment(a) # old arrangement method, which uses overlap between mass weighted vibrational modes in cartesian coordinates # a = np.array([[(1.0-self.vib_overlap(v1, v2)) for v2 in compute.normal_modes] for v1 in self.ref_eigvecs]) @@ -195,9 +195,9 @@ def compute(mvals_): freqs_rearr = compute.freqs[c2r] normal_modes_rearr = compute.normal_modes[c2r] - + # 3. Save rearranged frequencies and normal modes into a file for post-analysis - with open('mm_vdata.txt', 'w') as outfile: + with open('mm_vdata.txt', 'w') as outfile: outfile.writelines('%s\n' % line for line in compute.M_opt.write_xyz([0])) outfile.write('\n') for freq, normal_mode in zip(freqs_rearr, normal_modes_rearr): @@ -211,12 +211,12 @@ def compute(mvals_): draw_vibfreq_scatter_plot_n_overlap_matrix(self.name, self.engine, self.ref_eigvals, self.ref_eigvecs, freqs_rearr, normal_modes_rearr) return Answer -def cal_corr_coef(A): +def cal_corr_coef(A): # equations from https://math.stackexchange.com/a/1393907 size = len(A) j = np.ones(size) r = np.array(range(1,size+1)) - r2 = r*r + r2 = r*r n = np.dot(np.dot(j, A),j.T) sumx=np.dot(np.dot(r, A),j.T) sumy=np.dot(np.dot(j, A),r.T) @@ -226,7 +226,10 @@ def cal_corr_coef(A): r = (n*sumxy - sumx*sumy)/(np.sqrt(n*sumx2 - (sumx)**2)* np.sqrt(n*sumy2 - (sumy)**2)) return r -def draw_normal_modes(elem, ref_xyz, ref_eigvals, ref_eigvecs, mm_xyz, freqs_rearr, normal_modes_rearr): +def draw_normal_modes(elem, ref_xyz, ref_eigvals, ref_eigvecs, mm_xyz, freqs_rearr, normal_modes_rearr): + + import matplotlib.pyplot as plt + # draw qm and mm normal mode overlay fig, axs = plt.subplots(len(normal_modes_rearr), 1, figsize=(4, 4*len(normal_modes_rearr)), subplot_kw={'projection':'3d'}) def render_normal_modes(elem, xyz, eigvecs, color, qm=False, ref_eigvals=None, eigvals_rearr=None): @@ -235,10 +238,10 @@ def render_normal_modes(elem, xyz, eigvecs, color, qm=False, ref_eigvals=None, e u, v, w = eigvec.T *5 origin = np.array([x, y, z]) axs[idx].quiver(*origin, u, v, w, color=color) - + axs[idx].set_xlabel('x') axs[idx].set_ylabel('y') - axs[idx].set_zlabel('z') + axs[idx].set_zlabel('z') if qm: axs[idx].set_title(f'normal mode #{idx} (blue:QM({ref_eigvals[idx]:.2f}), red:MM({eigvals_rearr[idx]:.2f}))') axs[idx].scatter(x, y, z, color='black', s=30) @@ -250,11 +253,15 @@ def render_normal_modes(elem, xyz, eigvecs, color, qm=False, ref_eigvals=None, e render_normal_modes(elem, ref_xyz, ref_eigvecs, 'blue', qm=True, ref_eigvals=ref_eigvals, eigvals_rearr=freqs_rearr) render_normal_modes(elem, np.array(mm_xyz), normal_modes_rearr, 'red') - + plt.tight_layout() - plt.savefig('mm_vdata.pdf') + plt.savefig('mm_vdata.pdf') def draw_vibfreq_scatter_plot_n_overlap_matrix(name, engine, ref_eigvals, ref_eigvecs, freqs_rearr, normal_modes_rearr): + + import matplotlib.pyplot as plt + from mpl_toolkits.axes_grid1 import make_axes_locatable, axes_size + plt.switch_backend('agg') fig, axs = plt.subplots(1,2, figsize=(10,6)) overlap_matrix = np.array([[(vib_overlap(engine, v1, v2)) for v2 in normal_modes_rearr] for v1 in ref_eigvecs]) @@ -265,15 +272,15 @@ def draw_vibfreq_scatter_plot_n_overlap_matrix(name, engine, ref_eigvals, ref_ei axs[0].legend() axs[0].set_xlabel(r'QM vibrational frequency ($cm^{-1}$)') axs[0].set_ylabel(r'MM vibrational frequency ($cm^{-1}$)') - mae = np.sum(np.abs(ref_eigvals - freqs_rearr))/ len(ref_eigvals) + mae = np.sum(np.abs(ref_eigvals - freqs_rearr))/ len(ref_eigvals) axs[0].set_title(f'QM vs. MM vibrational frequencies\n MAE= {mae:.2f}') x0,x1 = axs[0].get_xlim() y0,y1 = axs[0].get_ylim() axs[0].set_aspect((x1-x0)/(y1-y0)) - # move ax x axis to top - axs[1].xaxis.tick_top() - # move ax x ticks inside + # move ax x axis to top + axs[1].xaxis.tick_top() + # move ax x ticks inside axs[1].tick_params(axis="y", direction='in') axs[1].tick_params(axis="x", direction='in') # draw matrix @@ -286,15 +293,15 @@ def draw_vibfreq_scatter_plot_n_overlap_matrix(name, engine, ref_eigvals, ref_ei pad = axes_size.Fraction(pad_fraction, width) cax = divider.append_axes("right", size=width, pad=pad) cax.yaxis.tick_right() - cax.xaxis.set_visible(False) - plt.colorbar(im, cax=cax) + cax.xaxis.set_visible(False) + plt.colorbar(im, cax=cax) corr_coef = cal_corr_coef(overlap_matrix) err = np.linalg.norm(qm_overlap_matrix - overlap_matrix)/np.linalg.norm(qm_overlap_matrix) # measure of error in matrix (Relative error) axs[1].set_title(f'QM vs. MM normal modes\n Correlation coef. ={corr_coef:.4f}, Error={err:.4f}') - # # move ax x axis to top - # axs[2].xaxis.tick_top() - # # move ax x ticks inside + # # move ax x axis to top + # axs[2].xaxis.tick_top() + # # move ax x ticks inside # axs[2].tick_params(axis="y", direction='in') # axs[2].tick_params(axis="x", direction='in') # # draw matrix @@ -307,11 +314,11 @@ def draw_vibfreq_scatter_plot_n_overlap_matrix(name, engine, ref_eigvals, ref_ei # pad = axes_size.Fraction(pad_fraction, width) # cax = divider.append_axes("right", size=width, pad=pad) # cax.yaxis.tick_right() - # cax.xaxis.set_visible(False) - # plt.colorbar(im, cax=cax) + # cax.xaxis.set_visible(False) + # plt.colorbar(im, cax=cax) # axs[2].set_title(f'(QM normal modes for reference)') - plt.tight_layout() + plt.tight_layout() plt.subplots_adjust(top=0.85) fig.suptitle('Hessian: iteration %i\nSystem: %s' % (Counter(), name)) fig.savefig('vibfreq_scatter_plot_n_overlap_matrix.pdf') \ No newline at end of file diff --git a/src/smirnoff_hack.py b/src/smirnoff_hack.py index 83bbc77cf..88be607c2 100644 --- a/src/smirnoff_hack.py +++ b/src/smirnoff_hack.py @@ -1,120 +1,158 @@ ## HACK: Improve the performance of the openff forcefield.create_openmm_system() +import os from openff.toolkit.utils.toolkits import OpenEyeToolkitWrapper, RDKitToolkitWrapper, AmberToolsToolkitWrapper from openff.toolkit.topology.molecule import Molecule -# time based on total 540s evaluation -# cache for OE find_smarts_matches (save 300+ s) -oe_original_find_smarts_matches = OpenEyeToolkitWrapper.find_smarts_matches -OE_TOOLKIT_CACHE_find_smarts_matches = {} -def oe_cached_find_smarts_matches(self, molecule, smarts, aromaticity_model='OEAroModel_MDL'): - cache_key = hash((molecule, smarts, aromaticity_model)) - if cache_key not in OE_TOOLKIT_CACHE_find_smarts_matches: - OE_TOOLKIT_CACHE_find_smarts_matches[cache_key] = oe_original_find_smarts_matches(self, molecule, smarts, aromaticity_model=aromaticity_model) - return OE_TOOLKIT_CACHE_find_smarts_matches[cache_key] -# replace the original function with new one -OpenEyeToolkitWrapper.find_smarts_matches = oe_cached_find_smarts_matches - -# cache for RDK find_smarts_matches -rdk_original_find_smarts_matches = RDKitToolkitWrapper.find_smarts_matches -RDK_TOOLKIT_CACHE_find_smarts_matches = {} -def rdk_cached_find_smarts_matches(self, molecule, smarts, aromaticity_model='OEAroModel_MDL'): - cache_key = hash((molecule, smarts, aromaticity_model)) - if cache_key not in RDK_TOOLKIT_CACHE_find_smarts_matches: - RDK_TOOLKIT_CACHE_find_smarts_matches[cache_key] = rdk_original_find_smarts_matches(self, molecule, smarts, aromaticity_model=aromaticity_model) - return RDK_TOOLKIT_CACHE_find_smarts_matches[cache_key] -# replace the original function with new one -RDKitToolkitWrapper.find_smarts_matches = rdk_cached_find_smarts_matches - - -# cache for the validate function (save 94s) -from openff.toolkit.typing.chemistry.environment import ChemicalEnvironment -original_validate = ChemicalEnvironment.validate -TOOLKIT_CACHE_ChemicalEnvironment_validate = {} -def cached_validate(smirks, validate_valence_type=True, toolkit_registry=OpenEyeToolkitWrapper): - cache_key = hash((smirks, validate_valence_type, toolkit_registry)) - if cache_key not in TOOLKIT_CACHE_ChemicalEnvironment_validate: - TOOLKIT_CACHE_ChemicalEnvironment_validate[cache_key] = original_validate(smirks, validate_valence_type=validate_valence_type, toolkit_registry=toolkit_registry) - return TOOLKIT_CACHE_ChemicalEnvironment_validate[cache_key] -ChemicalEnvironment.validate = cached_validate - - -# cache for compute_partial_charges_am1bcc (save 69s) -# No longer needed as of 0.7.0 since all partial charge assignment is routed through ToolkitWrapper.assign_partial_charges -# original_compute_partial_charges_am1bcc = OpenEyeToolkitWrapper.compute_partial_charges_am1bcc -# TOOLKIT_CACHE_compute_partial_charges_am1bcc = {} -# def cached_compute_partial_charges_am1bcc(self, molecule, use_conformers=None, strict_n_conformers=False): -# cache_key = hash(molecule, use_conformers, strict_n_conformers) -# if cache_key not in TOOLKIT_CACHE_compute_partial_charges_am1bcc: -# TOOLKIT_CACHE_compute_partial_charges_am1bcc[cache_key] = original_compute_partial_charges_am1bcc(self, molecule, use_conformers=use_conformers, strict_n_conformers=strict_n_conformers) -# return TOOLKIT_CACHE_compute_partial_charges_am1bcc[cache_key] -# OpenEyeToolkitWrapper.compute_partial_charges_am1bcc = cached_compute_partial_charges_am1bcc - - -# Cache for OETK assign_partial_charges -oe_original_assign_partial_charges = OpenEyeToolkitWrapper.assign_partial_charges -OE_TOOLKIT_CACHE_assign_partial_charges = {} -def oe_cached_assign_partial_charges(self, molecule, partial_charge_method=None, use_conformers=None, strict_n_conformers=False, _cls=Molecule): - cache_key = hash((molecule, partial_charge_method, str(use_conformers), strict_n_conformers)) - if cache_key not in OE_TOOLKIT_CACHE_assign_partial_charges: - oe_original_assign_partial_charges(self, molecule, partial_charge_method=partial_charge_method, use_conformers=use_conformers, strict_n_conformers=strict_n_conformers, _cls=_cls) - OE_TOOLKIT_CACHE_assign_partial_charges[cache_key] = molecule.partial_charges - else: - molecule.partial_charges = OE_TOOLKIT_CACHE_assign_partial_charges[cache_key] - return -OpenEyeToolkitWrapper.assign_partial_charges = oe_cached_assign_partial_charges - - -# Cache for AmberTools assign_partial_charges -at_original_assign_partial_charges = AmberToolsToolkitWrapper.assign_partial_charges -AT_TOOLKIT_CACHE_assign_partial_charges = {} -def at_cached_assign_partial_charges(self, molecule, partial_charge_method=None, use_conformers=None, strict_n_conformers=False, _cls=Molecule): - cache_key = hash((molecule, partial_charge_method, str(use_conformers), strict_n_conformers)) - if cache_key not in AT_TOOLKIT_CACHE_assign_partial_charges: - at_original_assign_partial_charges(self, molecule, partial_charge_method=partial_charge_method, use_conformers=use_conformers, strict_n_conformers=strict_n_conformers, _cls=_cls) - AT_TOOLKIT_CACHE_assign_partial_charges[cache_key] = molecule.partial_charges - else: - molecule.partial_charges = AT_TOOLKIT_CACHE_assign_partial_charges[cache_key] - return -AmberToolsToolkitWrapper.assign_partial_charges = at_cached_assign_partial_charges - - -# cache the OE generate_conformers function (save 15s) -OE_TOOLKIT_CACHE_molecule_conformers = {} -oe_original_generate_conformers = OpenEyeToolkitWrapper.generate_conformers -def oe_cached_generate_conformers(self, molecule, n_conformers=1, rms_cutoff=None, clear_existing=True): - cache_key = hash((molecule, n_conformers, str(rms_cutoff), clear_existing)) - if cache_key not in OE_TOOLKIT_CACHE_molecule_conformers: - oe_original_generate_conformers(self, molecule, n_conformers=n_conformers, rms_cutoff=rms_cutoff, clear_existing=clear_existing) - OE_TOOLKIT_CACHE_molecule_conformers[cache_key] = molecule._conformers - molecule._conformers = OE_TOOLKIT_CACHE_molecule_conformers[cache_key] -OpenEyeToolkitWrapper.generate_conformers = oe_cached_generate_conformers - - -# cache the RDKit generate_conformers function -RDK_TOOLKIT_CACHE_molecule_conformers = {} -rdk_original_generate_conformers = RDKitToolkitWrapper.generate_conformers -def rdk_cached_generate_conformers(self, molecule, n_conformers=1, rms_cutoff=None, clear_existing=True): - cache_key = hash((molecule, n_conformers, str(rms_cutoff), clear_existing)) - if cache_key not in RDK_TOOLKIT_CACHE_molecule_conformers: - rdk_original_generate_conformers(self, molecule, n_conformers=n_conformers, rms_cutoff=rms_cutoff, clear_existing=clear_existing) - RDK_TOOLKIT_CACHE_molecule_conformers[cache_key] = molecule._conformers - molecule._conformers = RDK_TOOLKIT_CACHE_molecule_conformers[cache_key] -RDKitToolkitWrapper.generate_conformers = rdk_cached_generate_conformers - - -# final timing: 56s - -# cache the ForceField creation (no longer needed since using OpenFF API for parameter modifications) - -# import hashlib -# from openff.toolkit.typing.engines.smirnoff import ForceField -# SMIRNOFF_FORCE_FIELD_CACHE = {} -# def getForceField(*ffpaths): -# hasher = hashlib.md5() -# for path in ffpaths: -# with open(path, 'rb') as f: -# hasher.update(f.read()) -# cache_key = hasher.hexdigest() -# if cache_key not in SMIRNOFF_FORCE_FIELD_CACHE: -# SMIRNOFF_FORCE_FIELD_CACHE[cache_key] = ForceField(*ffpaths, allow_cosmetic_attributes=True) -# return SMIRNOFF_FORCE_FIELD_CACHE[cache_key] + +# Add a mechanism for disabling SMIRNOFF hack entirely as it is prone to breaking +# when upstream dependencies (especially the toolkit) are updated. +_SHOULD_CACHE = os.environ.get("ENABLE_FB_SMIRNOFF_CACHING") + +# Caching of SMIRNOFF functions is enabled by default, including when the +# "ENABLE_FB_SMIRNOFF_CACHING" environmental variable is not set. The user +# Can `export ENABLE_FB_SMIRNOFF_CACHING=False` to disable all caching. +if _SHOULD_CACHE is None: + _SHOULD_CACHE = True +else: + _SHOULD_CACHE = _SHOULD_CACHE.lower() in ["true", "1", "yes"] + + +def hash_molecule(molecule): + + atom_map = None + + if "atom_map" in molecule.properties: + # Store a copy of any existing atom map + atom_map = molecule.properties.pop("atom_map") + + cmiles = molecule.to_smiles(mapped=True) + + if atom_map is not None: + molecule.properties["atom_map"] = atom_map + + return cmiles + + +if _SHOULD_CACHE: + + print( + "SMIRNOFF functions will be replaced with cached versions to improve their " + "performance." + ) + + # time based on total 540s evaluation + # cache for OE find_smarts_matches (save 300+ s) + oe_original_find_smarts_matches = OpenEyeToolkitWrapper.find_smarts_matches + OE_TOOLKIT_CACHE_find_smarts_matches = {} + def oe_cached_find_smarts_matches(self, molecule, smarts, aromaticity_model='OEAroModel_MDL'): + cache_key = hash((hash_molecule(molecule), smarts, aromaticity_model)) + if cache_key not in OE_TOOLKIT_CACHE_find_smarts_matches: + OE_TOOLKIT_CACHE_find_smarts_matches[cache_key] = oe_original_find_smarts_matches(self, molecule, smarts, aromaticity_model=aromaticity_model) + return OE_TOOLKIT_CACHE_find_smarts_matches[cache_key] + # replace the original function with new one + OpenEyeToolkitWrapper.find_smarts_matches = oe_cached_find_smarts_matches + + # cache for RDK find_smarts_matches + rdk_original_find_smarts_matches = RDKitToolkitWrapper.find_smarts_matches + RDK_TOOLKIT_CACHE_find_smarts_matches = {} + def rdk_cached_find_smarts_matches(self, molecule, smarts, aromaticity_model='OEAroModel_MDL'): + cache_key = hash((hash_molecule(molecule), smarts, aromaticity_model)) + if cache_key not in RDK_TOOLKIT_CACHE_find_smarts_matches: + RDK_TOOLKIT_CACHE_find_smarts_matches[cache_key] = rdk_original_find_smarts_matches(self, molecule, smarts, aromaticity_model=aromaticity_model) + return RDK_TOOLKIT_CACHE_find_smarts_matches[cache_key] + # replace the original function with new one + RDKitToolkitWrapper.find_smarts_matches = rdk_cached_find_smarts_matches + + + # cache for the validate function (save 94s) + from openff.toolkit.typing.chemistry.environment import ChemicalEnvironment + original_validate = ChemicalEnvironment.validate + TOOLKIT_CACHE_ChemicalEnvironment_validate = {} + def cached_validate(smirks, validate_valence_type=True, toolkit_registry=OpenEyeToolkitWrapper): + cache_key = hash((smirks, validate_valence_type, toolkit_registry)) + if cache_key not in TOOLKIT_CACHE_ChemicalEnvironment_validate: + TOOLKIT_CACHE_ChemicalEnvironment_validate[cache_key] = original_validate(smirks, validate_valence_type=validate_valence_type, toolkit_registry=toolkit_registry) + return TOOLKIT_CACHE_ChemicalEnvironment_validate[cache_key] + ChemicalEnvironment.validate = cached_validate + + + # cache for compute_partial_charges_am1bcc (save 69s) + # No longer needed as of 0.7.0 since all partial charge assignment is routed through ToolkitWrapper.assign_partial_charges + # original_compute_partial_charges_am1bcc = OpenEyeToolkitWrapper.compute_partial_charges_am1bcc + # TOOLKIT_CACHE_compute_partial_charges_am1bcc = {} + # def cached_compute_partial_charges_am1bcc(self, molecule, use_conformers=None, strict_n_conformers=False): + # cache_key = hash(molecule, use_conformers, strict_n_conformers) + # if cache_key not in TOOLKIT_CACHE_compute_partial_charges_am1bcc: + # TOOLKIT_CACHE_compute_partial_charges_am1bcc[cache_key] = original_compute_partial_charges_am1bcc(self, molecule, use_conformers=use_conformers, strict_n_conformers=strict_n_conformers) + # return TOOLKIT_CACHE_compute_partial_charges_am1bcc[cache_key] + # OpenEyeToolkitWrapper.compute_partial_charges_am1bcc = cached_compute_partial_charges_am1bcc + + + # Cache for OETK assign_partial_charges + oe_original_assign_partial_charges = OpenEyeToolkitWrapper.assign_partial_charges + OE_TOOLKIT_CACHE_assign_partial_charges = {} + def oe_cached_assign_partial_charges(self, molecule, partial_charge_method=None, use_conformers=None, strict_n_conformers=False, _cls=Molecule): + cache_key = hash((hash_molecule(molecule), partial_charge_method, str(use_conformers), strict_n_conformers)) + if cache_key not in OE_TOOLKIT_CACHE_assign_partial_charges: + oe_original_assign_partial_charges(self, molecule, partial_charge_method=partial_charge_method, use_conformers=use_conformers, strict_n_conformers=strict_n_conformers, _cls=_cls) + OE_TOOLKIT_CACHE_assign_partial_charges[cache_key] = molecule.partial_charges + else: + molecule.partial_charges = OE_TOOLKIT_CACHE_assign_partial_charges[cache_key] + return + OpenEyeToolkitWrapper.assign_partial_charges = oe_cached_assign_partial_charges + + + # Cache for AmberTools assign_partial_charges + at_original_assign_partial_charges = AmberToolsToolkitWrapper.assign_partial_charges + AT_TOOLKIT_CACHE_assign_partial_charges = {} + def at_cached_assign_partial_charges(self, molecule, partial_charge_method=None, use_conformers=None, strict_n_conformers=False, _cls=Molecule): + cache_key = hash((hash_molecule(molecule), partial_charge_method, str(use_conformers), strict_n_conformers)) + if cache_key not in AT_TOOLKIT_CACHE_assign_partial_charges: + at_original_assign_partial_charges(self, molecule, partial_charge_method=partial_charge_method, use_conformers=use_conformers, strict_n_conformers=strict_n_conformers, _cls=_cls) + AT_TOOLKIT_CACHE_assign_partial_charges[cache_key] = molecule.partial_charges + else: + molecule.partial_charges = AT_TOOLKIT_CACHE_assign_partial_charges[cache_key] + return + AmberToolsToolkitWrapper.assign_partial_charges = at_cached_assign_partial_charges + + + # cache the OE generate_conformers function (save 15s) + OE_TOOLKIT_CACHE_molecule_conformers = {} + oe_original_generate_conformers = OpenEyeToolkitWrapper.generate_conformers + def oe_cached_generate_conformers(self, molecule, n_conformers=1, rms_cutoff=None, clear_existing=True): + cache_key = hash((hash_molecule(molecule), n_conformers, str(rms_cutoff), clear_existing)) + if cache_key not in OE_TOOLKIT_CACHE_molecule_conformers: + oe_original_generate_conformers(self, molecule, n_conformers=n_conformers, rms_cutoff=rms_cutoff, clear_existing=clear_existing) + OE_TOOLKIT_CACHE_molecule_conformers[cache_key] = molecule._conformers + molecule._conformers = OE_TOOLKIT_CACHE_molecule_conformers[cache_key] + OpenEyeToolkitWrapper.generate_conformers = oe_cached_generate_conformers + + + # cache the RDKit generate_conformers function + RDK_TOOLKIT_CACHE_molecule_conformers = {} + rdk_original_generate_conformers = RDKitToolkitWrapper.generate_conformers + def rdk_cached_generate_conformers(self, molecule, n_conformers=1, rms_cutoff=None, clear_existing=True): + cache_key = hash((hash_molecule(molecule), n_conformers, str(rms_cutoff), clear_existing)) + if cache_key not in RDK_TOOLKIT_CACHE_molecule_conformers: + rdk_original_generate_conformers(self, molecule, n_conformers=n_conformers, rms_cutoff=rms_cutoff, clear_existing=clear_existing) + RDK_TOOLKIT_CACHE_molecule_conformers[cache_key] = molecule._conformers + molecule._conformers = RDK_TOOLKIT_CACHE_molecule_conformers[cache_key] + RDKitToolkitWrapper.generate_conformers = rdk_cached_generate_conformers + + + # final timing: 56s + + # cache the ForceField creation (no longer needed since using OpenFF API for parameter modifications) + + # import hashlib + # from openff.toolkit.typing.engines.smirnoff import ForceField + # SMIRNOFF_FORCE_FIELD_CACHE = {} + # def getForceField(*ffpaths): + # hasher = hashlib.md5() + # for path in ffpaths: + # with open(path, 'rb') as f: + # hasher.update(f.read()) + # cache_key = hasher.hexdigest() + # if cache_key not in SMIRNOFF_FORCE_FIELD_CACHE: + # SMIRNOFF_FORCE_FIELD_CACHE[cache_key] = ForceField(*ffpaths, allow_cosmetic_attributes=True) + # return SMIRNOFF_FORCE_FIELD_CACHE[cache_key] diff --git a/src/smirnoffio.py b/src/smirnoffio.py index 6c0da03b7..dab32d84e 100644 --- a/src/smirnoffio.py +++ b/src/smirnoffio.py @@ -528,9 +528,11 @@ def _update_positions(self, X1, disable_vsite): self.mod.getTopology().getNumAtoms() - self.pdb.topology.getNumAtoms() ) - # Add placeholder positions for any v-sites. - sites = np.zeros((n_v_sites, 3)) - X1 = np.vstack((X1, sites)) + # Add placeholder positions for an v-sites. + if isinstance(X1, np.ndarray): + X1 = numpy.vstack([X1, np.zeros((n_v_sites, 3))]) * angstrom + else: + X1 = (X1 + [Vec3(0.0, 0.0, 0.0)] * n_v_sites) * angstrom self.simulation.context.setPositions(X1 * angstrom) self.simulation.context.computeVirtualSites()