Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid trivial subs #545

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft

Avoid trivial subs #545

wants to merge 4 commits into from

Conversation

ordabayevy
Copy link
Member

@ordabayevy ordabayevy commented Jul 30, 2021

This proposes to avoid any trivial subs (do not call interpret if all subs are trivial).

Trivial subs can arise for example in eager MarkovProduct with step_names = {"prev": "prev", "curr": "curr"}:

def eager_markov_product(sum_op, prod_op, trans, time, step, step_names):
    ...
    return Subs(result, step_names)

which then pollutes AdjointTape.tape under adjoint interpretation (#493 #544).

@eb8680
Copy link
Member

eb8680 commented Aug 3, 2021

I don't think *Meta.__call__ methods are the right place for this kind of simplification logic (these methods should do nothing except fill default values for optional arguments and possibly call to_funsor where appropriate), and it really shouldn't be necessary regardless - it sounds like the issue, as is so often unfortunately the case, is with the logic in funsor.adjoint, or perhaps with alpha-renaming. Can you distill the underlying issue into a failing test comparing a correct expression and a very simple incorrect one generated by adjoint?

Pragmatically, however, I'm OK with merging this if you can do that and explain how this fix would unblock what you're actually working on.

funsor/terms.py Outdated
(k, to_funsor(v, arg.inputs[k])) for k, v in subs if k in arg.inputs
(k, to_funsor(v, arg.inputs[k]))
for k, v in subs
if k in arg.inputs and k is not v
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if v is already a Variable?

@ordabayevy
Copy link
Member Author

Can you distill the underlying issue into a failing test comparing a correct expression and a very simple incorrect one generated by adjoint?

I discovered this issue by examining test_adjoint.py::test_sequential_sum_product_adjoint: xfail_param(MarkovProduct, reason="mysteriously doubles adjoint values?").

In that test under adjoint interpretation eager_markov_product calls sequential_sum_product (no issues there) and then calls Subs(results, step_names) where step_names = {"prev": "prev", "curr": "curr"}. Since the names in step_names are the identical, Subs(result, step_names) returns the same result, however, it also gets appended to the AdjointTape.tape and I think that doubles adjoint values.

Simple solution can be to check step_names and call Subs(result, step_names) only if names are not identical. I moved that logic to SubsMeta.__call__ thinking that it might help to avoid similar issues in the future.

def eager_markov_product(sum_op, prod_op, trans, time, step, step_names):
    if step:
        result = sequential_sum_product(sum_op, prod_op, trans, time, dict(step))
    ...
    return Subs(result, step_names)

@ordabayevy
Copy link
Member Author

Can you distill the underlying issue into a failing test comparing a correct expression and a very simple incorrect one generated by adjoint?

Yes, I believe this can be boiled down to much simpler failing test with just couple of lines of code than test_adjoint.py::test_sequential_sum_product_adjoint: xfail_param(MarkovProduct, reason="mysteriously doubles adjoint values?"). Should I add such a test?

@eb8680
Copy link
Member

eb8680 commented Aug 3, 2021

Should I add such a test?

Yes, that would be very helpful!

with AdjointTape() as tape:
y = 2 * x
if use_subs:
y = y(i="i")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This substitution shouldn't change adjoint value.

elif test == "other":
y = y(j=0)
elif test == "reduce":
y = funsor.terms.Reduce(ops.add, y, frozenset())
Copy link
Member Author

@ordabayevy ordabayevy Aug 4, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact, it seems that any (unary) atomic op that doesn't change the arg doubles the adjoint value.

@ordabayevy ordabayevy marked this pull request as draft September 22, 2021 17:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants