diff --git a/funsor/integrate.py b/funsor/integrate.py index ab6f98ad..2be76294 100644 --- a/funsor/integrate.py +++ b/funsor/integrate.py @@ -177,13 +177,23 @@ def eager_integrate(delta, integrand, reduced_vars): if reduced_vars.isdisjoint(delta_fresh): return None reduced_names = frozenset(v.name for v in reduced_vars) + new_integrand = integrand + new_log_measure = delta + + # reduced_vars that are in integrand.inputs are substituted in delta and integrand subs = tuple( (name, point) for name, (point, log_density) in delta.terms - if name in reduced_names + if name in reduced_names and name in integrand.inputs ) - new_integrand = Subs(integrand, subs) - new_log_measure = Subs(delta, subs) + if subs: + new_integrand = Subs(new_integrand, subs) + new_log_measure = Subs(new_log_measure, subs) + # reduced vars that are not in integrand.inputs are reduced over in delta + reduced_names = reduced_names.difference(integrand.inputs) + if reduced_names: + new_log_measure = new_log_measure.reduce(ops.logaddexp, reduced_names) + result = Integrate(new_log_measure, new_integrand, reduced_vars - delta_fresh) return result