diff --git a/crates/stdlib/src/openssl.rs b/crates/stdlib/src/openssl.rs index bf03813b180..e07ad552f17 100644 --- a/crates/stdlib/src/openssl.rs +++ b/crates/stdlib/src/openssl.rs @@ -2,6 +2,10 @@ mod cert; +// SSL exception types (shared with rustls backend) +#[path = "ssl/error.rs"] +mod ssl_error; + // Conditional compilation for OpenSSL version-specific error codes cfg_if::cfg_if! { if #[cfg(ossl310)] { @@ -45,9 +49,16 @@ cfg_if::cfg_if! { } #[allow(non_upper_case_globals)] -#[pymodule(with(cert::ssl_cert, ossl101, ossl111, windows))] +#[pymodule(with(cert::ssl_cert, ssl_error::ssl_error, ossl101, ossl111, windows))] mod _ssl { use super::{bio, probe}; + + // Import error types used in this module (others are exposed via pymodule(with(...))) + use super::ssl_error::{ + PySSLCertVerificationError as PySslCertVerificationError, PySSLEOFError as PySslEOFError, + PySSLError as PySslError, PySSLWantReadError as PySslWantReadError, + PySSLWantWriteError as PySslWantWriteError, + }; use crate::{ common::lock::{ PyMappedRwLockReadGuard, PyMutex, PyRwLock, PyRwLockReadGuard, PyRwLockWriteGuard, @@ -56,8 +67,8 @@ mod _ssl { vm::{ AsObject, Py, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, builtins::{ - PyBaseException, PyBaseExceptionRef, PyBytesRef, PyListRef, PyOSError, PyStrRef, - PyTypeRef, PyWeak, + PyBaseException, PyBaseExceptionRef, PyBytesRef, PyListRef, PyStrRef, PyType, + PyWeak, }, class_or_notimplemented, convert::ToPyException, @@ -300,85 +311,6 @@ mod _ssl { parse_version_info(openssl_api_version) } - // SSL Exception Types - - /// An error occurred in the SSL implementation. - #[pyattr] - #[pyexception(name = "SSLError", base = PyOSError)] - #[derive(Debug)] - pub struct PySslError {} - - #[pyexception] - impl PySslError { - // Returns strerror attribute if available, otherwise str(args) - #[pymethod] - fn __str__(exc: PyBaseExceptionRef, vm: &VirtualMachine) -> PyResult { - // Try to get strerror attribute first (OSError compatibility) - if let Ok(strerror) = exc.as_object().get_attr("strerror", vm) - && !vm.is_none(&strerror) - { - return strerror.str(vm); - } - - // Otherwise return str(args) - exc.args().as_object().str(vm) - } - } - - /// A certificate could not be verified. - #[pyattr] - #[pyexception(name = "SSLCertVerificationError", base = PySslError)] - #[derive(Debug)] - pub struct PySslCertVerificationError {} - - #[pyexception] - impl PySslCertVerificationError {} - - /// SSL/TLS session closed cleanly. - #[pyattr] - #[pyexception(name = "SSLZeroReturnError", base = PySslError)] - #[derive(Debug)] - pub struct PySslZeroReturnError {} - - #[pyexception] - impl PySslZeroReturnError {} - - /// Non-blocking SSL socket needs to read more data. - #[pyattr] - #[pyexception(name = "SSLWantReadError", base = PySslError)] - #[derive(Debug)] - pub struct PySslWantReadError {} - - #[pyexception] - impl PySslWantReadError {} - - /// Non-blocking SSL socket needs to write more data. - #[pyattr] - #[pyexception(name = "SSLWantWriteError", base = PySslError)] - #[derive(Debug)] - pub struct PySslWantWriteError {} - - #[pyexception] - impl PySslWantWriteError {} - - /// System error when attempting SSL operation. - #[pyattr] - #[pyexception(name = "SSLSyscallError", base = PySslError)] - #[derive(Debug)] - pub struct PySslSyscallError {} - - #[pyexception] - impl PySslSyscallError {} - - /// SSL/TLS connection terminated abruptly. - #[pyattr] - #[pyexception(name = "SSLEOFError", base = PySslError)] - #[derive(Debug)] - pub struct PySslEOFError {} - - #[pyexception] - impl PySslEOFError {} - type OpensslVersionInfo = (u8, u8, u8, u8, u8); const fn parse_version_info(mut n: i64) -> OpensslVersionInfo { let status = (n & 0xF) as u8; @@ -582,18 +514,53 @@ mod _ssl { Ok(buf) } - // Callback data stored in SSL context for SNI + // Callback data stored in SSL ex_data for SNI/msg callbacks struct SniCallbackData { ssl_context: PyRef, - vm_ptr: *const VirtualMachine, + // Use weak reference to avoid reference cycle: + // PySslSocket -> SslStream -> SSL -> ex_data -> SniCallbackData -> PySslSocket + ssl_socket_weak: PyRef, + } + + // Thread-local storage for VirtualMachine pointer during handshake + // SNI callback is only called during handshake which is synchronous + thread_local! { + static HANDSHAKE_VM: std::cell::Cell> = const { std::cell::Cell::new(None) }; + // SSL pointer during handshake - needed because connection lock is held during handshake + // and callbacks may need to access SSL without acquiring the lock + static HANDSHAKE_SSL_PTR: std::cell::Cell> = const { std::cell::Cell::new(None) }; } - impl Drop for SniCallbackData { + // RAII guard to set/clear thread-local handshake context + struct HandshakeVmGuard { + _ssl_ptr: *mut sys::SSL, + } + + impl HandshakeVmGuard { + fn new(vm: &VirtualMachine, ssl_ptr: *mut sys::SSL) -> Self { + HANDSHAKE_VM.with(|cell| cell.set(Some(vm as *const _))); + HANDSHAKE_SSL_PTR.with(|cell| cell.set(Some(ssl_ptr))); + HandshakeVmGuard { _ssl_ptr: ssl_ptr } + } + } + + impl Drop for HandshakeVmGuard { fn drop(&mut self) { - // PyRef will handle reference counting + HANDSHAKE_VM.with(|cell| cell.set(None)); + HANDSHAKE_SSL_PTR.with(|cell| cell.set(None)); } } + // Get SSL pointer - either from thread-local (during handshake) or from connection + fn get_ssl_ptr_for_context_change(connection: &PyRwLock) -> *mut sys::SSL { + // First check if we're in a handshake callback (lock already held) + if let Some(ptr) = HANDSHAKE_SSL_PTR.with(|cell| cell.get()) { + return ptr; + } + // Otherwise, acquire the lock normally + connection.read().ssl().as_ptr() + } + // Get or create an ex_data index for SNI callback data fn get_sni_ex_data_index() -> libc::c_int { use std::sync::LazyLock; @@ -610,17 +577,30 @@ mod _ssl { } // Free function for callback data + // NOTE: We don't free the data here because it's managed manually in do_handshake + // to avoid use-after-free when the SSL object is dropped after timeout unsafe extern "C" fn sni_callback_data_free( _parent: *mut libc::c_void, - ptr: *mut libc::c_void, + _ptr: *mut libc::c_void, _ad: *mut sys::CRYPTO_EX_DATA, _idx: libc::c_int, _argl: libc::c_long, _argp: *mut libc::c_void, ) { - if !ptr.is_null() { - unsafe { - let _ = Box::from_raw(ptr as *mut SniCallbackData); + // Intentionally empty - data is freed in cleanup_sni_ex_data() + } + + // Clean up SNI callback data from SSL ex_data + // Called after handshake to free the data and release references + unsafe fn cleanup_sni_ex_data(ssl_ptr: *mut sys::SSL) { + unsafe { + let idx = get_sni_ex_data_index(); + let data_ptr = sys::SSL_get_ex_data(ssl_ptr, idx); + if !data_ptr.is_null() { + // Free the Box - this releases references to context and socket + let _ = Box::from_raw(data_ptr as *mut SniCallbackData); + // Clear the ex_data to prevent double-free + sys::SSL_set_ex_data(ssl_ptr, idx, std::ptr::null_mut()); } } } @@ -658,9 +638,13 @@ mod _ssl { let callback_data = &*(data_ptr as *const SniCallbackData); - // SAFETY: vm_ptr is stored during wrap_socket and is valid for the lifetime - // of the SSL connection. The handshake happens synchronously in the same thread. - let vm = &*callback_data.vm_ptr; + // Get VM from thread-local storage (set by HandshakeVmGuard in do_handshake) + let Some(vm_ptr) = HANDSHAKE_VM.with(|cell| cell.get()) else { + // VM not available - this shouldn't happen during handshake + *al = SSL_AD_INTERNAL_ERROR; + return SSL_TLSEXT_ERR_ALERT_FATAL; + }; + let vm = &*vm_ptr; // Get server name let servername = sys::SSL_get_servername(ssl_ptr, TLSEXT_NAMETYPE_host_name); @@ -674,20 +658,11 @@ mod _ssl { } }; - // Get SSL socket from SSL ex_data (stored as PySslSocket pointer) - let ssl_socket_ptr = sys::SSL_get_ex_data(ssl_ptr, 0); // Index 0 for SSL socket - let ssl_socket_obj = if !ssl_socket_ptr.is_null() { - let ssl_socket = &*(ssl_socket_ptr as *const PySslSocket); - // Try to get owner first - ssl_socket - .owner - .read() - .as_ref() - .and_then(|weak| weak.upgrade()) - .unwrap_or_else(|| vm.ctx.none()) - } else { - vm.ctx.none() - }; + // Get SSL socket from callback data via weak reference + let ssl_socket_obj = callback_data + .ssl_socket_weak + .upgrade() + .unwrap_or_else(|| vm.ctx.none()); // Call the Python callback match callback.call( @@ -735,81 +710,20 @@ mod _ssl { } // Message callback function called by OpenSSL - // Based on CPython's _PySSL_msg_callback in Modules/_ssl/debughelpers.c + // NOTE: This callback is intentionally a no-op to avoid deadlocks. + // The msg_callback can be called during various SSL operations (read, write, handshake), + // and invoking Python code from within these operations can cause deadlocks + // (see CPython bpo-43577). A proper implementation would require careful lock ordering. unsafe extern "C" fn _msg_callback( - write_p: libc::c_int, - version: libc::c_int, - content_type: libc::c_int, - buf: *const libc::c_void, - len: usize, - ssl_ptr: *mut sys::SSL, + _write_p: libc::c_int, + _version: libc::c_int, + _content_type: libc::c_int, + _buf: *const libc::c_void, + _len: usize, + _ssl_ptr: *mut sys::SSL, _arg: *mut libc::c_void, ) { - if ssl_ptr.is_null() { - return; - } - - unsafe { - // Get SSL socket from SSL_get_app_data (index 0) - let ssl_socket_ptr = sys::SSL_get_ex_data(ssl_ptr, 0); - if ssl_socket_ptr.is_null() { - return; - } - - let ssl_socket = &*(ssl_socket_ptr as *const PySslSocket); - - // Get the callback from the context - let callback_opt = ssl_socket.ctx.read().msg_callback.lock().clone(); - let Some(callback) = callback_opt else { - return; - }; - - // Get callback data from SSL ex_data (for VM) - let idx = get_sni_ex_data_index(); - let data_ptr = sys::SSL_get_ex_data(ssl_ptr, idx); - if data_ptr.is_null() { - return; - } - - let callback_data = &*(data_ptr as *const SniCallbackData); - let vm = &*callback_data.vm_ptr; - - // Get SSL socket owner object - let ssl_socket_obj = ssl_socket - .owner - .read() - .as_ref() - .and_then(|weak| weak.upgrade()) - .unwrap_or_else(|| vm.ctx.none()); - - // Create the message bytes - let buf_slice = std::slice::from_raw_parts(buf as *const u8, len); - let msg_bytes = vm.ctx.new_bytes(buf_slice.to_vec()); - - // Determine direction string - let direction_str = if write_p != 0 { "write" } else { "read" }; - - // Call the Python callback - // Signature: callback(conn, direction, version, content_type, msg_type, data) - // For simplicity, we'll pass msg_type as 0 (would need more parsing to get the actual type) - match callback.call( - ( - ssl_socket_obj, - vm.ctx.new_str(direction_str), - vm.ctx.new_int(version), - vm.ctx.new_int(content_type), - vm.ctx.new_int(0), // msg_type - would need parsing - msg_bytes, - ), - vm, - ) { - Ok(_) => {} - Err(exc) => { - // Log the exception but don't propagate it - vm.run_unraisable(exc, None, vm.ctx.none()); - } - } - } + // Intentionally empty to avoid deadlocks } #[pyfunction(name = "RAND_pseudo_bytes")] @@ -850,7 +764,11 @@ mod _ssl { impl Constructor for PySslContext { type Args = i32; - fn py_new(cls: PyTypeRef, proto_version: Self::Args, vm: &VirtualMachine) -> PyResult { + fn py_new( + _cls: &Py, + proto_version: Self::Args, + vm: &VirtualMachine, + ) -> PyResult { let proto = SslVersion::try_from(proto_version) .map_err(|_| vm.new_value_error("invalid protocol version"))?; let method = match proto { @@ -932,16 +850,14 @@ mod _ssl { sys::X509_VERIFY_PARAM_set_flags(param, sys::X509_V_FLAG_TRUSTED_FIRST); } - PySslContext { + Ok(PySslContext { ctx: PyRwLock::new(builder), check_hostname: AtomicCell::new(check_hostname), protocol: proto, post_handshake_auth: PyMutex::new(false), sni_callback: PyMutex::new(None), msg_callback: PyMutex::new(None), - } - .into_ref_with_type(vm, cls) - .map(Into::into) + }) } } @@ -981,12 +897,9 @@ mod _ssl { if ciphers.contains('\0') { return Err(exceptions::cstring_error(vm)); } - self.builder().set_cipher_list(ciphers).map_err(|_| { - vm.new_exception_msg( - PySslError::class(&vm.ctx).to_owned(), - "No cipher can be selected.".to_owned(), - ) - }) + self.builder() + .set_cipher_list(ciphers) + .map_err(|_| new_ssl_error(vm, "No cipher can be selected.")) } #[pymethod] @@ -1126,16 +1039,10 @@ mod _ssl { let set = !flags & new_flags; if clear != 0 && sys::X509_VERIFY_PARAM_clear_flags(param, clear) == 0 { - return Err(vm.new_exception_msg( - PySslError::class(&vm.ctx).to_owned(), - "Failed to clear verify flags".to_owned(), - )); + return Err(new_ssl_error(vm, "Failed to clear verify flags")); } if set != 0 && sys::X509_VERIFY_PARAM_set_flags(param, set) == 0 { - return Err(vm.new_exception_msg( - PySslError::class(&vm.ctx).to_owned(), - "Failed to set verify flags".to_owned(), - )); + return Err(new_ssl_error(vm, "Failed to set verify flags")); } Ok(()) } @@ -1477,10 +1384,13 @@ mod _ssl { let fp = rustpython_common::fileutils::fopen(path.as_path(), "rb").map_err(|e| { match e.kind() { - std::io::ErrorKind::NotFound => vm.new_exception_msg( - vm.ctx.exceptions.file_not_found_error.to_owned(), - e.to_string(), - ), + std::io::ErrorKind::NotFound => vm + .new_os_subtype_error( + vm.ctx.exceptions.file_not_found_error.to_owned(), + None, + e.to_string(), + ) + .upcast(), _ => vm.new_os_error(e.to_string()), } })?; @@ -1670,15 +1580,15 @@ mod _ssl { ) -> PyResult<(ssl::Ssl, SslServerOrClient, Option)> { // Validate socket type and context protocol if server_side && ctx_ref.protocol == SslVersion::TlsClient { - return Err(vm.new_exception_msg( - PySslError::class(&vm.ctx).to_owned(), - "Cannot create a server socket with a PROTOCOL_TLS_CLIENT context".to_owned(), + return Err(new_ssl_error( + vm, + "Cannot create a server socket with a PROTOCOL_TLS_CLIENT context", )); } if !server_side && ctx_ref.protocol == SslVersion::TlsServer { - return Err(vm.new_exception_msg( - PySslError::class(&vm.ctx).to_owned(), - "Cannot create a client socket with a PROTOCOL_TLS_SERVER context".to_owned(), + return Err(new_ssl_error( + vm, + "Cannot create a client socket with a PROTOCOL_TLS_SERVER context", )); } @@ -1791,21 +1701,22 @@ mod _ssl { let py_ref = py_ssl_socket.into_ref_with_type(vm, PySslSocket::class(&vm.ctx).to_owned())?; - // Set SNI callback data if callback is configured - if zelf.sni_callback.lock().is_some() { + // Check if SNI callback is configured (minimize lock time) + let has_sni_callback = zelf.sni_callback.lock().is_some(); + + // Set SNI callback data if needed (after releasing the lock) + if has_sni_callback { + let ssl_socket_weak = py_ref.as_object().downgrade(None, vm)?; unsafe { let ssl_ptr = py_ref.connection.read().ssl().as_ptr(); - // Store callback data in SSL ex_data + // Store callback data in SSL ex_data - use weak reference to avoid cycle let callback_data = Box::new(SniCallbackData { ssl_context: zelf.clone(), - vm_ptr: vm as *const _, + ssl_socket_weak, }); let idx = get_sni_ex_data_index(); sys::SSL_set_ex_data(ssl_ptr, idx, Box::into_raw(callback_data) as *mut _); - - // Store PyRef pointer (heap-allocated) in ex_data index 0 - sys::SSL_set_ex_data(ssl_ptr, 0, &*py_ref as *const _ as *mut _); } } @@ -1851,21 +1762,22 @@ mod _ssl { let py_ref = py_ssl_socket.into_ref_with_type(vm, PySslSocket::class(&vm.ctx).to_owned())?; - // Set SNI callback data if callback is configured - if zelf.sni_callback.lock().is_some() { + // Check if SNI callback is configured (minimize lock time) + let has_sni_callback = zelf.sni_callback.lock().is_some(); + + // Set SNI callback data if needed (after releasing the lock) + if has_sni_callback { + let ssl_socket_weak = py_ref.as_object().downgrade(None, vm)?; unsafe { let ssl_ptr = py_ref.connection.read().ssl().as_ptr(); - // Store callback data in SSL ex_data + // Store callback data in SSL ex_data - use weak reference to avoid cycle let callback_data = Box::new(SniCallbackData { ssl_context: zelf.clone(), - vm_ptr: vm as *const _, + ssl_socket_weak, }); let idx = get_sni_ex_data_index(); sys::SSL_set_ex_data(ssl_ptr, idx, Box::into_raw(callback_data) as *mut _); - - // Store PyRef pointer (heap-allocated) in ex_data index 0 - sys::SSL_set_ex_data(ssl_ptr, 0, &*py_ref as *const _ as *mut _); } } @@ -1992,10 +1904,7 @@ mod _ssl { } fn socket_closed_error(vm: &VirtualMachine) -> PyBaseExceptionRef { - vm.new_exception_msg( - PySslError::class(&vm.ctx).to_owned(), - "Underlying socket has been closed.".to_owned(), - ) + new_ssl_error(vm, "Underlying socket has been closed.") } // BIO stream wrapper to implement Read/Write traits for MemoryBIO @@ -2152,12 +2061,13 @@ mod _ssl { } #[pygetset(setter)] fn set_context(&self, value: PyRef, vm: &VirtualMachine) -> PyResult<()> { - // Update the SSL context in the underlying SSL object - let stream = self.connection.read(); + // Get SSL pointer - use thread-local during handshake to avoid deadlock + // (connection lock is already held during handshake) + let ssl_ptr = get_ssl_ptr_for_context_change(&self.connection); // Set the new SSL_CTX on the SSL object unsafe { - let result = SSL_set_SSL_CTX(stream.ssl().as_ptr(), value.ctx().as_ptr()); + let result = SSL_set_SSL_CTX(ssl_ptr, value.ctx().as_ptr()); if result.is_null() { return Err(vm.new_runtime_error("Failed to set SSL context".to_owned())); } @@ -2275,6 +2185,12 @@ mod _ssl { .map(cipher_to_tuple) } + #[pymethod] + fn pending(&self) -> i32 { + let stream = self.connection.read(); + unsafe { sys::SSL_pending(stream.ssl().as_ptr()) } + } + #[pymethod] fn shared_ciphers(&self, vm: &VirtualMachine) -> Option { #[cfg(ossl110)] @@ -2450,8 +2366,8 @@ mod _ssl { // Non-blocking would block - this is okay for shutdown // Return the underlying socket } else { - return Err(vm.new_exception_msg( - PySslError::class(&vm.ctx).to_owned(), + return Err(new_ssl_error( + vm, format!("SSL shutdown failed: error code {}", err), )); } @@ -2491,9 +2407,13 @@ mod _ssl { let mut stream = self.connection.write(); let ssl_ptr = stream.ssl().as_ptr(); + // Set up thread-local VM and SSL pointer for callbacks + // This allows callbacks to access SSL without acquiring the connection lock + let _vm_guard = HandshakeVmGuard::new(vm, ssl_ptr); + // BIO mode: no timeout/select logic, just do handshake if stream.is_bio() { - return stream.do_handshake().map_err(|e| { + let result = stream.do_handshake().map_err(|e| { let exc = convert_ssl_error(vm, e); // If it's a cert verification error, set verify info if exc.class().is(PySslCertVerificationError::class(&vm.ctx)) { @@ -2501,6 +2421,10 @@ mod _ssl { } exc }); + // Clean up SNI ex_data after handshake (success or failure) + // SAFETY: ssl_ptr is valid for the lifetime of stream + unsafe { cleanup_sni_ex_data(ssl_ptr) }; + return result; } // Socket mode: handle timeout and blocking @@ -2510,7 +2434,12 @@ mod _ssl { .timeout_deadline(); loop { let err = match stream.do_handshake() { - Ok(()) => return Ok(()), + Ok(()) => { + // Clean up SNI ex_data after successful handshake + // SAFETY: ssl_ptr is valid for the lifetime of stream + unsafe { cleanup_sni_ex_data(ssl_ptr) }; + return Ok(()); + } Err(e) => e, }; let (needs, state) = stream @@ -2519,12 +2448,20 @@ mod _ssl { .socket_needs(&err, &timeout); match state { SelectRet::TimedOut => { + // Clean up SNI ex_data before returning error + // SAFETY: ssl_ptr is valid for the lifetime of stream + unsafe { cleanup_sni_ex_data(ssl_ptr) }; return Err(socket::timeout_error_msg( vm, "The handshake operation timed out".to_owned(), - )); + ) + .upcast()); + } + SelectRet::Closed => { + // SAFETY: ssl_ptr is valid for the lifetime of stream + unsafe { cleanup_sni_ex_data(ssl_ptr) }; + return Err(socket_closed_error(vm)); } - SelectRet::Closed => return Err(socket_closed_error(vm)), SelectRet::Nonblocking => {} SelectRet::IsBlocking | SelectRet::Ok => { // For blocking sockets, select() has completed successfully @@ -2539,6 +2476,9 @@ mod _ssl { if exc.class().is(PySslCertVerificationError::class(&vm.ctx)) { set_verify_error_info(&exc, ssl_ptr, vm); } + // Clean up SNI ex_data before returning error + // SAFETY: ssl_ptr is valid for the lifetime of stream + unsafe { cleanup_sni_ex_data(ssl_ptr) }; return Err(exc); } } @@ -2565,7 +2505,8 @@ mod _ssl { return Err(socket::timeout_error_msg( vm, "The write operation timed out".to_owned(), - )); + ) + .upcast()); } SelectRet::Closed => return Err(socket_closed_error(vm)), _ => {} @@ -2584,7 +2525,8 @@ mod _ssl { return Err(socket::timeout_error_msg( vm, "The write operation timed out".to_owned(), - )); + ) + .upcast()); } SelectRet::Closed => return Err(socket_closed_error(vm)), SelectRet::Nonblocking => {} @@ -2727,7 +2669,8 @@ mod _ssl { return Err(socket::timeout_error_msg( vm, "The read operation timed out".to_owned(), - )); + ) + .upcast()); } SelectRet::Closed => return Err(socket_closed_error(vm)), SelectRet::Nonblocking => {} @@ -3070,7 +3013,7 @@ mod _ssl { impl Constructor for PySslMemoryBio { type Args = (); - fn py_new(cls: PyTypeRef, _args: Self::Args, vm: &VirtualMachine) -> PyResult { + fn py_new(_cls: &Py, _args: Self::Args, vm: &VirtualMachine) -> PyResult { unsafe { let bio = sys::BIO_new(sys::BIO_s_mem()); if bio.is_null() { @@ -3080,12 +3023,10 @@ mod _ssl { sys::BIO_set_retry_read(bio); BIO_set_mem_eof_return(bio, -1); - PySslMemoryBio { + Ok(PySslMemoryBio { bio, eof_written: AtomicCell::new(false), - } - .into_ref_with_type(vm, cls) - .map(Into::into) + }) } } } @@ -3143,10 +3084,7 @@ mod _ssl { #[pymethod] fn write(&self, data: ArgBytesLike, vm: &VirtualMachine) -> PyResult { if self.eof_written.load() { - return Err(vm.new_exception_msg( - PySslError::class(&vm.ctx).to_owned(), - "cannot write() after write_eof()".to_owned(), - )); + return Err(new_ssl_error(vm, "cannot write() after write_eof()")); } data.with_ref(|buf| unsafe { @@ -3235,6 +3173,12 @@ mod _ssl { } } + /// Helper function to create SSL error with proper OSError subtype handling + fn new_ssl_error(vm: &VirtualMachine, msg: impl ToString) -> PyBaseExceptionRef { + vm.new_os_subtype_error(PySslError::class(&vm.ctx).to_owned(), None, msg.to_string()) + .upcast() + } + #[track_caller] pub(crate) fn convert_openssl_error( vm: &VirtualMachine, @@ -3255,12 +3199,7 @@ mod _ssl { } else { vm.ctx.exceptions.os_error.to_owned() }; - let exc = vm.new_exception(exc_type, vec![vm.ctx.new_int(reason).into()]); - // Set errno attribute explicitly - let _ = exc - .as_object() - .set_attr("errno", vm.ctx.new_int(reason), vm); - return exc; + return vm.new_os_subtype_error(exc_type, Some(reason), "").upcast(); } let caller = std::panic::Location::caller(); @@ -3310,13 +3249,8 @@ mod _ssl { // Create exception instance let reason = sys::ERR_GET_REASON(e.code()); - let exc = vm.new_exception( - cls, - vec![vm.ctx.new_int(reason).into(), vm.ctx.new_str(msg).into()], - ); - - // Set attributes on instance, not class - let exc_obj: PyObjectRef = exc.into(); + let exc = vm.new_os_subtype_error(cls, Some(reason), msg); + let exc_obj: PyObjectRef = exc.upcast::().into(); // Set reason attribute (always set, even if just the error string) let reason_value = vm.ctx.new_str(errstr); @@ -3345,7 +3279,8 @@ mod _ssl { } None => { let cls = PySslError::class(&vm.ctx).to_owned(); - vm.new_exception_empty(cls) + vm.new_os_subtype_error(cls, None, "unknown SSL error") + .upcast() } } } @@ -3396,15 +3331,13 @@ mod _ssl { // this is an EOF in violation of protocol -> SSLEOFError // Need to set args[0] = SSL_ERROR_EOF for suppress_ragged_eofs check None => { - return vm.new_exception( - PySslEOFError::class(&vm.ctx).to_owned(), - vec![ - vm.ctx.new_int(SSL_ERROR_EOF).into(), - vm.ctx - .new_str("EOF occurred in violation of protocol") - .into(), - ], - ); + return vm + .new_os_subtype_error( + PySslEOFError::class(&vm.ctx).to_owned(), + Some(SSL_ERROR_EOF as i32), + "EOF occurred in violation of protocol", + ) + .upcast(); } }, ssl::ErrorCode::SSL => { @@ -3417,15 +3350,13 @@ mod _ssl { let reason = sys::ERR_GET_REASON(err_code); let lib = sys::ERR_GET_LIB(err_code); if lib == ERR_LIB_SSL && reason == SSL_R_UNEXPECTED_EOF_WHILE_READING { - return vm.new_exception( - PySslEOFError::class(&vm.ctx).to_owned(), - vec![ - vm.ctx.new_int(SSL_ERROR_EOF).into(), - vm.ctx - .new_str("EOF occurred in violation of protocol") - .into(), - ], - ); + return vm + .new_os_subtype_error( + PySslEOFError::class(&vm.ctx).to_owned(), + Some(SSL_ERROR_EOF as i32), + "EOF occurred in violation of protocol", + ) + .upcast(); } } return convert_openssl_error(vm, ssl_err.clone()); @@ -3440,7 +3371,7 @@ mod _ssl { "A failure in the SSL library occurred", ), }; - vm.new_exception_msg(cls, msg.to_owned()) + vm.new_os_subtype_error(cls, None, msg).upcast() } // SSL_FILETYPE_ASN1 part of _add_ca_certs in CPython @@ -3543,10 +3474,13 @@ mod _ssl { ) -> Result<(), PyBaseExceptionRef> { let root = Path::new(CERT_DIR); if !root.is_dir() { - return Err(vm.new_exception_msg( - vm.ctx.exceptions.file_not_found_error.to_owned(), - CERT_DIR.to_string(), - )); + return Err(vm + .new_os_subtype_error( + vm.ctx.exceptions.file_not_found_error.to_owned(), + None, + CERT_DIR.to_string(), + ) + .upcast()); } let mut combined_pem = String::new(); diff --git a/crates/stdlib/src/openssl/cert.rs b/crates/stdlib/src/openssl/cert.rs index 1139f0e26f0..1197bf4aa46 100644 --- a/crates/stdlib/src/openssl/cert.rs +++ b/crates/stdlib/src/openssl/cert.rs @@ -165,7 +165,8 @@ pub(crate) mod ssl_cert { format!("{}.{}.{}.{}", ip[0], ip[1], ip[2], ip[3]) } else if ip.len() == 16 { // IPv6 - format with all zeros visible (not compressed) - let ip_addr = std::net::Ipv6Addr::from(ip[0..16]); + let ip_addr = + std::net::Ipv6Addr::from(<[u8; 16]>::try_from(&ip[0..16]).unwrap()); let s = ip_addr.segments(); format!( "{:X}:{:X}:{:X}:{:X}:{:X}:{:X}:{:X}:{:X}", diff --git a/crates/stdlib/src/ssl.rs b/crates/stdlib/src/ssl.rs index e6f8c5fda7c..992b32e00ea 100644 --- a/crates/stdlib/src/ssl.rs +++ b/crates/stdlib/src/ssl.rs @@ -22,11 +22,14 @@ mod cert; // OpenSSL compatibility layer (abstracts rustls operations) mod compat; +// SSL exception types (shared with openssl backend) +mod error; + pub(crate) use _ssl::make_module; #[allow(non_snake_case)] #[allow(non_upper_case_globals)] -#[pymodule] +#[pymodule(with(error::ssl_error))] mod _ssl { use crate::{ common::{ @@ -37,15 +40,18 @@ mod _ssl { vm::{ AsObject, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject, VirtualMachine, - builtins::{ - PyBaseExceptionRef, PyBytesRef, PyListRef, PyOSError, PyStrRef, PyType, PyTypeRef, - }, + builtins::{PyBaseExceptionRef, PyBytesRef, PyListRef, PyStrRef, PyType, PyTypeRef}, convert::IntoPyException, function::{ArgBytesLike, ArgMemoryBuffer, FuncArgs, OptionalArg, PyComparisonValue}, stdlib::warnings, types::{Comparable, Constructor, Hashable, PyComparisonOp, Representable}, }, }; + + // Import error types used in this module (others are exposed via pymodule(with(...))) + use super::error::{ + PySSLEOFError, PySSLError, create_ssl_want_read_error, create_ssl_want_write_error, + }; use std::{ collections::HashMap, sync::{ @@ -342,106 +348,6 @@ mod _ssl { #[pyattr] const ENCODING_PEM_AUX: i32 = 0x101; // PEM + 0x100 - #[pyattr] - #[pyexception(name = "SSLError", base = PyOSError)] - #[derive(Debug)] - #[repr(transparent)] - pub struct PySSLError(PyOSError); - - #[pyexception] - impl PySSLError { - // Returns strerror attribute if available, otherwise str(args) - #[pymethod] - fn __str__(exc: PyBaseExceptionRef, vm: &VirtualMachine) -> PyResult { - // Try to get strerror attribute first (OSError compatibility) - if let Ok(strerror) = exc.as_object().get_attr("strerror", vm) - && !vm.is_none(&strerror) - { - return strerror.str(vm); - } - - // Otherwise return str(args) - let args = exc.args(); - if args.len() == 1 { - args.as_slice()[0].str(vm) - } else { - args.as_object().str(vm) - } - } - } - - #[pyattr] - #[pyexception(name = "SSLZeroReturnError", base = PySSLError)] - #[derive(Debug)] - #[repr(transparent)] - pub struct PySSLZeroReturnError(PySSLError); - - #[pyexception] - impl PySSLZeroReturnError {} - - #[pyattr] - #[pyexception(name = "SSLWantReadError", base = PySSLError, impl)] - #[derive(Debug)] - #[repr(transparent)] - pub struct PySSLWantReadError(PySSLError); - - #[pyattr] - #[pyexception(name = "SSLWantWriteError", base = PySSLError, impl)] - #[derive(Debug)] - #[repr(transparent)] - pub struct PySSLWantWriteError(PySSLError); - - #[pyattr] - #[pyexception(name = "SSLSyscallError", base = PySSLError, impl)] - #[derive(Debug)] - #[repr(transparent)] - pub struct PySSLSyscallError(PySSLError); - - #[pyattr] - #[pyexception(name = "SSLEOFError", base = PySSLError, impl)] - #[derive(Debug)] - #[repr(transparent)] - pub struct PySSLEOFError(PySSLError); - - #[pyattr] - #[pyexception(name = "SSLCertVerificationError", base = PySSLError, impl)] - #[derive(Debug)] - #[repr(transparent)] - pub struct PySSLCertVerificationError(PySSLError); - - // Helper functions to create SSL exceptions with proper errno attribute - pub(super) fn create_ssl_want_read_error(vm: &VirtualMachine) -> PyRef { - vm.new_os_subtype_error( - PySSLWantReadError::class(&vm.ctx).to_owned(), - Some(SSL_ERROR_WANT_READ), - "The operation did not complete (read)", - ) - } - - pub(super) fn create_ssl_want_write_error(vm: &VirtualMachine) -> PyRef { - vm.new_os_subtype_error( - PySSLWantWriteError::class(&vm.ctx).to_owned(), - Some(SSL_ERROR_WANT_WRITE), - "The operation did not complete (write)", - ) - } - - pub(crate) fn create_ssl_eof_error(vm: &VirtualMachine) -> PyRef { - vm.new_os_subtype_error( - PySSLEOFError::class(&vm.ctx).to_owned(), - None, - "EOF occurred in violation of protocol", - ) - } - - pub(crate) fn create_ssl_zero_return_error(vm: &VirtualMachine) -> PyRef { - vm.new_os_subtype_error( - PySSLZeroReturnError::class(&vm.ctx).to_owned(), - None, - "TLS/SSL connection has been closed (EOF)", - ) - } - /// Validate server hostname for TLS SNI /// /// Checks that the hostname: diff --git a/crates/stdlib/src/ssl/compat.rs b/crates/stdlib/src/ssl/compat.rs index 4ccc590360a..ab3c81b7a4e 100644 --- a/crates/stdlib/src/ssl/compat.rs +++ b/crates/stdlib/src/ssl/compat.rs @@ -30,10 +30,13 @@ use rustpython_vm::{AsObject, Py, PyObjectRef, PyPayload, PyResult, TryFromObjec use std::io::Read; use std::sync::{Arc, Once}; -// Import PySSLSocket and helper functions from parent module -use super::_ssl::{ - PySSLCertVerificationError, PySSLError, PySSLSocket, create_ssl_eof_error, - create_ssl_want_read_error, create_ssl_want_write_error, create_ssl_zero_return_error, +// Import PySSLSocket from parent module +use super::_ssl::PySSLSocket; + +// Import error types and helper functions from error module +use super::error::{ + PySSLCertVerificationError, PySSLError, create_ssl_eof_error, create_ssl_want_read_error, + create_ssl_want_write_error, create_ssl_zero_return_error, }; // SSL Verification Flags diff --git a/crates/stdlib/src/ssl/error.rs b/crates/stdlib/src/ssl/error.rs new file mode 100644 index 00000000000..e31683ec72d --- /dev/null +++ b/crates/stdlib/src/ssl/error.rs @@ -0,0 +1,117 @@ +// SSL exception types shared between ssl (rustls) and openssl backends + +pub(crate) use ssl_error::*; + +#[pymodule(sub)] +pub(crate) mod ssl_error { + use crate::vm::{ + PyPayload, PyRef, PyResult, VirtualMachine, + builtins::{PyBaseExceptionRef, PyOSError, PyStrRef}, + types::Constructor, + }; + + // Error type constants (needed for create_ssl_want_read_error etc.) + pub(crate) const SSL_ERROR_WANT_READ: i32 = 2; + pub(crate) const SSL_ERROR_WANT_WRITE: i32 = 3; + + #[pyattr] + #[pyexception(name = "SSLError", base = PyOSError)] + #[derive(Debug)] + #[repr(transparent)] + pub struct PySSLError(PyOSError); + + #[pyexception] + impl PySSLError { + // Returns strerror attribute if available, otherwise str(args) + #[pymethod] + fn __str__(exc: PyBaseExceptionRef, vm: &VirtualMachine) -> PyResult { + use crate::vm::AsObject; + // Try to get strerror attribute first (OSError compatibility) + if let Ok(strerror) = exc.as_object().get_attr("strerror", vm) + && !vm.is_none(&strerror) + { + return strerror.str(vm); + } + + // Otherwise return str(args) + let args = exc.args(); + if args.len() == 1 { + args.as_slice()[0].str(vm) + } else { + args.as_object().str(vm) + } + } + } + + #[pyattr] + #[pyexception(name = "SSLZeroReturnError", base = PySSLError)] + #[derive(Debug)] + #[repr(transparent)] + pub struct PySSLZeroReturnError(PySSLError); + + #[pyexception] + impl PySSLZeroReturnError {} + + #[pyattr] + #[pyexception(name = "SSLWantReadError", base = PySSLError, impl)] + #[derive(Debug)] + #[repr(transparent)] + pub struct PySSLWantReadError(PySSLError); + + #[pyattr] + #[pyexception(name = "SSLWantWriteError", base = PySSLError, impl)] + #[derive(Debug)] + #[repr(transparent)] + pub struct PySSLWantWriteError(PySSLError); + + #[pyattr] + #[pyexception(name = "SSLSyscallError", base = PySSLError, impl)] + #[derive(Debug)] + #[repr(transparent)] + pub struct PySSLSyscallError(PySSLError); + + #[pyattr] + #[pyexception(name = "SSLEOFError", base = PySSLError, impl)] + #[derive(Debug)] + #[repr(transparent)] + pub struct PySSLEOFError(PySSLError); + + #[pyattr] + #[pyexception(name = "SSLCertVerificationError", base = PySSLError, impl)] + #[derive(Debug)] + #[repr(transparent)] + pub struct PySSLCertVerificationError(PySSLError); + + // Helper functions to create SSL exceptions with proper errno attribute + pub fn create_ssl_want_read_error(vm: &VirtualMachine) -> PyRef { + vm.new_os_subtype_error( + PySSLWantReadError::class(&vm.ctx).to_owned(), + Some(SSL_ERROR_WANT_READ), + "The operation did not complete (read)", + ) + } + + pub fn create_ssl_want_write_error(vm: &VirtualMachine) -> PyRef { + vm.new_os_subtype_error( + PySSLWantWriteError::class(&vm.ctx).to_owned(), + Some(SSL_ERROR_WANT_WRITE), + "The operation did not complete (write)", + ) + } + + pub fn create_ssl_eof_error(vm: &VirtualMachine) -> PyRef { + vm.new_os_subtype_error( + PySSLEOFError::class(&vm.ctx).to_owned(), + None, + "EOF occurred in violation of protocol", + ) + } + + pub fn create_ssl_zero_return_error(vm: &VirtualMachine) -> PyRef { + vm.new_os_subtype_error( + PySSLZeroReturnError::class(&vm.ctx).to_owned(), + None, + "TLS/SSL connection has been closed (EOF)", + ) + } +}