Skip to content

Commit

Permalink
Fix context in rust task impl (#482)
Browse files Browse the repository at this point in the history
  • Loading branch information
gi0baro authored Jan 8, 2025
1 parent d15c37e commit 4a8cf4a
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 36 deletions.
3 changes: 3 additions & 0 deletions granian/_futures.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ def __init__(self, loop, ctx, cb, aio_tenter, aio_texit):
super().__init__()
self._schedule_fn = _cbsched_schedule(loop, ctx, self._run, cb)

def _run(self, coro):
self._run_wctx(coro, contextvars.copy_context())

def cancel(self):
return False

Expand Down
113 changes: 77 additions & 36 deletions src/callbacks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,11 @@ pub(crate) struct CallbackScheduler {
schedule_fn: OnceLock<PyObject>,
aio_tenter: PyObject,
aio_texit: PyObject,
pykw_ctx: PyObject,
pym_lcs: PyObject,
pyname_aioblock: PyObject,
pyname_aiosend: PyObject,
pyname_aiothrow: PyObject,
pyname_donecb: PyObject,
pyname_loopcs: PyObject,
pynone: PyObject,
pyfalse: PyObject,
}
Expand All @@ -37,7 +36,7 @@ impl CallbackScheduler {
}

#[inline]
pub(crate) fn send(pyself: Py<Self>, py: Python, coro: PyObject) {
pub(crate) fn send(pyself: Py<Self>, py: Python, coro: PyObject, ctx: PyObject) {
let rself = pyself.get();
let ptr = pyself.as_ptr();

Expand All @@ -53,6 +52,11 @@ impl CallbackScheduler {
);
Bound::from_owned_ptr_or_err(py, res)
} {
let aio_ctxd = PyDict::new(py);
aio_ctxd
.set_item(pyo3::intern!(py, "context"), ctx.clone_ref(py))
.unwrap();

if unsafe {
let vptr = pyo3::ffi::PyObject_GetAttr(res.as_ptr(), rself.pyname_aioblock.as_ptr());
Bound::from_owned_ptr_or_err(py, vptr)
Expand All @@ -64,6 +68,7 @@ impl CallbackScheduler {
CallbackSchedulerWaker {
sched: pyself.clone_ref(py),
coro,
ctx: ctx.clone_ref(py),
},
)
.unwrap();
Expand All @@ -74,7 +79,7 @@ impl CallbackScheduler {
pyo3::ffi::PyObject_Call(
pyo3::ffi::PyObject_GetAttr(resp, rself.pyname_donecb.as_ptr()),
(waker,).into_py_any(py).unwrap().into_ptr(),
rself.pykw_ctx.as_ptr(),
aio_ctxd.as_ptr(),
);
}
} else {
Expand All @@ -83,16 +88,16 @@ impl CallbackScheduler {
CallbackSchedulerRef {
sched: pyself.clone_ref(py),
coro,
ctx: ctx.clone_ref(py),
},
)
.unwrap();

unsafe {
pyo3::ffi::PyObject_CallMethodOneArg(
#[allow(clippy::used_underscore_binding)]
rself._loop.as_ptr(),
rself.pyname_loopcs.as_ptr(),
sref.as_ptr(),
pyo3::ffi::PyObject_Call(
rself.pym_lcs.as_ptr(),
(sref,).into_py_any(py).unwrap().into_ptr(),
aio_ctxd.as_ptr(),
);
}
}
Expand Down Expand Up @@ -148,6 +153,11 @@ impl CallbackScheduler {
);
Bound::from_owned_ptr_or_err(py, res)
} {
let aio_ctxd = PyDict::new(py);
aio_ctxd
.set_item(pyo3::intern!(py, "context"), ctx.clone_ref(py))
.unwrap();

if unsafe {
let vptr = pyo3::ffi::PyObject_GetAttr(res.as_ptr(), rself.pyname_aioblock.as_ptr());
Bound::from_owned_ptr_or_err(py, vptr)
Expand All @@ -159,6 +169,7 @@ impl CallbackScheduler {
CallbackSchedulerWaker {
sched: pyself.clone_ref(py),
coro,
ctx: ctx.clone_ref(py),
},
)
.unwrap();
Expand All @@ -169,7 +180,7 @@ impl CallbackScheduler {
pyo3::ffi::PyObject_Call(
pyo3::ffi::PyObject_GetAttr(resp, rself.pyname_donecb.as_ptr()),
(waker,).into_py_any(py).unwrap().into_ptr(),
rself.pykw_ctx.as_ptr(),
aio_ctxd.as_ptr(),
);
}
} else {
Expand All @@ -178,17 +189,16 @@ impl CallbackScheduler {
CallbackSchedulerRef {
sched: pyself.clone_ref(py),
coro,
ctx: ctx.clone_ref(py),
},
)
.unwrap();

unsafe {
pyo3::ffi::PyObject_CallMethodObjArgs(
#[allow(clippy::used_underscore_binding)]
rself._loop.as_ptr(),
rself.pyname_loopcs.as_ptr(),
sref.as_ptr(),
std::ptr::null_mut::<PyObject>(),
pyo3::ffi::PyObject_Call(
rself.pym_lcs.as_ptr(),
(sref,).into_py_any(py).unwrap().into_ptr(),
aio_ctxd.as_ptr(),
);
}
}
Expand Down Expand Up @@ -228,8 +238,7 @@ impl CallbackScheduler {
aio_tenter: PyObject,
aio_texit: PyObject,
) -> Self {
let ctxd = PyDict::new(py);
ctxd.set_item(pyo3::intern!(py, "context"), ctx.clone_ref(py)).unwrap();
let pym_lcs = event_loop.getattr(py, pyo3::intern!(py, "call_soon")).unwrap();

Self {
_loop: event_loop,
Expand All @@ -240,12 +249,11 @@ impl CallbackScheduler {
aio_texit,
pyfalse: false.into_py_any(py).unwrap(),
pynone: py.None(),
pykw_ctx: ctxd.into_py_any(py).unwrap(),
pym_lcs,
pyname_aioblock: pyo3::intern!(py, "_asyncio_future_blocking").into_py_any(py).unwrap(),
pyname_aiosend: pyo3::intern!(py, "send").into_py_any(py).unwrap(),
pyname_aiothrow: pyo3::intern!(py, "throw").into_py_any(py).unwrap(),
pyname_donecb: pyo3::intern!(py, "add_done_callback").into_py_any(py).unwrap(),
pyname_loopcs: pyo3::intern!(py, "call_soon").into_py_any(py).unwrap(),
}
}

Expand All @@ -255,14 +263,31 @@ impl CallbackScheduler {
}

#[cfg(not(any(Py_3_12, Py_3_13)))]
fn _run(pyself: Py<Self>, py: Python, coro: PyObject) {
CallbackScheduler::send(pyself, py, coro);
fn _run_wctx(pyself: Py<Self>, py: Python, coro: PyObject, ctx: PyObject) {
unsafe {
pyo3::ffi::PyContext_Enter(ctx.as_ptr());
}

CallbackScheduler::send(pyself, py, coro, ctx.clone_ref(py));

unsafe {
pyo3::ffi::PyContext_Exit(ctx.as_ptr());
}
}

#[cfg(any(Py_3_12, Py_3_13))]
fn _run(pyself: Py<Self>, py: Python, coro: PyObject) {
let stepper = Py::new(py, CallbackSchedulerStep::new(py, pyself, coro)).unwrap();
fn _run_wctx(pyself: Py<Self>, py: Python, coro: PyObject, ctx: PyObject) {
let stepper = Py::new(py, CallbackSchedulerStep::new(pyself, coro, ctx.clone_ref(py))).unwrap();

unsafe {
pyo3::ffi::PyContext_Enter(ctx.as_ptr());
}

CallbackSchedulerStep::send(stepper, py);

unsafe {
pyo3::ffi::PyContext_Exit(ctx.as_ptr());
}
}
}

Expand All @@ -271,18 +296,18 @@ impl CallbackScheduler {
pub(crate) struct CallbackSchedulerStep {
sched: Py<CallbackScheduler>,
coro: PyObject,
ctx: PyObject,
futw: Mutex<Option<PyObject>>,
pyname_wake: PyObject,
}

#[cfg(any(Py_3_12, Py_3_13))]
impl CallbackSchedulerStep {
pub(crate) fn new(py: Python, sched: Py<CallbackScheduler>, coro: PyObject) -> Self {
pub(crate) fn new(sched: Py<CallbackScheduler>, coro: PyObject, ctx: PyObject) -> Self {
Self {
sched,
coro,
ctx,
futw: Mutex::new(None),
pyname_wake: pyo3::intern!(py, "wake").into_py_any(py).unwrap(),
}
}

Expand All @@ -309,6 +334,11 @@ impl CallbackSchedulerStep {
);
Bound::from_owned_ptr_or_err(py, res)
} {
let aio_ctxd = PyDict::new(py);
aio_ctxd
.set_item(pyo3::intern!(py, "context"), rself.ctx.clone_ref(py))
.unwrap();

if unsafe {
let vptr = pyo3::ffi::PyObject_GetAttr(res.as_ptr(), rsched.pyname_aioblock.as_ptr());
Bound::from_owned_ptr_or_err(py, vptr)
Expand All @@ -326,17 +356,16 @@ impl CallbackSchedulerStep {
pyo3::ffi::PyObject_Call(
pyo3::ffi::PyObject_GetAttr(resp, rsched.pyname_donecb.as_ptr()),
(pyself.clone_ref(py),).into_py_any(py).unwrap().into_ptr(),
rsched.pykw_ctx.as_ptr(),
aio_ctxd.as_ptr(),
);
}
} else {
unsafe {
let mptr = pyo3::ffi::PyObject_GetAttr(ptr, rself.pyname_wake.as_ptr());
pyo3::ffi::PyObject_CallMethodOneArg(
#[allow(clippy::used_underscore_binding)]
rsched._loop.as_ptr(),
rsched.pyname_loopcs.as_ptr(),
mptr,
let sref = pyself.getattr(py, pyo3::intern!(py, "_step")).unwrap();
pyo3::ffi::PyObject_Call(
rsched.pym_lcs.as_ptr(),
(sref,).into_py_any(py).unwrap().into_ptr(),
aio_ctxd.as_ptr(),
);
}
}
Expand Down Expand Up @@ -398,13 +427,19 @@ impl CallbackSchedulerStep {
pub(crate) struct CallbackSchedulerWaker {
sched: Py<CallbackScheduler>,
coro: PyObject,
ctx: PyObject,
}

#[pymethods]
impl CallbackSchedulerWaker {
fn __call__(&self, py: Python, fut: PyObject) {
match fut.call_method0(py, pyo3::intern!(py, "result")) {
Ok(_) => CallbackScheduler::send(self.sched.clone_ref(py), py, self.coro.clone_ref(py)),
Ok(_) => CallbackScheduler::send(
self.sched.clone_ref(py),
py,
self.coro.clone_ref(py),
self.ctx.clone_ref(py),
),
Err(err) => CallbackScheduler::throw(
self.sched.clone_ref(py),
py,
Expand All @@ -419,12 +454,18 @@ impl CallbackSchedulerWaker {
pub(crate) struct CallbackSchedulerRef {
sched: Py<CallbackScheduler>,
coro: PyObject,
ctx: PyObject,
}

#[pymethods]
impl CallbackSchedulerRef {
fn __call__(&self, py: Python) {
CallbackScheduler::send(self.sched.clone_ref(py), py, self.coro.clone_ref(py));
CallbackScheduler::send(
self.sched.clone_ref(py),
py,
self.coro.clone_ref(py),
self.ctx.clone_ref(py),
);
}
}

Expand Down

0 comments on commit 4a8cf4a

Please sign in to comment.