Skip to content

Commit

Permalink
Refactor GIL acquisitions in async runtime
Browse files Browse the repository at this point in the history
  • Loading branch information
gi0baro committed Jan 13, 2025
1 parent a69ea50 commit 5d7fae2
Show file tree
Hide file tree
Showing 13 changed files with 294 additions and 275 deletions.
3 changes: 0 additions & 3 deletions granian/_futures.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,6 @@ def __init__(self, loop, ctx, cb, aio_tenter, aio_texit):
super().__init__()
self._schedule_fn = _cbsched_schedule(loop, ctx, self._run, cb)

def _run(self, coro):
self._run_wctx(coro, contextvars.copy_context())

def cancel(self):
return False

Expand Down
62 changes: 62 additions & 0 deletions src/asgi/conversion.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
use pyo3::{
prelude::*,
types::{PyBytes, PyDict},
IntoPyObjectExt,
};
use tokio_tungstenite::tungstenite::Message;

use super::errors::error_flow;
use super::types::ASGIMessageType;
use crate::conversion::Utf8BytesToPy;

#[inline]
pub(crate) fn message_into_py(py: Python, message: ASGIMessageType) -> PyResult<Bound<PyAny>> {
let dict = PyDict::new(py);
match message {
ASGIMessageType::HTTPDisconnect => {
dict.set_item(pyo3::intern!(py, "type"), pyo3::intern!(py, "http.disconnect"))?;
}
ASGIMessageType::HTTPRequestBody((bytes, more)) => {
dict.set_item(pyo3::intern!(py, "type"), pyo3::intern!(py, "http.request"))?;
dict.set_item(pyo3::intern!(py, "body"), bytes.into_py_any(py)?)?;
dict.set_item(pyo3::intern!(py, "more_body"), more)?;
}
ASGIMessageType::WSConnect => {
dict.set_item(pyo3::intern!(py, "type"), pyo3::intern!(py, "websocket.connect"))?;
}
_ => unreachable!(),
}
Ok(dict.into_any())
}

