Skip to content

Commit b2b2de4

Browse files
authored
Make a few more safety improvements (#16)
1 parent 94092af commit b2b2de4

File tree

1 file changed

+15
-30
lines changed

1 file changed

+15
-30
lines changed

src/asgi/mod.rs

Lines changed: 15 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use std::{
33
ffi::CString,
44
fs::{read_dir, read_to_string},
55
path::{Path, PathBuf},
6-
sync::{Arc, OnceLock, RwLock, Weak},
6+
sync::{Arc, Mutex, OnceLock, Weak},
77
};
88

99
#[cfg(target_os = "linux")]
@@ -26,19 +26,17 @@ type HttpResponseResult = Result<HttpResponse, HandlerError>;
2626
static FALLBACK_RUNTIME: OnceLock<tokio::runtime::Runtime> = OnceLock::new();
2727

2828
fn fallback_handle() -> tokio::runtime::Handle {
29-
if let Ok(handle) = tokio::runtime::Handle::try_current() {
30-
handle
31-
} else {
29+
tokio::runtime::Handle::try_current().unwrap_or_else(|_| {
3230
// No runtime exists, create a fallback one
3331
let rt = FALLBACK_RUNTIME.get_or_init(|| {
3432
tokio::runtime::Runtime::new().expect("Failed to create fallback tokio runtime")
3533
});
3634
rt.handle().clone()
37-
}
35+
})
3836
}
3937

4038
/// Global Python event loop handle storage
41-
static PYTHON_EVENT_LOOP: OnceLock<RwLock<Weak<EventLoopHandle>>> = OnceLock::new();
39+
static PYTHON_EVENT_LOOP: OnceLock<Mutex<Weak<EventLoopHandle>>> = OnceLock::new();
4240

4341
mod http;
4442
mod http_method;
@@ -90,22 +88,16 @@ unsafe impl Sync for EventLoopHandle {}
9088

9189
/// Ensure a Python event loop exists and return a handle to it
9290
fn ensure_python_event_loop() -> Result<Arc<EventLoopHandle>, HandlerError> {
93-
let weak_handle = PYTHON_EVENT_LOOP.get_or_init(|| RwLock::new(Weak::new()));
91+
let mut guard = PYTHON_EVENT_LOOP
92+
.get_or_init(|| Mutex::new(Weak::new()))
93+
.lock()?;
9494

9595
// Try to upgrade the weak reference
96-
if let Some(handle) = weak_handle.read()?.upgrade() {
97-
return Ok(handle);
98-
}
99-
100-
// Need write lock to create new handle
101-
let mut guard = weak_handle.write()?;
102-
103-
// Double-check in case another thread created it
10496
if let Some(handle) = guard.upgrade() {
10597
return Ok(handle);
10698
}
10799

108-
// Create new event loop handle
100+
// Create new handle
109101
let new_handle = Arc::new(create_event_loop_handle()?);
110102
*guard = Arc::downgrade(&new_handle);
111103

@@ -159,16 +151,10 @@ impl Asgi {
159151
docroot: Option<String>,
160152
app_target: Option<PythonHandlerTarget>,
161153
) -> Result<Self, HandlerError> {
162-
// Determine document root
163-
let docroot = PathBuf::from(if let Some(docroot) = docroot {
164-
docroot
165-
} else {
166-
current_dir()
167-
.map(|path| path.to_string_lossy().to_string())
168-
.map_err(HandlerError::CurrentDirectoryError)?
169-
});
170-
171154
let target = app_target.unwrap_or_default();
155+
let docroot = docroot
156+
.map(|d| Ok(PathBuf::from(d)))
157+
.unwrap_or_else(|| current_dir().map_err(HandlerError::CurrentDirectoryError))?;
172158

173159
// Get or create shared Python event loop
174160
let event_loop_handle = ensure_python_event_loop()?;
@@ -181,8 +167,10 @@ impl Asgi {
181167
.canonicalize()
182168
.map_err(HandlerError::EntrypointNotFoundError)?;
183169

184-
let code = read_to_string(entrypoint).map_err(HandlerError::EntrypointNotFoundError)?;
185-
let code = CString::new(code).map_err(HandlerError::StringCovertError)?;
170+
let code = read_to_string(entrypoint)
171+
.map_err(HandlerError::EntrypointNotFoundError)
172+
.and_then(|s| CString::new(s).map_err(HandlerError::StringCovertError))?;
173+
186174
let file_name =
187175
CString::new(format!("{}.py", target.file)).map_err(HandlerError::StringCovertError)?;
188176
let module_name =
@@ -374,9 +362,6 @@ fn setup_python_paths(py: Python, docroot: &Path) -> PyResult<()> {
374362

375363
/// Start a Python thread that runs the event loop forever
376364
fn start_python_event_loop_thread(event_loop: PyObject) {
377-
// Initialize Python for this thread
378-
pyo3::prepare_freethreaded_python();
379-
380365
Python::with_gil(|py| {
381366
// Set the event loop for this thread and run it
382367
let asyncio = py.import("asyncio")?;

0 commit comments

Comments
 (0)