Skip to content

Commit

Permalink
[wip][compiled autograd] runtime wrapper for verbose debugging
Browse files Browse the repository at this point in the history
ghstack-source-id: 6a33c4e345f7cd06f49a4c41d92a5f24bf011a5b
Pull Request resolved: #125417
  • Loading branch information
xmfan committed May 2, 2024
1 parent b03fb49 commit d9a0edd
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 3 deletions.
8 changes: 7 additions & 1 deletion torch/_dynamo/compiled_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,13 @@ def end_capture(self, outputs):
"compiled_autograd_graph",
payload_fn=lambda: graph.print_readable(print_output=False),
)
return self.compiler_fn(graph)

def runtime_wrapper(compiled_fn, inputs, sizes, hooks):
# insert debug code here
verbose_log.debug(f"lifted hooks={hooks}")
return compiled_fn(inputs, sizes, hooks)

return runtime_wrapper, self.compiler_fn(graph)

def reorder_accumulate_grad_nodes(self):
"""
Expand Down
7 changes: 5 additions & 2 deletions torch/csrc/dynamo/python_compiled_autograd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ struct CacheNode {
std::vector<CacheKeyBuffer> key_storage;
std::vector<SizeInput> expected_sizes;

THPObjectPtr runtime_wrapper;
THPObjectPtr compiled_fn;
};

Expand Down Expand Up @@ -504,7 +505,9 @@ CacheNode* _compiled_autograd_impl(
}
}

cache->compiled_fn = check(call_end_capture(py_compiler, state.outputs));
PyObject* res = check(call_end_capture(py_compiler, state.outputs));
cache->runtime_wrapper = PyTuple_GetItem(res, 0);
cache->compiled_fn = PyTuple_GetItem(res, 1);
state.debug_asserts();
} // End cache miss region

Expand Down Expand Up @@ -552,7 +555,7 @@ variable_list compiled_autograd(
&hooks);

THPObjectPtr pyresult(check(PyObject_CallFunctionObjArgs(
cache->compiled_fn.get(), inputs.get(), sizes.get(), hooks.get(), NULL)));
cache->runtime_wrapper.get(), cache->compiled_fn.get(), inputs.get(), sizes.get(), hooks.get(), NULL)));
variable_list outputs = THPVariable_UnpackList(pyresult);
TORCH_INTERNAL_ASSERT(outputs.size() == output_edges.size());
return outputs;
Expand Down

0 comments on commit d9a0edd

Please sign in to comment.