Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Provenance funsor #593

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/source/funsors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,11 @@ Constant
:undoc-members:
:show-inheritance:
:member-order: bysource

Provenance
----------
.. automodule:: funsor.provenance
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
4 changes: 4 additions & 0 deletions funsor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from funsor.integrate import Integrate
from funsor.interpreter import interpretation, reinterpret
from funsor.op_factory import make_op
from funsor.provenance import Provenance
from funsor.sum_product import MarkovProduct
from funsor.tensor import Tensor, function
from funsor.terms import (
Expand Down Expand Up @@ -47,6 +48,7 @@
montecarlo,
ops,
precondition,
provenance,
recipes,
sum_product,
terms,
Expand All @@ -71,6 +73,7 @@
"Number",
"Real",
"Reals",
"Provenance",
"Slice",
"Stack",
"Tensor",
Expand Down Expand Up @@ -105,6 +108,7 @@
"ops",
"precondition",
"pretty",
"provenance",
"quote",
"reals",
"recipes",
Expand Down
9 changes: 9 additions & 0 deletions funsor/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from funsor.domains import Array, Real, Reals
from funsor.gaussian import Gaussian
from funsor.interpreter import gensym
from funsor.provenance import Provenance
from funsor.tensor import (
Tensor,
align_tensors,
Expand Down Expand Up @@ -458,6 +459,14 @@ def backenddist_to_funsor(
for param_name in funsor_dist_class._ast_fields
if param_name != "value"
]
provenance = frozenset().union(
*[param.provenance for param in params if isinstance(param, Provenance)]
)
if provenance:
params = [
param.term if isinstance(param, Provenance) else param for param in params
]
return Provenance(funsor_dist_class(*params), provenance)
return funsor_dist_class(*params)


Expand Down
2 changes: 2 additions & 0 deletions funsor/integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from funsor.delta import Delta
from funsor.gaussian import Gaussian, _norm2, _vm, align_gaussian
from funsor.interpretations import eager, normalize
from funsor.provenance import Provenance
from funsor.tensor import Tensor
from funsor.terms import (
Funsor,
Expand Down Expand Up @@ -139,6 +140,7 @@ def normalize_integrate_contraction(log_measure, integrand, reduced_vars):
Tensor,
GaussianMixture,
EagerConstant,
Provenance,
),
)
def eager_contraction_binary_to_integrate(red_op, bin_op, reduced_vars, lhs, rhs):
Expand Down
105 changes: 105 additions & 0 deletions funsor/provenance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from collections import OrderedDict

import funsor.ops as ops
from funsor.tensor import Tensor
from funsor.terms import Binary, Funsor, FunsorMeta, Number, Unary, Variable, eager


class ProvenanceMeta(FunsorMeta):
"""
Wrapper to combine provenance information from the term.
"""

def __call__(cls, term, provenance):
while isinstance(term, Provenance):
provenance |= term.provenance
term = term.term

return super(ProvenanceMeta, cls).__call__(term, provenance)


class Provenance(Funsor, metaclass=ProvenanceMeta):
"""
Provenance funsor for tracking the dependence of terms on ``(name, point)``
of sampled random variables.

**References**

[1] David Wingate, Noah Goodman, Andreas Stuhlmüller, Jeffrey Siskind (2011)
Nonstandard Interpretations of Probabilistic Programs for Efficient Inference
http://papers.neurips.cc/paper/4309-nonstandard-interpretations-of-probabilistic-programs-for-efficient-inference.pdf

:param funsor term: A term that depends on tracked variables.
:param frozenset provenance: A set of tuples of the form ``(name, point)``.
"""

def __init__(self, term, provenance):
assert isinstance(term, Funsor)
assert isinstance(provenance, frozenset)

provenance_names = frozenset([name for name, point in provenance])
assert provenance_names.isdisjoint(term.inputs)
inputs = OrderedDict()
for name, point in provenance:
assert isinstance(name, str)
assert isinstance(point, Funsor)
assert name not in point.inputs
inputs.update({name: point.output})
inputs.update(point.inputs)

inputs.update(term.inputs)
output = term.output
fresh = provenance_names
bound = {}
super(Provenance, self).__init__(inputs, output, fresh, bound)
self.term = term
self.provenance = provenance

def eager_subs(self, subs):
assert isinstance(subs, tuple)
subs = OrderedDict(subs)
assert set(subs).issubset(self.fresh)
new_provenance = frozenset()
new_term = self.term
for name, point in self.provenance:
if name in subs:
value = subs[name]
if isinstance(value, Variable):
new_provenance |= frozenset([(value.name, point)])
continue

# leave out the substituted provenance variable
# make sure that the value matches the point
assert value is point
else:
new_provenance |= frozenset([(name, point)])
return Provenance(new_term, new_provenance) if new_provenance else new_term

def _sample(self, sampled_vars, sample_inputs, rng_key):
result = self.term._sample(sampled_vars, sample_inputs, rng_key)
return Provenance(result, self.provenance)


@eager.register(Binary, ops.BinaryOp, Provenance, Provenance)
def eager_binary_provenance_provenance(op, lhs, rhs):
return Provenance(op(lhs.term, rhs.term), lhs.provenance | rhs.provenance)


@eager.register(Binary, ops.BinaryOp, Provenance, (Number, Tensor))
def eager_binary_provenance_tensor(op, lhs, rhs):
assert lhs.fresh.isdisjoint(rhs.inputs)
return Provenance(op(lhs.term, rhs), lhs.provenance)


@eager.register(Binary, ops.BinaryOp, (Number, Tensor), Provenance)
def eager_binary_tensor_provenance(op, lhs, rhs):
assert rhs.fresh.isdisjoint(lhs.inputs)
return Provenance(op(lhs, rhs.term), rhs.provenance)


@eager.register(Unary, ops.UnaryOp, Provenance)
def eager_unary(op, arg):
return Provenance(op(arg.term), arg.provenance)
18 changes: 8 additions & 10 deletions funsor/torch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from collections import OrderedDict

import torch
from multipledispatch import dispatch

from funsor.constant import Constant
from funsor.provenance import Provenance
from funsor.tensor import tensor_to_funsor
from funsor.terms import to_data, to_funsor
from funsor.torch.provenance import ProvenanceTensor
Expand All @@ -31,15 +29,15 @@ def _quote(x, indent, out):


@to_funsor.register(ProvenanceTensor)
def provenance_to_funsor(x, output=None, dim_to_name=None):
ret = to_funsor(x._t, output=output, dim_to_name=dim_to_name)
return Constant(OrderedDict(x._provenance), ret)
def provenancetensor_to_funsor(x, output=None, dim_to_name=None):
term = to_funsor(x._t, output=output, dim_to_name=dim_to_name)
return Provenance(term, x._provenance)


@to_data.register(Constant)
def constant_to_data(x, name_to_dim=None):
data = to_data(x.arg, name_to_dim=name_to_dim)
return ProvenanceTensor(data, provenance=frozenset(x.const_inputs.items()))
@to_data.register(Provenance)
def provenance_to_data(x, name_to_dim=None):
data = to_data(x.term, name_to_dim=name_to_dim)
return ProvenanceTensor(data, provenance=x.provenance)


to_funsor.register(torch.Tensor)(tensor_to_funsor)
Expand Down
30 changes: 2 additions & 28 deletions test/test_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from funsor.delta import Delta
from funsor.domains import Bint, Real
from funsor.tensor import Tensor
from funsor.terms import Number, Variable, to_data, to_funsor
from funsor.testing import assert_close, randn, requires_backend
from funsor.terms import Number, Variable
from funsor.testing import assert_close, randn


def test_eager_subs_variable():
Expand Down Expand Up @@ -81,29 +81,3 @@ def test_align():
for i in range(2):
for j in range(3):
assert x(a=0, b=b, i=i, j=j) == y(a=0, b=b, i=i, j=j)


@requires_backend("torch", reason="requires ProvenanceTensor")
def test_to_funsor():
import torch

from funsor.torch.provenance import ProvenanceTensor

data = torch.zeros(3, 3)
pt = ProvenanceTensor(data, frozenset({("x", Real)}))
c = to_funsor(pt)
assert c is Constant(OrderedDict(x=Real), Tensor(data))


@requires_backend("torch", reason="requires ProvenanceTensor")
def test_to_data():
import torch

from funsor.torch.provenance import ProvenanceTensor

data = torch.zeros(3, 3)
c = Constant(OrderedDict(x=Real), Tensor(data))
pt = to_data(c)
assert isinstance(pt, ProvenanceTensor)
assert pt._t is data
assert pt._provenance == frozenset({("x", Real)})