Skip to content

Commit

Permalink
added stem method
Browse files Browse the repository at this point in the history
  • Loading branch information
Javier Sanchez authored and Javier Sanchez committed Jul 9, 2024
1 parent bd2c6e0 commit 9fbc5c2
Showing 1 changed file with 71 additions and 36 deletions.
107 changes: 71 additions & 36 deletions augur/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,54 +140,75 @@ def f(self, x, labels, pars_fid, sys_fid):
f_out : np.ndarray
Theory vector computed at x.
"""

if len(labels) != len(x):
raise ValueError('The labels should have the same length as the parameters!')
if hasattr(x, "__len__"):
if len(labels) != len(x):
raise ValueError('The labels should have the same length as the parameters!')
else:
if isinstance(x, list):
x = np.array(x)
if x.ndim == 1:
if isinstance(labels, list):
raise ValueError('x is a scalar and labels has more than one entry')

if isinstance(x, list):
x = np.array(x)
# Scalar variable
if isinstance(x, (float, int)):
_pars = pars_fid.copy()
_sys_pars = sys_fid.copy()
if labels in pars_fid.keys():
_pars.update({labels: x})
elif labels in sys_fid.keys():
_sys_pars.update({labels: x})
self.tools.reset()
self.lk.reset()
pmap = ParamsMap(_sys_pars)
cosmo = ccl.Cosmology(**_pars)
self.lk.update(pmap)
self.tools.update(pmap)
self.tools.prepare(cosmo)
f_out = self.lk.compute_theory_vector(self.tools)
# 1D
if x.ndim == 1:
_pars = pars_fid.copy()
_sys_pars = sys_fid.copy()
for i in range(len(labels)):
if labels[i] in pars_fid.keys():
_pars.update({labels[i]: x[i]})
elif labels[i] in sys_fid.keys():
_sys_pars.update({labels[i]: x[i]})
else:
raise ValueError(f'Parameter name {labels[i]} not recognized!')
self.tools.reset()
self.lk.reset()
pmap = ParamsMap(_sys_pars)
cosmo = ccl.Cosmology(**_pars)
self.lk.update(pmap)
self.tools.update(pmap)
self.tools.prepare(cosmo)
f_out = self.lk.compute_theory_vector(self.tools)
# 2D
elif x.ndim == 2:
f_out = []
for i in range(len(labels)):
_pars = pars_fid.copy()
_sys_pars = sys_fid.copy()
for i in range(len(labels)):
if labels[i] in pars_fid.keys():
_pars.update({labels[i]: x[i]})
elif labels[i] in sys_fid.keys():
_sys_pars.update({labels[i]: x[i]})
xi = x[i]
for j in range(len(labels)):
if labels[j] in pars_fid.keys():
_pars.update({labels[j]: xi[j]})
elif labels[j] in sys_fid.keys():
_sys_pars.update({labels[j]: xi[j]})
else:
raise ValueError(f'Parameter name {labels[i]} not recognized!')
raise ValueError(f'Parameter name {labels[j]} not recognized')
self.tools.reset()
self.lk.reset()
pmap = ParamsMap(_sys_pars)
cosmo = ccl.Cosmology(**_pars)
self.lk.update(pmap)
self.tools.update(pmap)
self.tools.prepare(cosmo)
f_out = self.lk.compute_theory_vector(self.tools)
elif x.ndim == 2:
f_out = []
for i in range(len(labels)):
_pars = pars_fid.copy()
_sys_pars = sys_fid.copy()
xi = x[i]
for j in range(len(labels)):
if labels[j] in pars_fid.keys():
_pars.update({labels[j]: xi[j]})
elif labels[j] in sys_fid.keys():
_sys_pars.update({labels[j]: xi[j]})
else:
raise ValueError(f'Parameter name {labels[j]} not recognized')
self.tools.reset()
self.lk.reset()
pmap = ParamsMap(_sys_pars)
cosmo = ccl.Cosmology(**_pars)
self.lk.update(pmap)
self.tools.update(pmap)
self.tools.prepare(cosmo)
f_out.append(self.lk.compute_theory_vector(self.tools))
return np.array(f_out)
f_out.append(self.lk.compute_theory_vector(self.tools))
return np.array(f_out)

def get_derivatives(self, force=False, method='5pt_stencil'):
def get_derivatives(self, force=False, method='stem'):
# Compute the derivatives with respect to the parameters in var_pars at x
if (self.derivatives is None) or (force):
if '5pt_stencil' in method:
Expand All @@ -204,6 +225,20 @@ def get_derivatives(self, force=False, method='5pt_stencil'):
self.req_params),
step=float(self.config['step']),
**ndkwargs)(self.x).T
elif 'stem' in method:
from derivative_calculator import DerivativeCalculator
_aux_ders = np.zeros((len(self.x), len(self.data_fid)))
for i in range(len(self.x)):
_aux = np.zeros((len(self.data_fid)))
for j in range(len(self.data_fid)):
calc = DerivativeCalculator(lambda y: self.f(y, self.var_pars[i],
self.pars_fid,
self.req_params)[j],
self.x[i],
dx=float(self.config['step']))
_aux[j] = calc.stem_method()
_aux_ders[i] = _aux
self.derivatives = np.array(_aux_ders)
else:
raise ValueError(f'Selected method: `{method}` is not available. \
Please select 5pt_stencil or numdifftools.')
Expand Down

0 comments on commit 9fbc5c2

Please sign in to comment.