#[inline]
pub(crate) fn ws_message_into_py(py: Python, message: Message) -> PyResult<Bound<PyAny>> {
match message {
Message::Binary(message) => {
let dict = PyDict::new(py);
dict.set_item(pyo3::intern!(py, "type"), pyo3::intern!(py, "websocket.receive"))?;
dict.set_item(pyo3::intern!(py, "bytes"), PyBytes::new(py, &message[..]))?;
Ok(dict.into_any())
}
Message::Text(message) => {
let dict = PyDict::new(py);
dict.set_item(pyo3::intern!(py, "type"), pyo3::intern!(py, "websocket.receive"))?;
dict.set_item(pyo3::intern!(py, "text"), Utf8BytesToPy(message))?;
Ok(dict.into_any())
}
Message::Close(frame) => {
let close_code: u16 = match frame {
Some(frame) => frame.code.into(),
_ => 1005,
};
let dict = PyDict::new(py);
dict.set_item(pyo3::intern!(py, "type"), pyo3::intern!(py, "websocket.disconnect"))?;
dict.set_item(pyo3::intern!(py, "code"), close_code)?;
Ok(dict.into_any())
}
v => {
log::warn!("Unsupported websocket message received {:?}", v);
error_flow!()
}
}
}
122 changes: 32 additions & 90 deletions src/asgi/io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,7 @@ use hyper::{
header::{HeaderMap, HeaderName, HeaderValue, SERVER as HK_SERVER},
Response, StatusCode,
};
use pyo3::{
prelude::*,
pybacked::PyBackedBytes,
types::{PyBytes, PyDict},
};
use pyo3::{prelude::*, pybacked::PyBackedBytes, types::PyDict};
use std::{
borrow::Cow,
sync::{atomic, Arc, Mutex},
Expand All @@ -27,7 +23,7 @@ use super::{
types::ASGIMessageType,
};
use crate::{
conversion::{BytesToPy, Utf8BytesToPy},
conversion::FutureResultToPy,
http::{response_404, HTTPResponse, HTTPResponseBody, HV_SERVER},
runtime::{empty_future_into_py, future_into_py_futlike, future_into_py_iter, Runtime, RuntimeRef},
ws::{HyperWebsocket, UpgradeData, WSRxStream, WSTxStream},
Expand Down Expand Up @@ -87,7 +83,6 @@ impl ASGIHTTPProtocol {
close: bool,
) -> PyResult<Bound<'p, PyAny>> {
let flow_hld = self.flow_tx_waiter.clone();
let pynone = py.None();

future_into_py_futlike(self.rt.clone(), py, async move {
match tx.send(Ok(body.into())).await {
Expand All @@ -101,7 +96,7 @@ impl ASGIHTTPProtocol {
flow_hld.notify_one();
}
}
Ok(pynone)
FutureResultToPy::None
})
}

Expand All @@ -117,11 +112,7 @@ impl ASGIHTTPProtocol {
let flow_hld = self.flow_tx_waiter.clone();
return future_into_py_futlike(self.rt.clone(), py, async move {
let () = flow_hld.notified().await;
Python::with_gil(|py| {
let dict = PyDict::new(py);
dict.set_item(pyo3::intern!(py, "type"), pyo3::intern!(py, "http.disconnect"))?;
Ok(dict.into_any().unbind())
})
FutureResultToPy::ASGIMessage(ASGIMessageType::HTTPDisconnect)
});
}

Expand All @@ -145,37 +136,29 @@ impl ASGIHTTPProtocol {
}

match chunk {
Ok(data) => Python::with_gil(|py| {
let dict = PyDict::new(py);
dict.set_item(pyo3::intern!(py, "type"), pyo3::intern!(py, "http.request"))?;
dict.set_item(pyo3::intern!(py, "body"), BytesToPy(data))?;
dict.set_item(pyo3::intern!(py, "more_body"), more_body)?;
Ok(dict.into_any().unbind())
}),
Ok(data) => FutureResultToPy::ASGIMessage(ASGIMessageType::HTTPRequestBody((data, more_body))),
_ => {
flow_hld.notify_one();
Python::with_gil(|py| {
let dict = PyDict::new(py);
dict.set_item(pyo3::intern!(py, "type"), pyo3::intern!(py, "http.disconnect"))?;
Ok(dict.into_any().unbind())
})
FutureResultToPy::ASGIMessage(ASGIMessageType::HTTPDisconnect)
}
}
})
}

