diff --git a/torch/_dynamo/compiled_autograd.py b/torch/_dynamo/compiled_autograd.py index 3ac5441c4868..fc00b232cc3d 100644 --- a/torch/_dynamo/compiled_autograd.py +++ b/torch/_dynamo/compiled_autograd.py @@ -220,7 +220,13 @@ def end_capture(self, outputs): "compiled_autograd_graph", payload_fn=lambda: graph.print_readable(print_output=False), ) - return self.compiler_fn(graph) + + def runtime_wrapper(compiled_fn, inputs, sizes, hooks): + # insert debug code here + verbose_log.debug(f"lifted hooks={hooks}") + return compiled_fn(inputs, sizes, hooks) + + return runtime_wrapper, self.compiler_fn(graph) def reorder_accumulate_grad_nodes(self): """ diff --git a/torch/csrc/dynamo/python_compiled_autograd.cpp b/torch/csrc/dynamo/python_compiled_autograd.cpp index dd5ea7cbd094..3dcbc671d4b4 100644 --- a/torch/csrc/dynamo/python_compiled_autograd.cpp +++ b/torch/csrc/dynamo/python_compiled_autograd.cpp @@ -228,6 +228,7 @@ struct CacheNode { std::vector key_storage; std::vector expected_sizes; + THPObjectPtr runtime_wrapper; THPObjectPtr compiled_fn; }; @@ -504,7 +505,9 @@ CacheNode* _compiled_autograd_impl( } } - cache->compiled_fn = check(call_end_capture(py_compiler, state.outputs)); + PyObject* res = check(call_end_capture(py_compiler, state.outputs)); + cache->runtime_wrapper = PyTuple_GetItem(res, 0); + cache->compiled_fn = PyTuple_GetItem(res, 1); state.debug_asserts(); } // End cache miss region @@ -552,7 +555,7 @@ variable_list compiled_autograd( &hooks); THPObjectPtr pyresult(check(PyObject_CallFunctionObjArgs( - cache->compiled_fn.get(), inputs.get(), sizes.get(), hooks.get(), NULL))); + cache->runtime_wrapper.get(), cache->compiled_fn.get(), inputs.get(), sizes.get(), hooks.get(), NULL))); variable_list outputs = THPVariable_UnpackList(pyresult); TORCH_INTERNAL_ASSERT(outputs.size() == output_edges.size()); return outputs;