fn send<'p>(&self, py: Python<'p>, data: &Bound<'p, PyDict>) -> PyResult<Bound<'p, PyAny>> {
match adapt_message_type(py, data) {
Ok(ASGIMessageType::HTTPStart(intent)) => match self.response_started.load(atomic::Ordering::Relaxed) {
false => {
let mut response_intent = self.response_intent.lock().unwrap();
*response_intent = Some(intent);
self.response_started.store(true, atomic::Ordering::Relaxed);
empty_future_into_py(py)
Ok(ASGIMessageType::HTTPResponseStart(intent)) => {
match self.response_started.load(atomic::Ordering::Relaxed) {
false => {
let mut response_intent = self.response_intent.lock().unwrap();
*response_intent = Some(intent);
self.response_started.store(true, atomic::Ordering::Relaxed);
empty_future_into_py(py)
}
true => error_flow!(),
}
true => error_flow!(),
},
Ok(ASGIMessageType::HTTPBody((body, more))) => {
}
Ok(ASGIMessageType::HTTPResponseBody((body, more))) => {
match (
self.response_started.load(atomic::Ordering::Relaxed),
more,
Expand Down Expand Up @@ -224,7 +207,7 @@ impl ASGIHTTPProtocol {
_ => error_flow!(),
}
}
Ok(ASGIMessageType::HTTPFile(file_path)) => match (
Ok(ASGIMessageType::HTTPResponseFile(file_path)) => match (
self.response_started.load(atomic::Ordering::Relaxed),
self.tx.lock().unwrap().take(),
) {
Expand Down Expand Up @@ -325,7 +308,6 @@ impl ASGIWebsocketProtocol {
let accepted = self.accepted.clone();
let rx = self.ws_rx.clone();
let tx = self.ws_tx.clone();
let pynone = py.None();

future_into_py_futlike(self.rt.clone(), py, async move {
if let Some(mut upgrade) = upgrade {
Expand All @@ -342,36 +324,33 @@ impl ASGIWebsocketProtocol {
*wtx = Some(tx);
*wrx = Some(rx);
accepted.store(true, atomic::Ordering::Relaxed);
return Ok(pynone);
return FutureResultToPy::None;
}
}
}
}
Python::with_gil(|_| drop(pynone));
error_flow!()
FutureResultToPy::Err(error_flow!())
})
}

#[inline(always)]
fn send_message<'p>(&self, py: Python<'p>, data: Message) -> PyResult<Bound<'p, PyAny>> {
let transport = self.ws_tx.clone();
let closed = self.closed.clone();
let pynone = py.None();

future_into_py_futlike(self.rt.clone(), py, async move {
if let Some(ws) = &mut *(transport.lock().await) {
match ws.send(data).await {
Ok(()) => return Ok(pynone),
Ok(()) => return FutureResultToPy::None,
_ => {
if closed.load(atomic::Ordering::Relaxed) {
log::info!("Attempted to write to a closed websocket");
return Ok(pynone);
return FutureResultToPy::None;
}
}
};
};
Python::with_gil(|_| drop(pynone));
error_flow!()
FutureResultToPy::Err(error_flow!())
})
}

Expand All @@ -380,7 +359,6 @@ impl ASGIWebsocketProtocol {
let closed = self.closed.clone();
let ws_rx = self.ws_rx.clone();
let ws_tx = self.ws_tx.clone();
let pynone = py.None();

future_into_py_iter(self.rt.clone(), py, async move {
if let Some(tx) = ws_tx.lock().await.take() {
Expand All @@ -389,7 +367,7 @@ impl ASGIWebsocketProtocol {
.close()
.await;
}
Ok(pynone)
FutureResultToPy::None
})
}

Expand Down Expand Up @@ -422,11 +400,7 @@ impl ASGIWebsocketProtocol {
future_into_py_futlike(self.rt.clone(), py, async move {
let accepted = accepted.load(atomic::Ordering::Relaxed);
if !accepted {
return Python::with_gil(|py| {
let dict = PyDict::new(py);
dict.set_item(pyo3::intern!(py, "type"), pyo3::intern!(py, "websocket.connect"))?;
Ok(dict.into_any().unbind())
});
return FutureResultToPy::ASGIMessage(ASGIMessageType::WSConnect);
}

if let Some(ws) = &mut *(transport.lock().await) {
Expand All @@ -435,14 +409,14 @@ impl ASGIWebsocketProtocol {
Ok(Message::Ping(_) | Message::Pong(_)) => continue,
Ok(message @ Message::Close(_)) => {
closed.store(true, atomic::Ordering::Relaxed);
return ws_message_into_py(message);
return FutureResultToPy::ASGIWSMessage(message);
}
Ok(message) => return ws_message_into_py(message),
Ok(message) => return FutureResultToPy::ASGIWSMessage(message),
_ => break,
}
}
}
error_flow!()
FutureResultToPy::Err(error_flow!())
})
}

Expand All @@ -451,7 +425,7 @@ impl ASGIWebsocketProtocol {
Ok(ASGIMessageType::WSAccept(subproto)) => self.accept(py, subproto),
Ok(ASGIMessageType::WSClose) => self.close(py),
Ok(ASGIMessageType::WSMessage(message)) => self.send_message(py, message),
_ => future_into_py_iter::<_, _>(self.rt.clone(), py, async { error_message!() }),
_ => future_into_py_iter::<_, _>(self.rt.clone(), py, async { FutureResultToPy::Err(error_message!()) }),
}
}
}
Expand All @@ -462,12 +436,12 @@ fn adapt_message_type(py: Python, message: &Bound<PyDict>) -> Result<ASGIMessage
Ok(Some(item)) => {
let message_type: &str = item.extract()?;
match message_type {
"http.response.start" => Ok(ASGIMessageType::HTTPStart((
"http.response.start" => Ok(ASGIMessageType::HTTPResponseStart((
adapt_status_code(py, message)?,
adapt_headers(py, message),
))),
"http.response.body" => Ok(ASGIMessageType::HTTPBody(adapt_body(py, message))),
"http.response.pathsend" => Ok(ASGIMessageType::HTTPFile(adapt_file(py, message)?)),
"http.response.body" => Ok(ASGIMessageType::HTTPResponseBody(adapt_body(py, message))),
"http.response.pathsend" => Ok(ASGIMessageType::HTTPResponseFile(adapt_file(py, message)?)),
"websocket.accept" => {
let subproto: Option<String> = match message.get_item(pyo3::intern!(py, "subprotocol")) {
Ok(Some(item)) => item.extract::<String>().map(Some).unwrap_or(None),
Expand Down Expand Up @@ -554,35 +528,3 @@ fn ws_message_into_rs(py: Python, message: &Bound<PyDict>) -> PyResult<Message>
_ => error_message!(),
}
}

#[inline(always)]
fn ws_message_into_py(message: Message) -> PyResult<PyObject> {
match message {
Message::Binary(message) => Python::with_gil(|py| {
let dict = PyDict::new(py);
dict.set_item(pyo3::intern!(py, "type"), pyo3::intern!(py, "websocket.receive"))?;
dict.set_item(pyo3::intern!(py, "bytes"), PyBytes::new(py, &message[..]))?;
Ok(dict.into_any().unbind())
}),
Message::Text(message) => Python::with_gil(|py| {
let dict = PyDict::new(py);
dict.set_item(pyo3::intern!(py, "type"), pyo3::intern!(py, "websocket.receive"))?;
dict.set_item(pyo3::intern!(py, "text"), Utf8BytesToPy(message))?;
Ok(dict.into_any().unbind())
}),
Message::Close(frame) => Python::with_gil(|py| {
let close_code: u16 = match frame {
Some(frame) => frame.code.into(),
_ => 1005,
};
let dict = PyDict::new(py);
dict.set_item(pyo3::intern!(py, "type"), pyo3::intern!(py, "websocket.disconnect"))?;
dict.set_item(pyo3::intern!(py, "code"), close_code)?;
Ok(dict.into_any().unbind())
}),
v => {
log::warn!("Unsupported websocket message received {:?}", v);
error_flow!()
}
}
}
3 changes: 2 additions & 1 deletion src/asgi/mod.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
use pyo3::prelude::*;

mod callbacks;
pub(crate) mod conversion;
mod errors;
mod http;
mod io;
pub(crate) mod serve;
mod types;
pub(crate) mod types;
mod utils;

pub(crate) fn init_pymodule(module: &Bound<PyModule>) -> PyResult<()> {
Expand Down
11 changes: 7 additions & 4 deletions src/asgi/types.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
use hyper::HeaderMap;
use hyper::{body, HeaderMap};
use tokio_tungstenite::tungstenite::Message;

pub(crate) enum ASGIMessageType {
HTTPStart((u16, HeaderMap)),
HTTPBody((Box<[u8]>, bool)),
HTTPFile(String),
HTTPResponseStart((u16, HeaderMap)),
HTTPResponseBody((Box<[u8]>, bool)),
HTTPResponseFile(String),
HTTPDisconnect,
HTTPRequestBody((body::Bytes, bool)),
WSAccept(Option<String>),
WSConnect,
WSClose,
WSMessage(Message),
}
9 changes: 6 additions & 3 deletions src/asyncio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@ pub(crate) fn empty_context(py: Python) -> PyResult<&Bound<PyAny>> {
.bind(py))
}

#[allow(dead_code)]
pub(crate) fn copy_context(py: Python) -> PyResult<Bound<PyAny>> {
contextvars(py)?.call_method0("copy_context")
pub(crate) fn copy_context(py: Python) -> PyObject {
let ctx = unsafe {
let ptr = pyo3::ffi::PyContext_CopyCurrent();
Bound::from_owned_ptr(py, ptr)
};
ctx.unbind()
}
Loading

0 comments on commit 5d7fae2

Please sign in to comment.