diff --git a/crates/derive-impl/src/pyclass.rs b/crates/derive-impl/src/pyclass.rs index 86b038f2082..5060dced2b0 100644 --- a/crates/derive-impl/src/pyclass.rs +++ b/crates/derive-impl/src/pyclass.rs @@ -305,6 +305,85 @@ pub(crate) fn impl_pyclass_impl(attr: PunctuatedNestedMeta, item: Item) -> Resul Ok(tokens) } +/// Validates that when a base class is specified, the struct has the base type as its first field. +/// This ensures proper memory layout for subclassing (required for #[repr(transparent)] to work correctly). +fn validate_base_field(item: &Item, base_path: &syn::Path) -> Result<()> { + let Item::Struct(item_struct) = item else { + // Only validate structs - enums with base are already an error elsewhere + return Ok(()); + }; + + // Get the base type name for error messages + let base_name = base_path + .segments + .last() + .map(|s| s.ident.to_string()) + .unwrap_or_else(|| quote!(#base_path).to_string()); + + match &item_struct.fields { + syn::Fields::Named(fields) => { + let Some(first_field) = fields.named.first() else { + bail_span!( + item_struct, + "#[pyclass] with base = {base_name} requires the first field to be of type {base_name}, but the struct has no fields" + ); + }; + if !type_matches_path(&first_field.ty, base_path) { + bail_span!( + first_field, + "#[pyclass] with base = {base_name} requires the first field to be of type {base_name}" + ); + } + } + syn::Fields::Unnamed(fields) => { + let Some(first_field) = fields.unnamed.first() else { + bail_span!( + item_struct, + "#[pyclass] with base = {base_name} requires the first field to be of type {base_name}, but the struct has no fields" + ); + }; + if !type_matches_path(&first_field.ty, base_path) { + bail_span!( + first_field, + "#[pyclass] with base = {base_name} requires the first field to be of type {base_name}" + ); + } + } + syn::Fields::Unit => { + bail_span!( + item_struct, + "#[pyclass] with base = {base_name} requires the first field to be of type {base_name}, but the struct is a unit struct" + ); + } + } + + Ok(()) +} + +/// Check if a type matches a given path (handles simple cases like `Foo` or `path::to::Foo`) +fn type_matches_path(ty: &syn::Type, path: &syn::Path) -> bool { + // Compare by converting both to string representation for macro hygiene + let ty_str = quote!(#ty).to_string().replace(' ', ""); + let path_str = quote!(#path).to_string().replace(' ', ""); + + // Check if both are the same or if the type ends with the path's last segment + if ty_str == path_str { + return true; + } + + // Also match if just the last segment matches (e.g., foo::Bar matches Bar) + let syn::Type::Path(type_path) = ty else { + return false; + }; + let Some(type_last) = type_path.path.segments.last() else { + return false; + }; + let Some(path_last) = path.segments.last() else { + return false; + }; + type_last.ident == path_last.ident +} + fn generate_class_def( ident: &Ident, name: &str, @@ -339,7 +418,6 @@ fn generate_class_def( } else { quote!(false) }; - let basicsize = quote!(std::mem::size_of::<#ident>()); let is_pystruct = attrs.iter().any(|attr| { attr.path().is_ident("derive") && if let Ok(Meta::List(l)) = attr.parse_meta() { @@ -350,6 +428,25 @@ fn generate_class_def( false } }); + // Check if the type has #[repr(transparent)] - only then we can safely + // generate PySubclass impl (requires same memory layout as base type) + let is_repr_transparent = attrs.iter().any(|attr| { + attr.path().is_ident("repr") + && if let Ok(Meta::List(l)) = attr.parse_meta() { + l.nested + .into_iter() + .any(|n| n.get_ident().is_some_and(|p| p == "transparent")) + } else { + false + } + }); + // If repr(transparent) with a base, the type has the same memory layout as base, + // so basicsize should be 0 (no additional space beyond the base type) + let basicsize = if is_repr_transparent && base.is_some() { + quote!(0) + } else { + quote!(std::mem::size_of::<#ident>()) + }; if base.is_some() && is_pystruct { bail_span!(ident, "PyStructSequence cannot have `base` class attr",); } @@ -379,12 +476,31 @@ fn generate_class_def( } }); - let base_or_object = if let Some(base) = base { + let base_or_object = if let Some(ref base) = base { quote! { #base } } else { quote! { ::rustpython_vm::builtins::PyBaseObject } }; + // Generate PySubclass impl for #[repr(transparent)] types with base class + // (tuple struct assumed, so &self.0 works) + let subclass_impl = if !is_pystruct && is_repr_transparent { + base.as_ref().map(|typ| { + quote! { + impl ::rustpython_vm::class::PySubclass for #ident { + type Base = #typ; + + #[inline] + fn as_base(&self) -> &Self::Base { + &self.0 + } + } + } + }) + } else { + None + }; + let tokens = quote! { impl ::rustpython_vm::class::PyClassDef for #ident { const NAME: &'static str = #name; @@ -409,6 +525,8 @@ fn generate_class_def( #base_class } + + #subclass_impl }; Ok(tokens) } @@ -426,11 +544,16 @@ pub(crate) fn impl_pyclass(attr: PunctuatedNestedMeta, item: Item) -> Result Result::static_type() } + }; - // We need this to make extend mechanism work: quote! { + // static_assertions::const_assert!(std::mem::size_of::<#base_type>() <= std::mem::size_of::<#ident>()); impl ::rustpython_vm::PyPayload for #ident { + #[inline] + fn payload_type_id() -> ::std::any::TypeId { + <#base_type as ::rustpython_vm::PyPayload>::payload_type_id() + } + + #[inline] + fn validate_downcastable_from(obj: &::rustpython_vm::PyObject) -> bool { + ::BASICSIZE <= obj.class().slots.basicsize && obj.class().fast_issubclass(::static_type()) + } + fn class(ctx: &::rustpython_vm::vm::Context) -> &'static ::rustpython_vm::Py<::rustpython_vm::builtins::PyType> { - ctx.types.#ctx_type_ident + #class_fn } } } } else { - quote! {} + if let Some(ctx_type_name) = class_meta.ctx_name()? { + let ctx_type_ident = Ident::new(&ctx_type_name, ident.span()); + quote! { + impl ::rustpython_vm::PyPayload for #ident { + fn class(ctx: &::rustpython_vm::vm::Context) -> &'static ::rustpython_vm::Py<::rustpython_vm::builtins::PyType> { + ctx.types.#ctx_type_ident + } + } + } + } else { + quote! {} + } }; let empty_impl = if let Some(attrs) = class_meta.impl_attrs()? { @@ -536,26 +687,6 @@ pub(crate) fn impl_pyexception(attr: PunctuatedNestedMeta, item: Item) -> Result let class_name = class_meta.class_name()?; let base_class_name = class_meta.base()?; - let impl_payload = if let Some(ctx_type_name) = class_meta.ctx_name()? { - let ctx_type_ident = Ident::new(&ctx_type_name, ident.span()); // FIXME span - - // We need this to make extend mechanism work: - quote! { - impl ::rustpython_vm::PyPayload for #ident { - fn class(ctx: &::rustpython_vm::vm::Context) -> &'static ::rustpython_vm::Py<::rustpython_vm::builtins::PyType> { - ctx.exceptions.#ctx_type_ident - } - } - } - } else { - quote! { - impl ::rustpython_vm::PyPayload for #ident { - fn class(_ctx: &::rustpython_vm::vm::Context) -> &'static ::rustpython_vm::Py<::rustpython_vm::builtins::PyType> { - ::static_type() - } - } - } - }; let impl_pyclass = if class_meta.has_impl()? { quote! { #[pyexception] @@ -568,7 +699,6 @@ pub(crate) fn impl_pyexception(attr: PunctuatedNestedMeta, item: Item) -> Result let ret = quote! { #[pyclass(module = false, name = #class_name, base = #base_class_name)] #item - #impl_payload #impl_pyclass }; Ok(ret) @@ -585,7 +715,8 @@ pub(crate) fn impl_pyexception_impl(attr: PunctuatedNestedMeta, item: Item) -> R let mut extra_attrs = Vec::new(); for nested in &attr { if let NestedMeta::Meta(Meta::List(MetaList { path, nested, .. })) = nested { - if path.is_ident("with") { + // If we already found the constructor trait, no need to keep looking for it + if !has_slot_new && path.is_ident("with") { // Check if Constructor is in the list for meta in nested { if let NestedMeta::Meta(Meta::Path(p)) = meta @@ -1078,9 +1209,8 @@ impl GetSetNursery { item_ident: Ident, ) -> Result<()> { assert!(!self.validated, "new item is not allowed after validation"); - if !matches!(kind, GetSetItemKind::Get) && !cfgs.is_empty() { - bail_span!(item_ident, "Only the getter can have #[cfg]",); - } + // Note: Both getter and setter can have #[cfg], but they must have matching cfgs + // since the map key is (name, cfgs). This ensures getter and setter are paired correctly. let entry = self.map.entry((name.clone(), cfgs)).or_default(); let func = match kind { GetSetItemKind::Get => &mut entry.0, diff --git a/crates/derive-impl/src/pystructseq.rs b/crates/derive-impl/src/pystructseq.rs index c43673fe975..6c34844696d 100644 --- a/crates/derive-impl/src/pystructseq.rs +++ b/crates/derive-impl/src/pystructseq.rs @@ -446,8 +446,10 @@ pub(crate) fn impl_pystruct_sequence( }; let output = quote! { - // The Python type struct (user-defined, possibly empty) - #pytype_vis struct #pytype_ident; + // The Python type struct - newtype wrapping PyTuple + #[derive(Debug)] + #[repr(transparent)] + #pytype_vis struct #pytype_ident(pub ::rustpython_vm::builtins::PyTuple); // PyClassDef for Python type impl ::rustpython_vm::class::PyClassDef for #pytype_ident { @@ -476,10 +478,37 @@ pub(crate) fn impl_pystruct_sequence( } } - // MaybeTraverse (empty - no GC fields in empty struct) + // Subtype uses base type's payload_type_id + impl ::rustpython_vm::PyPayload for #pytype_ident { + #[inline] + fn payload_type_id() -> ::std::any::TypeId { + <::rustpython_vm::builtins::PyTuple as ::rustpython_vm::PyPayload>::payload_type_id() + } + + #[inline] + fn validate_downcastable_from(obj: &::rustpython_vm::PyObject) -> bool { + obj.class().fast_issubclass(::static_type()) + } + + fn class(_ctx: &::rustpython_vm::vm::Context) -> &'static ::rustpython_vm::Py<::rustpython_vm::builtins::PyType> { + ::static_type() + } + } + + // MaybeTraverse - delegate to inner PyTuple impl ::rustpython_vm::object::MaybeTraverse for #pytype_ident { - fn try_traverse(&self, _traverse_fn: &mut ::rustpython_vm::object::TraverseFn<'_>) { - // Empty struct has no fields to traverse + fn try_traverse(&self, traverse_fn: &mut ::rustpython_vm::object::TraverseFn<'_>) { + self.0.try_traverse(traverse_fn) + } + } + + // PySubclass for proper inheritance + impl ::rustpython_vm::class::PySubclass for #pytype_ident { + type Base = ::rustpython_vm::builtins::PyTuple; + + #[inline] + fn as_base(&self) -> &Self::Base { + &self.0 } } diff --git a/crates/stdlib/src/socket.rs b/crates/stdlib/src/socket.rs index b4e4dac88aa..08b05b56aa8 100644 --- a/crates/stdlib/src/socket.rs +++ b/crates/stdlib/src/socket.rs @@ -15,7 +15,7 @@ mod _socket { use crate::common::lock::{PyMappedRwLockReadGuard, PyRwLock, PyRwLockReadGuard}; use crate::vm::{ AsObject, Py, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, - builtins::{PyBaseExceptionRef, PyListRef, PyStrRef, PyTupleRef, PyTypeRef}, + builtins::{PyBaseExceptionRef, PyListRef, PyOSError, PyStrRef, PyTupleRef, PyTypeRef}, common::os::ErrorExt, convert::{IntoPyException, ToPyObject, TryFromBorrowedObject, TryFromObject}, function::{ArgBytesLike, ArgMemoryBuffer, Either, FsPath, OptionalArg, OptionalOption}, @@ -1826,6 +1826,11 @@ mod _socket { Self::Py(exc) } } + impl From> for IoOrPyException { + fn from(exc: PyRef) -> Self { + Self::Py(exc.upcast()) + } + } impl From for IoOrPyException { fn from(err: io::Error) -> Self { Self::Io(err) @@ -1844,7 +1849,7 @@ mod _socket { #[inline] fn into_pyexception(self, vm: &VirtualMachine) -> PyBaseExceptionRef { match self { - Self::Timeout => timeout_error(vm), + Self::Timeout => timeout_error(vm).upcast(), Self::Py(exc) => exc, Self::Io(err) => err.into_pyexception(vm), } @@ -2412,18 +2417,15 @@ mod _socket { SocketError::GaiError => gaierror(vm), SocketError::HError => herror(vm), }; - vm.new_exception( - exception_cls, - vec![vm.new_pyobj(err.error_num()), vm.ctx.new_str(strerr).into()], - ) - .into() + vm.new_os_subtype_error(exception_cls, Some(err.error_num()), strerr) + .into() } - fn timeout_error(vm: &VirtualMachine) -> PyBaseExceptionRef { + fn timeout_error(vm: &VirtualMachine) -> PyRef { timeout_error_msg(vm, "timed out".to_owned()) } - pub(crate) fn timeout_error_msg(vm: &VirtualMachine, msg: String) -> PyBaseExceptionRef { - vm.new_exception_msg(timeout(vm), msg) + pub(crate) fn timeout_error_msg(vm: &VirtualMachine, msg: String) -> PyRef { + vm.new_os_subtype_error(timeout(vm), None, msg) } fn get_ipv6_addr_str(ipv6: Ipv6Addr) -> String { diff --git a/crates/stdlib/src/ssl.rs b/crates/stdlib/src/ssl.rs index 14e087668bb..c23062d639d 100644 --- a/crates/stdlib/src/ssl.rs +++ b/crates/stdlib/src/ssl.rs @@ -37,7 +37,9 @@ mod _ssl { vm::{ AsObject, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject, VirtualMachine, - builtins::{PyBaseExceptionRef, PyBytesRef, PyListRef, PyStrRef, PyType, PyTypeRef}, + builtins::{ + PyBaseExceptionRef, PyBytesRef, PyListRef, PyOSError, PyStrRef, PyType, PyTypeRef, + }, convert::IntoPyException, function::{ArgBytesLike, ArgMemoryBuffer, FuncArgs, OptionalArg, PyComparisonValue}, stdlib::warnings, @@ -340,13 +342,11 @@ mod _ssl { #[pyattr] const ENCODING_PEM_AUX: i32 = 0x101; // PEM + 0x100 - // Exception types - use rustpython_vm::builtins::PyOSError; - #[pyattr] #[pyexception(name = "SSLError", base = PyOSError)] #[derive(Debug)] - pub struct PySSLError {} + #[repr(transparent)] + pub struct PySSLError(PyOSError); #[pyexception] impl PySSLError { @@ -373,89 +373,72 @@ mod _ssl { #[pyattr] #[pyexception(name = "SSLZeroReturnError", base = PySSLError)] #[derive(Debug)] - pub struct PySSLZeroReturnError {} + #[repr(transparent)] + pub struct PySSLZeroReturnError(PySSLError); #[pyexception] impl PySSLZeroReturnError {} #[pyattr] - #[pyexception(name = "SSLWantReadError", base = PySSLError)] + #[pyexception(name = "SSLWantReadError", base = PySSLError, impl)] #[derive(Debug)] - pub struct PySSLWantReadError {} - - #[pyexception] - impl PySSLWantReadError {} + #[repr(transparent)] + pub struct PySSLWantReadError(PySSLError); #[pyattr] - #[pyexception(name = "SSLWantWriteError", base = PySSLError)] + #[pyexception(name = "SSLWantWriteError", base = PySSLError, impl)] #[derive(Debug)] - pub struct PySSLWantWriteError {} - - #[pyexception] - impl PySSLWantWriteError {} + #[repr(transparent)] + pub struct PySSLWantWriteError(PySSLError); #[pyattr] - #[pyexception(name = "SSLSyscallError", base = PySSLError)] + #[pyexception(name = "SSLSyscallError", base = PySSLError, impl)] #[derive(Debug)] - pub struct PySSLSyscallError {} - - #[pyexception] - impl PySSLSyscallError {} + #[repr(transparent)] + pub struct PySSLSyscallError(PySSLError); #[pyattr] - #[pyexception(name = "SSLEOFError", base = PySSLError)] + #[pyexception(name = "SSLEOFError", base = PySSLError, impl)] #[derive(Debug)] - pub struct PySSLEOFError {} - - #[pyexception] - impl PySSLEOFError {} + #[repr(transparent)] + pub struct PySSLEOFError(PySSLError); #[pyattr] - #[pyexception(name = "SSLCertVerificationError", base = PySSLError)] + #[pyexception(name = "SSLCertVerificationError", base = PySSLError, impl)] #[derive(Debug)] - pub struct PySSLCertVerificationError {} - - #[pyexception] - impl PySSLCertVerificationError {} + #[repr(transparent)] + pub struct PySSLCertVerificationError(PySSLError); // Helper functions to create SSL exceptions with proper errno attribute - pub(super) fn create_ssl_want_read_error(vm: &VirtualMachine) -> PyBaseExceptionRef { - // args = (errno, message) - vm.new_exception( + pub(super) fn create_ssl_want_read_error(vm: &VirtualMachine) -> PyRef { + vm.new_os_subtype_error( PySSLWantReadError::class(&vm.ctx).to_owned(), - vec![ - vm.ctx.new_int(SSL_ERROR_WANT_READ).into(), - vm.ctx - .new_str("The operation did not complete (read)") - .into(), - ], + Some(SSL_ERROR_WANT_READ), + "The operation did not complete (read)", ) } - pub(super) fn create_ssl_want_write_error(vm: &VirtualMachine) -> PyBaseExceptionRef { - // args = (errno, message) - vm.new_exception( + pub(super) fn create_ssl_want_write_error(vm: &VirtualMachine) -> PyRef { + vm.new_os_subtype_error( PySSLWantWriteError::class(&vm.ctx).to_owned(), - vec![ - vm.ctx.new_int(SSL_ERROR_WANT_WRITE).into(), - vm.ctx - .new_str("The operation did not complete (write)") - .into(), - ], + Some(SSL_ERROR_WANT_WRITE), + "The operation did not complete (write)", ) } - pub(crate) fn create_ssl_eof_error(vm: &VirtualMachine) -> PyBaseExceptionRef { - vm.new_exception_msg( + pub(crate) fn create_ssl_eof_error(vm: &VirtualMachine) -> PyRef { + vm.new_os_subtype_error( PySSLEOFError::class(&vm.ctx).to_owned(), - "EOF occurred in violation of protocol".to_owned(), + None, + "EOF occurred in violation of protocol", ) } - pub(crate) fn create_ssl_zero_return_error(vm: &VirtualMachine) -> PyBaseExceptionRef { - vm.new_exception_msg( + pub(crate) fn create_ssl_zero_return_error(vm: &VirtualMachine) -> PyRef { + vm.new_os_subtype_error( PySSLZeroReturnError::class(&vm.ctx).to_owned(), - "TLS/SSL connection has been closed (EOF)".to_owned(), + None, + "TLS/SSL connection has been closed (EOF)", ) } @@ -1250,7 +1233,12 @@ mod _ssl { let msg = io_err.to_string(); if msg.contains("Failed to decrypt") || msg.contains("wrong password") { // Wrong password error - vm.new_exception_msg(PySSLError::class(&vm.ctx).to_owned(), msg) + vm.new_os_subtype_error( + PySSLError::class(&vm.ctx).to_owned(), + None, + msg, + ) + .upcast() } else { // [SSL] PEM lib super::compat::SslError::create_ssl_error_with_reason( @@ -1282,14 +1270,13 @@ mod _ssl { // Validate certificate and key match cert::validate_cert_key_match(&certs, &key).map_err(|e| { - vm.new_exception_msg( - PySSLError::class(&vm.ctx).to_owned(), - if e.contains("key values mismatch") { - "[SSL: KEY_VALUES_MISMATCH] key values mismatch".to_owned() - } else { - e - }, - ) + let msg = if e.contains("key values mismatch") { + "[SSL: KEY_VALUES_MISMATCH] key values mismatch".to_owned() + } else { + e + }; + vm.new_os_subtype_error(PySSLError::class(&vm.ctx).to_owned(), Some(0), msg) + .upcast() })?; // Auto-build certificate chain: if only leaf cert is in file, try to add CA certs @@ -1311,18 +1298,23 @@ mod _ssl { // Additional validation: Create CertifiedKey to ensure rustls accepts it let signing_key = rustls::crypto::aws_lc_rs::sign::any_supported_type(&key).map_err(|_| { - vm.new_exception_msg( + vm.new_os_subtype_error( PySSLError::class(&vm.ctx).to_owned(), - "[SSL: KEY_VALUES_MISMATCH] key values mismatch".to_owned(), + None, + "[SSL: KEY_VALUES_MISMATCH] key values mismatch", ) + .upcast() })?; let certified_key = CertifiedKey::new(full_chain.clone(), signing_key); if certified_key.keys_match().is_err() { - return Err(vm.new_exception_msg( - PySSLError::class(&vm.ctx).to_owned(), - "[SSL: KEY_VALUES_MISMATCH] key values mismatch".to_owned(), - )); + return Err(vm + .new_os_subtype_error( + PySSLError::class(&vm.ctx).to_owned(), + None, + "[SSL: KEY_VALUES_MISMATCH] key values mismatch", + ) + .upcast()); } // Add cert/key pair to collection (OpenSSL allows multiple cert/key pairs) @@ -1517,9 +1509,7 @@ mod _ssl { } if *self.x509_cert_count.read() == 0 { - return Err(vm.new_os_error( - "Failed to load certificates from Windows store".to_owned(), - )); + return Err(vm.new_os_error("Failed to load certificates from Windows store")); } Ok(()) @@ -1626,8 +1616,10 @@ mod _ssl { let cipher_str = ciphers.as_str(); // Parse cipher string and store selected ciphers - let selected_ciphers = parse_cipher_string(cipher_str) - .map_err(|e| vm.new_exception_msg(PySSLError::class(&vm.ctx).to_owned(), e))?; + let selected_ciphers = parse_cipher_string(cipher_str).map_err(|e| { + vm.new_os_subtype_error(PySSLError::class(&vm.ctx).to_owned(), None, e) + .upcast() + })?; // Store in context *self.selected_ciphers.write() = Some(selected_ciphers); @@ -1875,16 +1867,17 @@ mod _ssl { // Check if file exists if !std::path::Path::new(&path_str).exists() { - // Create FileNotFoundError with errno=ENOENT (2) using args - let exc = vm.new_exception( + // Create FileNotFoundError with errno=ENOENT (2) + let exc = vm.new_os_subtype_error( vm.ctx.exceptions.file_not_found_error.to_owned(), - vec![ - vm.ctx.new_int(2).into(), // errno = ENOENT (2) - vm.ctx.new_str("No such file or directory").into(), - vm.ctx.new_str(path_str.clone()).into(), // filename - ], + Some(2), // errno = ENOENT (2) + "No such file or directory", ); - return Err(exc); + // Set filename attribute + let _ = exc + .as_object() + .set_attr("filename", vm.ctx.new_str(path_str.clone()), vm); + return Err(exc.upcast()); } // Validate that the file contains DH parameters @@ -1988,16 +1981,22 @@ mod _ssl { // Validate socket type and context protocol if args.server_side && zelf.protocol == PROTOCOL_TLS_CLIENT { - return Err(vm.new_exception_msg( - PySSLError::class(&vm.ctx).to_owned(), - "Cannot create a server socket with a PROTOCOL_TLS_CLIENT context".to_owned(), - )); + return Err(vm + .new_os_subtype_error( + PySSLError::class(&vm.ctx).to_owned(), + None, + "Cannot create a server socket with a PROTOCOL_TLS_CLIENT context", + ) + .upcast()); } if !args.server_side && zelf.protocol == PROTOCOL_TLS_SERVER { - return Err(vm.new_exception_msg( - PySSLError::class(&vm.ctx).to_owned(), - "Cannot create a client socket with a PROTOCOL_TLS_SERVER context".to_owned(), - )); + return Err(vm + .new_os_subtype_error( + PySSLError::class(&vm.ctx).to_owned(), + None, + "Cannot create a client socket with a PROTOCOL_TLS_SERVER context", + ) + .upcast()); } // Create _SSLSocket instance @@ -2052,16 +2051,22 @@ mod _ssl { // Validate socket type and context protocol if server_side && zelf.protocol == PROTOCOL_TLS_CLIENT { - return Err(vm.new_exception_msg( - PySSLError::class(&vm.ctx).to_owned(), - "Cannot create a server socket with a PROTOCOL_TLS_CLIENT context".to_owned(), - )); + return Err(vm + .new_os_subtype_error( + PySSLError::class(&vm.ctx).to_owned(), + None, + "Cannot create a server socket with a PROTOCOL_TLS_CLIENT context", + ) + .upcast()); } if !server_side && zelf.protocol == PROTOCOL_TLS_SERVER { - return Err(vm.new_exception_msg( - PySSLError::class(&vm.ctx).to_owned(), - "Cannot create a client socket with a PROTOCOL_TLS_SERVER context".to_owned(), - )); + return Err(vm + .new_os_subtype_error( + PySSLError::class(&vm.ctx).to_owned(), + None, + "Cannot create a client socket with a PROTOCOL_TLS_SERVER context", + ) + .upcast()); } // Create _SSLSocket instance with BIO mode @@ -2209,20 +2214,26 @@ mod _ssl { // Preserve specific error messages from cert.rs let err_msg = e.to_string(); if err_msg.contains("no start line") { - // no start line: cadata does not contain a certificate - vm.new_exception_msg( + vm.new_os_subtype_error( PySSLError::class(&vm.ctx).to_owned(), - "no start line: cadata does not contain a certificate".to_string(), + None, + "no start line: cadata does not contain a certificate", ) + .upcast() } else if err_msg.contains("not enough data") { - // not enough data: cadata does not contain a certificate - vm.new_exception_msg( + vm.new_os_subtype_error( PySSLError::class(&vm.ctx).to_owned(), - "not enough data: cadata does not contain a certificate".to_string(), + None, + "not enough data: cadata does not contain a certificate", ) + .upcast() } else { - // Generic PEM error - vm.new_exception_msg(PySSLError::class(&vm.ctx).to_owned(), err_msg) + vm.new_os_subtype_error( + PySSLError::class(&vm.ctx).to_owned(), + None, + err_msg, + ) + .upcast() } }) } @@ -2774,10 +2785,13 @@ mod _ssl { // - Re-acquire connection lock after callback // - Call: connection.send_fatal_alert(AlertDescription::InternalError) // - Then close connection - let exc = vm.new_exception_msg( - PySSLError::class(&vm.ctx).to_owned(), - "SNI callback returned invalid type".to_owned(), - ); + let exc: PyBaseExceptionRef = vm + .new_os_subtype_error( + PySSLError::class(&vm.ctx).to_owned(), + None, + "SNI callback returned invalid type", + ) + .upcast(); let _ = exc.as_object().set_attr( "reason", vm.ctx.new_str("TLSV1_ALERT_INTERNAL_ERROR"), @@ -3287,7 +3301,7 @@ mod _ssl { len: OptionalArg, buffer: OptionalArg, vm: &VirtualMachine, - ) -> PyResult { + ) -> PyResult { // Convert len to usize, defaulting to 1024 if not provided // -1 means read all available data (treat as large buffer size) let len_val = len.unwrap_or(PEM_BUFSIZE as isize); @@ -3330,10 +3344,13 @@ mod _ssl { // After unwrap()/shutdown(), read operations should fail with SSLError let shutdown_state = *self.shutdown_state.lock(); if shutdown_state != ShutdownState::NotStarted { - return Err(vm.new_exception_msg( - PySSLError::class(&vm.ctx).to_owned(), - "cannot read after shutdown".to_owned(), - )); + return Err(vm + .new_os_subtype_error( + PySSLError::class(&vm.ctx).to_owned(), + None, + "cannot read after shutdown", + ) + .upcast()); } // Helper function to handle return value based on buffer presence @@ -3380,10 +3397,13 @@ mod _ssl { } Err(crate::ssl::compat::SslError::Eof) => { // EOF occurred in violation of protocol (unexpected closure) - Err(vm.new_exception_msg( - PySSLEOFError::class(&vm.ctx).to_owned(), - "EOF occurred in violation of protocol".to_owned(), - )) + Err(vm + .new_os_subtype_error( + PySSLEOFError::class(&vm.ctx).to_owned(), + None, + "EOF occurred in violation of protocol", + ) + .upcast()) } Err(crate::ssl::compat::SslError::ZeroReturn) => { // Clean closure with close_notify - return empty data @@ -3391,13 +3411,15 @@ mod _ssl { } Err(crate::ssl::compat::SslError::WantRead) => { // Non-blocking mode: would block - Err(create_ssl_want_read_error(vm)) + Err(create_ssl_want_read_error(vm).upcast()) } Err(crate::ssl::compat::SslError::WantWrite) => { // Non-blocking mode: would block on write - Err(create_ssl_want_write_error(vm)) + Err(create_ssl_want_write_error(vm).upcast()) + } + Err(crate::ssl::compat::SslError::Timeout(msg)) => { + Err(timeout_error_msg(vm, msg).upcast()) } - Err(crate::ssl::compat::SslError::Timeout(msg)) => Err(timeout_error_msg(vm, msg)), Err(crate::ssl::compat::SslError::Py(e)) => { // Python exception - pass through Err(e) @@ -3453,10 +3475,13 @@ mod _ssl { // After unwrap()/shutdown(), write operations should fail with SSLError let shutdown_state = *self.shutdown_state.lock(); if shutdown_state != ShutdownState::NotStarted { - return Err(vm.new_exception_msg( - PySSLError::class(&vm.ctx).to_owned(), - "cannot write after shutdown".to_owned(), - )); + return Err(vm + .new_os_subtype_error( + PySSLError::class(&vm.ctx).to_owned(), + None, + "cannot write after shutdown", + ) + .upcast()); } { @@ -3511,7 +3536,9 @@ mod _ssl { Ok(_) => {} Err(e) => { if is_blocking_io_error(&e, vm) { - return Err(create_ssl_want_write_error(vm)); + return Err( + create_ssl_want_write_error(vm).upcast() + ); } return Err(e); } @@ -3902,7 +3929,7 @@ mod _ssl { // Still waiting for peer's close-notify // Raise SSLWantReadError to signal app needs to transfer data // This is correct for non-blocking sockets and BIO mode - return Err(create_ssl_want_read_error(vm)); + return Err(create_ssl_want_read_error(vm).upcast()); } // Both close-notify exchanged, shutdown complete *self.shutdown_state.lock() = ShutdownState::Completed; @@ -4066,15 +4093,17 @@ mod _ssl { // The rustls TLS library does not support requesting client certificates // after the initial handshake is completed. // Raise SSLError instead of NotImplementedError for compatibility - Err(vm.new_exception_msg( - PySSLError::class(&vm.ctx).to_owned(), - "Post-handshake authentication is not supported by the rustls backend. \ + Err(vm + .new_os_subtype_error( + PySSLError::class(&vm.ctx).to_owned(), + None, + "Post-handshake authentication is not supported by the rustls backend. \ The rustls TLS library does not provide an API to request client certificates \ after the initial handshake. Consider requesting the client certificate \ during the initial handshake by setting the appropriate verify_mode before \ - calling do_handshake()." - .to_owned(), - )) + calling do_handshake().", + ) + .upcast()) } #[pymethod] diff --git a/crates/stdlib/src/ssl/compat.rs b/crates/stdlib/src/ssl/compat.rs index a55b3058884..798542f210a 100644 --- a/crates/stdlib/src/ssl/compat.rs +++ b/crates/stdlib/src/ssl/compat.rs @@ -222,7 +222,11 @@ pub(super) fn create_ssl_cert_verification_error( let msg = format!("[SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: {verify_message}",); - let exc = vm.new_exception_msg(PySSLCertVerificationError::class(&vm.ctx).to_owned(), msg); + let exc = vm.new_os_subtype_error( + PySSLCertVerificationError::class(&vm.ctx).to_owned(), + None, + msg, + ); // Set verify_code and verify_message attributes // Ignore errors as they're extremely rare (e.g., out of memory) @@ -248,7 +252,7 @@ pub(super) fn create_ssl_cert_verification_error( vm, )?; - Ok(exc) + Ok(exc.upcast()) } /// Unified TLS connection type (client or server) @@ -481,10 +485,7 @@ impl SslError { let msg = message.into(); // SSLError args should be (errno, message) format // FIXME: Use 1 as generic SSL error code - let exc = vm.new_exception( - PySSLError::class(&vm.ctx).to_owned(), - vec![vm.new_pyobj(1i32), vm.new_pyobj(msg)], - ); + let exc = vm.new_os_subtype_error(PySSLError::class(&vm.ctx).to_owned(), Some(1), msg); // Set library and reason attributes // Ignore errors as they're extremely rare (e.g., out of memory) @@ -497,7 +498,7 @@ impl SslError { exc.as_object() .set_attr("reason", vm.ctx.new_str(reason).as_object().to_owned(), vm); - exc + exc.upcast() } /// Create SSLError with library and reason from ssl_data codes @@ -542,19 +543,22 @@ impl SslError { /// Convert to Python exception pub fn into_py_err(self, vm: &VirtualMachine) -> PyBaseExceptionRef { match self { - SslError::WantRead => create_ssl_want_read_error(vm), - SslError::WantWrite => create_ssl_want_write_error(vm), - SslError::Timeout(msg) => timeout_error_msg(vm, msg), + SslError::WantRead => create_ssl_want_read_error(vm).upcast(), + SslError::WantWrite => create_ssl_want_write_error(vm).upcast(), + SslError::Timeout(msg) => timeout_error_msg(vm, msg).upcast(), SslError::Syscall(msg) => { // Create SSLError with library=None for syscall errors during SSL operations Self::create_ssl_error_with_reason(vm, None, &msg, msg.clone()) } - SslError::Ssl(msg) => vm.new_exception_msg( - PySSLError::class(&vm.ctx).to_owned(), - format!("SSL error: {msg}"), - ), - SslError::ZeroReturn => create_ssl_zero_return_error(vm), - SslError::Eof => create_ssl_eof_error(vm), + SslError::Ssl(msg) => vm + .new_os_subtype_error( + PySSLError::class(&vm.ctx).to_owned(), + None, + format!("SSL error: {msg}"), + ) + .upcast(), + SslError::ZeroReturn => create_ssl_zero_return_error(vm).upcast(), + SslError::Eof => create_ssl_eof_error(vm).upcast(), SslError::CertVerification(cert_err) => { // Use the proper cert verification error creator create_ssl_cert_verification_error(vm, &cert_err).expect("unlikely to happen") diff --git a/crates/vm/src/builtins/bool.rs b/crates/vm/src/builtins/bool.rs index 9b519dbbde5..6b3ddd8241a 100644 --- a/crates/vm/src/builtins/bool.rs +++ b/crates/vm/src/builtins/bool.rs @@ -1,8 +1,7 @@ use super::{PyInt, PyStrRef, PyType, PyTypeRef}; use crate::common::format::FormatSpec; use crate::{ - AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyResult, TryFromBorrowedObject, - VirtualMachine, + AsObject, Context, Py, PyObject, PyObjectRef, PyResult, TryFromBorrowedObject, VirtualMachine, class::PyClassImpl, convert::{IntoPyException, ToPyObject, ToPyResult}, function::{FuncArgs, OptionalArg}, @@ -21,10 +20,15 @@ impl ToPyObject for bool { impl<'a> TryFromBorrowedObject<'a> for bool { fn try_from_borrowed_object(vm: &VirtualMachine, obj: &'a PyObject) -> PyResult { - if obj.fast_isinstance(vm.ctx.types.int_type) { - Ok(get_value(obj)) - } else { - Err(vm.new_type_error(format!("Expected type bool, not {}", obj.class().name()))) + // Python takes integers as a legit bool value + match obj.downcast_ref::() { + Some(int_obj) => { + let int_val = int_obj.as_bigint(); + Ok(!int_val.is_zero()) + } + None => { + Err(vm.new_type_error(format!("Expected type bool, not {}", obj.class().name()))) + } } } } @@ -81,19 +85,14 @@ impl PyObjectRef { } } -#[pyclass(name = "bool", module = false, base = PyInt)] -pub struct PyBool; - -impl PyPayload for PyBool { - #[inline] - fn class(ctx: &Context) -> &'static Py { - ctx.types.bool_type - } -} +#[pyclass(name = "bool", module = false, base = PyInt, ctx = "bool_type")] +#[repr(transparent)] +pub struct PyBool(pub PyInt); impl Debug for PyBool { - fn fmt(&self, _f: &mut Formatter<'_>) -> std::fmt::Result { - todo!() + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + let value = !self.0.as_bigint().is_zero(); + write!(f, "PyBool({})", value) } } @@ -221,5 +220,9 @@ pub(crate) fn init(context: &Context) { // Retrieve inner int value: pub(crate) fn get_value(obj: &PyObject) -> bool { - !obj.downcast_ref::().unwrap().as_bigint().is_zero() + !obj.downcast_ref::() + .unwrap() + .0 + .as_bigint() + .is_zero() } diff --git a/crates/vm/src/builtins/int.rs b/crates/vm/src/builtins/int.rs index b210d41823d..8fe85267cd0 100644 --- a/crates/vm/src/builtins/int.rs +++ b/crates/vm/src/builtins/int.rs @@ -256,7 +256,7 @@ impl PyInt { if cls.is(vm.ctx.types.int_type) { Ok(vm.ctx.new_int(value)) } else if cls.is(vm.ctx.types.bool_type) { - Ok(vm.ctx.new_bool(!value.into().eq(&BigInt::zero()))) + Ok(vm.ctx.new_bool(!value.into().eq(&BigInt::zero())).upcast()) } else { Self::from(value).into_ref_with_type(vm, cls) } diff --git a/crates/vm/src/class.rs b/crates/vm/src/class.rs index 8850de36152..6b5a02dea73 100644 --- a/crates/vm/src/class.rs +++ b/crates/vm/src/class.rs @@ -171,3 +171,18 @@ pub trait PyClassImpl: PyClassDef { slots } } + +/// Trait for Python subclasses that can provide a reference to their base type. +/// +/// This trait is automatically implemented by the `#[pyclass]` macro when +/// `base = SomeType` is specified. It provides safe reference access to the +/// base type's payload. +/// +/// For subclasses with `#[repr(transparent)]` +/// which enables ownership transfer via `into_base()`. +pub trait PySubclass: crate::PyPayload { + type Base: crate::PyPayload; + + /// Returns a reference to the base type's payload. + fn as_base(&self) -> &Self::Base; +} diff --git a/crates/vm/src/exception_group.rs b/crates/vm/src/exception_group.rs index 5eb011960e1..8d033b26110 100644 --- a/crates/vm/src/exception_group.rs +++ b/crates/vm/src/exception_group.rs @@ -46,7 +46,8 @@ pub(super) mod types { #[pyexception(name, base = PyBaseException, ctx = "base_exception_group")] #[derive(Debug)] - pub struct PyBaseExceptionGroup {} + #[repr(transparent)] + pub struct PyBaseExceptionGroup(PyBaseException); #[pyexception] impl PyBaseExceptionGroup { diff --git a/crates/vm/src/exceptions.rs b/crates/vm/src/exceptions.rs index 4a8bea23557..04dd78fb448 100644 --- a/crates/vm/src/exceptions.rs +++ b/crates/vm/src/exceptions.rs @@ -964,32 +964,8 @@ impl ExceptionZoo { extend_exception!(PyUnboundLocalError, ctx, excs.unbound_local_error); // os errors: - let errno_getter = - ctx.new_readonly_getset("errno", excs.os_error, |exc: PyBaseExceptionRef| { - let args = exc.args(); - args.first() - .filter(|_| args.len() > 1 && args.len() <= 5) - .cloned() - }); - let strerror_getter = - ctx.new_readonly_getset("strerror", excs.os_error, |exc: PyBaseExceptionRef| { - let args = exc.args(); - args.get(1) - .filter(|_| args.len() >= 2 && args.len() <= 5) - .cloned() - }); - extend_exception!(PyOSError, ctx, excs.os_error, { - // POSIX exception code - "errno" => errno_getter.clone(), - // exception strerror - "strerror" => strerror_getter.clone(), - // exception filename - "filename" => ctx.none(), - // second exception filename - "filename2" => ctx.none(), - }); - #[cfg(windows)] - excs.os_error.set_str_attr("winerror", ctx.none(), ctx); + // PyOSError now uses struct fields with pygetset, no need for dynamic attributes + extend_exception!(PyOSError, ctx, excs.os_error); extend_exception!(PyBlockingIOError, ctx, excs.blocking_io_error); extend_exception!(PyChildProcessError, ctx, excs.child_process_error); @@ -1219,11 +1195,14 @@ pub(crate) fn errno_to_exc_type(_errno: i32, _vm: &VirtualMachine) -> Option<&'s pub(super) mod types { use crate::common::lock::PyRwLock; + use crate::object::{MaybeTraverse, Traverse, TraverseFn}; #[cfg_attr(target_arch = "wasm32", allow(unused_imports))] use crate::{ - AsObject, PyObjectRef, PyRef, PyResult, VirtualMachine, + AsObject, Py, PyAtomicRef, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, + VirtualMachine, builtins::{ - PyInt, PyStrRef, PyTupleRef, PyTypeRef, traceback::PyTracebackRef, tuple::IntoPyTuple, + PyInt, PyStrRef, PyTupleRef, PyType, PyTypeRef, traceback::PyTracebackRef, + tuple::IntoPyTuple, }, convert::ToPyResult, function::{ArgBytesLike, FuncArgs}, @@ -1255,23 +1234,28 @@ pub(super) mod types { #[pyexception(name, base = PyBaseException, ctx = "system_exit", impl)] #[derive(Debug)] - pub struct PySystemExit {} + #[repr(transparent)] + pub struct PySystemExit(PyBaseException); #[pyexception(name, base = PyBaseException, ctx = "generator_exit", impl)] #[derive(Debug)] - pub struct PyGeneratorExit {} + #[repr(transparent)] + pub struct PyGeneratorExit(PyBaseException); #[pyexception(name, base = PyBaseException, ctx = "keyboard_interrupt", impl)] #[derive(Debug)] - pub struct PyKeyboardInterrupt {} + #[repr(transparent)] + pub struct PyKeyboardInterrupt(PyBaseException); #[pyexception(name, base = PyBaseException, ctx = "exception_type", impl)] #[derive(Debug)] - pub struct PyException {} + #[repr(transparent)] + pub struct PyException(PyBaseException); #[pyexception(name, base = PyException, ctx = "stop_iteration")] #[derive(Debug)] - pub struct PyStopIteration {} + #[repr(transparent)] + pub struct PyStopIteration(PyException); #[pyexception] impl PyStopIteration { @@ -1289,31 +1273,37 @@ pub(super) mod types { #[pyexception(name, base = PyException, ctx = "stop_async_iteration", impl)] #[derive(Debug)] - pub struct PyStopAsyncIteration {} + #[repr(transparent)] + pub struct PyStopAsyncIteration(PyException); #[pyexception(name, base = PyException, ctx = "arithmetic_error", impl)] #[derive(Debug)] - pub struct PyArithmeticError {} + #[repr(transparent)] + pub struct PyArithmeticError(PyException); #[pyexception(name, base = PyArithmeticError, ctx = "floating_point_error", impl)] #[derive(Debug)] - pub struct PyFloatingPointError {} - + #[repr(transparent)] + pub struct PyFloatingPointError(PyArithmeticError); #[pyexception(name, base = PyArithmeticError, ctx = "overflow_error", impl)] #[derive(Debug)] - pub struct PyOverflowError {} + #[repr(transparent)] + pub struct PyOverflowError(PyArithmeticError); #[pyexception(name, base = PyArithmeticError, ctx = "zero_division_error", impl)] #[derive(Debug)] - pub struct PyZeroDivisionError {} + #[repr(transparent)] + pub struct PyZeroDivisionError(PyArithmeticError); #[pyexception(name, base = PyException, ctx = "assertion_error", impl)] #[derive(Debug)] - pub struct PyAssertionError {} + #[repr(transparent)] + pub struct PyAssertionError(PyException); #[pyexception(name, base = PyException, ctx = "attribute_error")] #[derive(Debug)] - pub struct PyAttributeError {} + #[repr(transparent)] + pub struct PyAttributeError(PyException); #[pyexception] impl PyAttributeError { @@ -1340,15 +1330,18 @@ pub(super) mod types { #[pyexception(name, base = PyException, ctx = "buffer_error", impl)] #[derive(Debug)] - pub struct PyBufferError {} + #[repr(transparent)] + pub struct PyBufferError(PyException); #[pyexception(name, base = PyException, ctx = "eof_error", impl)] #[derive(Debug)] - pub struct PyEOFError {} + #[repr(transparent)] + pub struct PyEOFError(PyException); #[pyexception(name, base = PyException, ctx = "import_error")] #[derive(Debug)] - pub struct PyImportError {} + #[repr(transparent)] + pub struct PyImportError(PyException); #[pyexception] impl PyImportError { @@ -1393,19 +1386,23 @@ pub(super) mod types { #[pyexception(name, base = PyImportError, ctx = "module_not_found_error", impl)] #[derive(Debug)] - pub struct PyModuleNotFoundError {} + #[repr(transparent)] + pub struct PyModuleNotFoundError(PyImportError); #[pyexception(name, base = PyException, ctx = "lookup_error", impl)] #[derive(Debug)] - pub struct PyLookupError {} + #[repr(transparent)] + pub struct PyLookupError(PyException); #[pyexception(name, base = PyLookupError, ctx = "index_error", impl)] #[derive(Debug)] - pub struct PyIndexError {} + #[repr(transparent)] + pub struct PyIndexError(PyLookupError); #[pyexception(name, base = PyLookupError, ctx = "key_error")] #[derive(Debug)] - pub struct PyKeyError {} + #[repr(transparent)] + pub struct PyKeyError(PyLookupError); #[pyexception] impl PyKeyError { @@ -1425,80 +1422,180 @@ pub(super) mod types { #[pyexception(name, base = PyException, ctx = "memory_error", impl)] #[derive(Debug)] - pub struct PyMemoryError {} + #[repr(transparent)] + pub struct PyMemoryError(PyException); #[pyexception(name, base = PyException, ctx = "name_error", impl)] #[derive(Debug)] - pub struct PyNameError {} + #[repr(transparent)] + pub struct PyNameError(PyException); #[pyexception(name, base = PyNameError, ctx = "unbound_local_error", impl)] #[derive(Debug)] - pub struct PyUnboundLocalError {} + #[repr(transparent)] + pub struct PyUnboundLocalError(PyNameError); #[pyexception(name, base = PyException, ctx = "os_error")] - #[derive(Debug)] - pub struct PyOSError {} + pub struct PyOSError { + base: PyException, + errno: PyAtomicRef>, + strerror: PyAtomicRef>, + filename: PyAtomicRef>, + filename2: PyAtomicRef>, + #[cfg(windows)] + winerror: PyAtomicRef>, + } + + impl crate::class::PySubclass for PyOSError { + type Base = PyException; + fn as_base(&self) -> &Self::Base { + &self.base + } + } + + impl std::fmt::Debug for PyOSError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("PyOSError").finish_non_exhaustive() + } + } + + unsafe impl Traverse for PyOSError { + fn traverse(&self, tracer_fn: &mut TraverseFn<'_>) { + self.base.try_traverse(tracer_fn); + if let Some(obj) = self.errno.deref() { + tracer_fn(obj); + } + if let Some(obj) = self.strerror.deref() { + tracer_fn(obj); + } + if let Some(obj) = self.filename.deref() { + tracer_fn(obj); + } + if let Some(obj) = self.filename2.deref() { + tracer_fn(obj); + } + #[cfg(windows)] + if let Some(obj) = self.winerror.deref() { + tracer_fn(obj); + } + } + } // OS Errors: - #[pyexception] - impl PyOSError { - #[cfg(not(target_arch = "wasm32"))] - fn optional_new(args: Vec, vm: &VirtualMachine) -> Option { - let len = args.len(); - if (2..=5).contains(&len) { - let errno = &args[0]; - errno - .downcast_ref::() - .and_then(|errno| errno.try_to_primitive::(vm).ok()) - .and_then(|errno| super::errno_to_exc_type(errno, vm)) - .and_then(|typ| vm.invoke_exception(typ.to_owned(), args.to_vec()).ok()) + impl Constructor for PyOSError { + type Args = FuncArgs; + + fn py_new(_cls: &Py, args: FuncArgs, vm: &VirtualMachine) -> PyResult { + let len = args.args.len(); + // CPython only sets errno/strerror when args len is 2-5 + let (errno, strerror) = if (2..=5).contains(&len) { + (Some(args.args[0].clone()), Some(args.args[1].clone())) + } else { + (None, None) + }; + let filename = if (3..=5).contains(&len) { + Some(args.args[2].clone()) } else { None - } + }; + let filename2 = if len == 5 { + args.args.get(4).cloned() + } else { + None + }; + // Truncate args for base exception when 3-5 args + let base_args = if (3..=5).contains(&len) { + args.args[..2].to_vec() + } else { + args.args.to_vec() + }; + let base_exception = PyBaseException::new(base_args, vm); + Ok(Self { + base: PyException(base_exception), + errno: errno.into(), + strerror: strerror.into(), + filename: filename.into(), + filename2: filename2.into(), + #[cfg(windows)] + winerror: None.into(), + }) } - #[cfg(not(target_arch = "wasm32"))] - #[pyslot] - pub fn slot_new(cls: PyTypeRef, args: FuncArgs, vm: &VirtualMachine) -> PyResult { + + fn slot_new(cls: PyTypeRef, args: FuncArgs, vm: &VirtualMachine) -> PyResult { // We need this method, because of how `CPython` copies `init` // from `BaseException` in `SimpleExtendsException` macro. // See: `BaseException_new` if *cls.name() == *vm.ctx.exceptions.os_error.name() { - match Self::optional_new(args.args.to_vec(), vm) { - Some(error) => error.to_pyresult(vm), - None => PyBaseException::slot_new(cls, args, vm), + let args_vec = args.args.to_vec(); + let len = args_vec.len(); + if (2..=5).contains(&len) { + let errno = &args_vec[0]; + if let Some(error) = errno + .downcast_ref::() + .and_then(|errno| errno.try_to_primitive::(vm).ok()) + .and_then(|errno| super::errno_to_exc_type(errno, vm)) + .and_then(|typ| vm.invoke_exception(typ.to_owned(), args_vec).ok()) + { + return error.to_pyresult(vm); + } } - } else { - PyBaseException::slot_new(cls, args, vm) } + let payload = Self::py_new(&cls, args, vm)?; + payload.into_ref_with_type(vm, cls).map(Into::into) } - #[cfg(target_arch = "wasm32")] - #[pyslot] - pub fn slot_new(cls: PyTypeRef, args: FuncArgs, vm: &VirtualMachine) -> PyResult { - PyBaseException::slot_new(cls, args, vm) - } + } + + #[pyexception(with(Constructor))] + impl PyOSError { #[pyslot] #[pymethod(name = "__init__")] pub fn slot_init(zelf: PyObjectRef, args: FuncArgs, vm: &VirtualMachine) -> PyResult<()> { let len = args.args.len(); let mut new_args = args; - if (3..=5).contains(&len) { - zelf.set_attr("filename", new_args.args[2].clone(), vm)?; + + // All OSError subclasses use #[repr(transparent)] wrapping PyOSError, + // so we can safely access the PyOSError fields through pointer cast + // SAFETY: All OSError subclasses (FileNotFoundError, etc.) are + // #[repr(transparent)] wrappers around PyOSError with identical memory layout + #[allow(deprecated)] + let exc: &Py = zelf.downcast_ref::().unwrap(); + + // SAFETY: slot_init is called during object initialization, + // so fields are None and swap result can be safely ignored + if len <= 5 { + // Only set errno/strerror when args len is 2-5 (CPython behavior) + if 2 <= len { + let _ = unsafe { exc.errno.swap(Some(new_args.args[0].clone())) }; + let _ = unsafe { exc.strerror.swap(Some(new_args.args[1].clone())) }; + } + if 3 <= len { + let _ = unsafe { exc.filename.swap(Some(new_args.args[2].clone())) }; + } #[cfg(windows)] - if let Some(winerror) = new_args.args.get(3) { - zelf.set_attr("winerror", winerror.clone(), vm)?; - // Convert winerror to errno and replace args[0] (CPython behavior) - if let Some(winerror_int) = winerror - .downcast_ref::() + if 4 <= len { + let winerror = new_args.args.get(3).cloned(); + // Store original winerror + let _ = unsafe { exc.winerror.swap(winerror.clone()) }; + + // Convert winerror to errno and update errno + args[0] + if let Some(errno) = winerror + .as_ref() + .and_then(|w| w.downcast_ref::()) .and_then(|w| w.try_to_primitive::(vm).ok()) + .map(crate::common::os::winerror_to_errno) { - let errno = crate::common::os::winerror_to_errno(winerror_int); - new_args.args[0] = vm.new_pyobj(errno); + let errno_obj = vm.new_pyobj(errno); + let _ = unsafe { exc.errno.swap(Some(errno_obj.clone())) }; + new_args.args[0] = errno_obj; } } - if let Some(filename2) = new_args.args.get(4) { - zelf.set_attr("filename2", filename2.clone(), vm)?; + if len == 5 { + let _ = unsafe { exc.filename2.swap(new_args.args.get(4).cloned()) }; } + } + // args are truncated to 2 for compatibility (only when 2-5 args) + if (3..=5).contains(&len) { new_args.args.truncate(2); } PyBaseException::slot_init(zelf, new_args, vm) @@ -1582,23 +1679,80 @@ pub(super) mod types { } result.into_pytuple(vm) } + + // Getters and setters for OSError fields + #[pygetset] + fn errno(&self) -> Option { + self.errno.to_owned() + } + + #[pygetset(setter)] + fn set_errno(&self, value: Option, vm: &VirtualMachine) { + self.errno.swap_to_temporary_refs(value, vm); + } + + #[pygetset(name = "strerror")] + fn get_strerror(&self) -> Option { + self.strerror.to_owned() + } + + #[pygetset(setter, name = "strerror")] + fn set_strerror(&self, value: Option, vm: &VirtualMachine) { + self.strerror.swap_to_temporary_refs(value, vm); + } + + #[pygetset] + fn filename(&self) -> Option { + self.filename.to_owned() + } + + #[pygetset(setter)] + fn set_filename(&self, value: Option, vm: &VirtualMachine) { + self.filename.swap_to_temporary_refs(value, vm); + } + + #[pygetset] + fn filename2(&self) -> Option { + self.filename2.to_owned() + } + + #[pygetset(setter)] + fn set_filename2(&self, value: Option, vm: &VirtualMachine) { + self.filename2.swap_to_temporary_refs(value, vm); + } + + #[cfg(windows)] + #[pygetset] + fn winerror(&self) -> Option { + self.winerror.to_owned() + } + + #[cfg(windows)] + #[pygetset(setter)] + fn set_winerror(&self, value: Option, vm: &VirtualMachine) { + self.winerror.swap_to_temporary_refs(value, vm); + } } #[pyexception(name, base = PyOSError, ctx = "blocking_io_error", impl)] + #[repr(transparent)] #[derive(Debug)] - pub struct PyBlockingIOError {} + pub struct PyBlockingIOError(PyOSError); #[pyexception(name, base = PyOSError, ctx = "child_process_error", impl)] + #[repr(transparent)] #[derive(Debug)] - pub struct PyChildProcessError {} + pub struct PyChildProcessError(PyOSError); #[pyexception(name, base = PyOSError, ctx = "connection_error", impl)] + #[repr(transparent)] #[derive(Debug)] - pub struct PyConnectionError {} + pub struct PyConnectionError(PyOSError); #[pyexception(name, base = PyConnectionError, ctx = "broken_pipe_error", impl)] + #[repr(transparent)] #[derive(Debug)] - pub struct PyBrokenPipeError {} + pub struct PyBrokenPipeError(PyConnectionError); #[pyexception( name, @@ -1606,8 +1760,9 @@ pub(super) mod types { ctx = "connection_aborted_error", impl )] + #[repr(transparent)] #[derive(Debug)] - pub struct PyConnectionAbortedError {} + pub struct PyConnectionAbortedError(PyConnectionError); #[pyexception( name, @@ -1615,64 +1770,79 @@ pub(super) mod types { ctx = "connection_refused_error", impl )] + #[repr(transparent)] #[derive(Debug)] - pub struct PyConnectionRefusedError {} + pub struct PyConnectionRefusedError(PyConnectionError); #[pyexception(name, base = PyConnectionError, ctx = "connection_reset_error", impl)] + #[repr(transparent)] #[derive(Debug)] - pub struct PyConnectionResetError {} + pub struct PyConnectionResetError(PyConnectionError); #[pyexception(name, base = PyOSError, ctx = "file_exists_error", impl)] + #[repr(transparent)] #[derive(Debug)] - pub struct PyFileExistsError {} + pub struct PyFileExistsError(PyOSError); #[pyexception(name, base = PyOSError, ctx = "file_not_found_error", impl)] + #[repr(transparent)] #[derive(Debug)] - pub struct PyFileNotFoundError {} + pub struct PyFileNotFoundError(PyOSError); #[pyexception(name, base = PyOSError, ctx = "interrupted_error", impl)] + #[repr(transparent)] #[derive(Debug)] - pub struct PyInterruptedError {} + pub struct PyInterruptedError(PyOSError); #[pyexception(name, base = PyOSError, ctx = "is_a_directory_error", impl)] + #[repr(transparent)] #[derive(Debug)] - pub struct PyIsADirectoryError {} + pub struct PyIsADirectoryError(PyOSError); #[pyexception(name, base = PyOSError, ctx = "not_a_directory_error", impl)] + #[repr(transparent)] #[derive(Debug)] - pub struct PyNotADirectoryError {} + pub struct PyNotADirectoryError(PyOSError); #[pyexception(name, base = PyOSError, ctx = "permission_error", impl)] + #[repr(transparent)] #[derive(Debug)] - pub struct PyPermissionError {} + pub struct PyPermissionError(PyOSError); #[pyexception(name, base = PyOSError, ctx = "process_lookup_error", impl)] + #[repr(transparent)] #[derive(Debug)] - pub struct PyProcessLookupError {} + pub struct PyProcessLookupError(PyOSError); #[pyexception(name, base = PyOSError, ctx = "timeout_error", impl)] #[derive(Debug)] - pub struct PyTimeoutError {} + #[repr(transparent)] + pub struct PyTimeoutError(PyOSError); #[pyexception(name, base = PyException, ctx = "reference_error", impl)] #[derive(Debug)] - pub struct PyReferenceError {} + #[repr(transparent)] + pub struct PyReferenceError(PyException); #[pyexception(name, base = PyException, ctx = "runtime_error", impl)] #[derive(Debug)] - pub struct PyRuntimeError {} + #[repr(transparent)] + pub struct PyRuntimeError(PyException); #[pyexception(name, base = PyRuntimeError, ctx = "not_implemented_error", impl)] #[derive(Debug)] - pub struct PyNotImplementedError {} + #[repr(transparent)] + pub struct PyNotImplementedError(PyRuntimeError); #[pyexception(name, base = PyRuntimeError, ctx = "recursion_error", impl)] #[derive(Debug)] - pub struct PyRecursionError {} + #[repr(transparent)] + pub struct PyRecursionError(PyRuntimeError); #[pyexception(name, base = PyException, ctx = "syntax_error")] #[derive(Debug)] - pub struct PySyntaxError {} + #[repr(transparent)] + pub struct PySyntaxError(PyException); #[pyexception] impl PySyntaxError { @@ -1766,7 +1936,8 @@ pub(super) mod types { ctx = "incomplete_input_error" )] #[derive(Debug)] - pub struct PyIncompleteInputError {} + #[repr(transparent)] + pub struct PyIncompleteInputError(PySyntaxError); #[pyexception] impl PyIncompleteInputError { @@ -1784,31 +1955,38 @@ pub(super) mod types { #[pyexception(name, base = PySyntaxError, ctx = "indentation_error", impl)] #[derive(Debug)] - pub struct PyIndentationError {} + #[repr(transparent)] + pub struct PyIndentationError(PySyntaxError); #[pyexception(name, base = PyIndentationError, ctx = "tab_error", impl)] #[derive(Debug)] - pub struct PyTabError {} + #[repr(transparent)] + pub struct PyTabError(PyIndentationError); #[pyexception(name, base = PyException, ctx = "system_error", impl)] #[derive(Debug)] - pub struct PySystemError {} + #[repr(transparent)] + pub struct PySystemError(PyException); #[pyexception(name, base = PyException, ctx = "type_error", impl)] #[derive(Debug)] - pub struct PyTypeError {} + #[repr(transparent)] + pub struct PyTypeError(PyException); #[pyexception(name, base = PyException, ctx = "value_error", impl)] #[derive(Debug)] - pub struct PyValueError {} + #[repr(transparent)] + pub struct PyValueError(PyException); #[pyexception(name, base = PyValueError, ctx = "unicode_error", impl)] #[derive(Debug)] - pub struct PyUnicodeError {} + #[repr(transparent)] + pub struct PyUnicodeError(PyValueError); #[pyexception(name, base = PyUnicodeError, ctx = "unicode_decode_error")] #[derive(Debug)] - pub struct PyUnicodeDecodeError {} + #[repr(transparent)] + pub struct PyUnicodeDecodeError(PyUnicodeError); #[pyexception] impl PyUnicodeDecodeError { @@ -1859,7 +2037,8 @@ pub(super) mod types { #[pyexception(name, base = PyUnicodeError, ctx = "unicode_encode_error")] #[derive(Debug)] - pub struct PyUnicodeEncodeError {} + #[repr(transparent)] + pub struct PyUnicodeEncodeError(PyUnicodeError); #[pyexception] impl PyUnicodeEncodeError { @@ -1910,7 +2089,8 @@ pub(super) mod types { #[pyexception(name, base = PyUnicodeError, ctx = "unicode_translate_error")] #[derive(Debug)] - pub struct PyUnicodeTranslateError {} + #[repr(transparent)] + pub struct PyUnicodeTranslateError(PyUnicodeError); #[pyexception] impl PyUnicodeTranslateError { @@ -1958,54 +2138,67 @@ pub(super) mod types { #[cfg(feature = "jit")] #[pyexception(name, base = PyException, ctx = "jit_error", impl)] #[derive(Debug)] - pub struct PyJitError {} + #[repr(transparent)] + pub struct PyJitError(PyException); // Warnings #[pyexception(name, base = PyException, ctx = "warning", impl)] #[derive(Debug)] - pub struct PyWarning {} + #[repr(transparent)] + pub struct PyWarning(PyException); #[pyexception(name, base = PyWarning, ctx = "deprecation_warning", impl)] #[derive(Debug)] - pub struct PyDeprecationWarning {} + #[repr(transparent)] + pub struct PyDeprecationWarning(PyWarning); #[pyexception(name, base = PyWarning, ctx = "pending_deprecation_warning", impl)] #[derive(Debug)] - pub struct PyPendingDeprecationWarning {} + #[repr(transparent)] + pub struct PyPendingDeprecationWarning(PyWarning); #[pyexception(name, base = PyWarning, ctx = "runtime_warning", impl)] #[derive(Debug)] - pub struct PyRuntimeWarning {} + #[repr(transparent)] + pub struct PyRuntimeWarning(PyWarning); #[pyexception(name, base = PyWarning, ctx = "syntax_warning", impl)] #[derive(Debug)] - pub struct PySyntaxWarning {} + #[repr(transparent)] + pub struct PySyntaxWarning(PyWarning); #[pyexception(name, base = PyWarning, ctx = "user_warning", impl)] #[derive(Debug)] - pub struct PyUserWarning {} + #[repr(transparent)] + pub struct PyUserWarning(PyWarning); #[pyexception(name, base = PyWarning, ctx = "future_warning", impl)] #[derive(Debug)] - pub struct PyFutureWarning {} + #[repr(transparent)] + pub struct PyFutureWarning(PyWarning); #[pyexception(name, base = PyWarning, ctx = "import_warning", impl)] #[derive(Debug)] - pub struct PyImportWarning {} + #[repr(transparent)] + pub struct PyImportWarning(PyWarning); #[pyexception(name, base = PyWarning, ctx = "unicode_warning", impl)] #[derive(Debug)] - pub struct PyUnicodeWarning {} + #[repr(transparent)] + pub struct PyUnicodeWarning(PyWarning); #[pyexception(name, base = PyWarning, ctx = "bytes_warning", impl)] #[derive(Debug)] - pub struct PyBytesWarning {} + #[repr(transparent)] + pub struct PyBytesWarning(PyWarning); #[pyexception(name, base = PyWarning, ctx = "resource_warning", impl)] #[derive(Debug)] - pub struct PyResourceWarning {} + #[repr(transparent)] + pub struct PyResourceWarning(PyWarning); #[pyexception(name, base = PyWarning, ctx = "encoding_warning", impl)] #[derive(Debug)] - pub struct PyEncodingWarning {} + #[repr(transparent)] + pub struct PyEncodingWarning(PyWarning); } diff --git a/crates/vm/src/object/core.rs b/crates/vm/src/object/core.rs index 6530abdbaba..e04b87de594 100644 --- a/crates/vm/src/object/core.rs +++ b/crates/vm/src/object/core.rs @@ -15,7 +15,6 @@ use super::{ ext::{AsObject, PyRefExact, PyResult}, payload::PyPayload, }; -use crate::object::traverse::{MaybeTraverse, Traverse, TraverseFn}; use crate::object::traverse_object::PyObjVTable; use crate::{ builtins::{PyDictRef, PyType, PyTypeRef}, @@ -27,6 +26,10 @@ use crate::{ }, vm::VirtualMachine, }; +use crate::{ + class::StaticType, + object::traverse::{MaybeTraverse, Traverse, TraverseFn}, +}; use itertools::Itertools; use std::{ any::TypeId, @@ -1070,6 +1073,35 @@ impl PyRef { } } +impl PyRef +where + T::Base: std::fmt::Debug, +{ + /// Converts this reference to the base type (ownership transfer). + /// # Safety + /// T and T::Base must have compatible layouts in size_of::() bytes. + #[inline] + pub fn into_base(self) -> PyRef { + let obj: PyObjectRef = self.into(); + match obj.downcast() { + Ok(base_ref) => base_ref, + Err(_) => unsafe { std::hint::unreachable_unchecked() }, + } + } + #[inline] + pub fn upcast(self) -> PyRef + where + T: StaticType, + { + debug_assert!(T::static_type().is_subtype(U::static_type())); + let obj: PyObjectRef = self.into(); + match obj.downcast::() { + Ok(upcast_ref) => upcast_ref, + Err(_) => unsafe { std::hint::unreachable_unchecked() }, + } + } +} + impl Borrow for PyRef where T: PyPayload, diff --git a/crates/vm/src/object/payload.rs b/crates/vm/src/object/payload.rs index b2211761492..cf903871179 100644 --- a/crates/vm/src/object/payload.rs +++ b/crates/vm/src/object/payload.rs @@ -16,6 +16,15 @@ cfg_if::cfg_if! { } } +#[cold] +pub(crate) fn cold_downcast_type_error( + vm: &VirtualMachine, + class: &Py, + obj: &PyObject, +) -> PyBaseExceptionRef { + vm.new_downcast_type_error(class, obj) +} + pub trait PyPayload: MaybeTraverse + PyThreadingConstraint + Sized + 'static { #[inline] fn payload_type_id() -> std::any::TypeId { @@ -38,17 +47,8 @@ pub trait PyPayload: MaybeTraverse + PyThreadingConstraint + Sized + 'static { return Ok(()); } - #[cold] - fn raise_downcast_type_error( - vm: &VirtualMachine, - class: &Py, - obj: &PyObject, - ) -> PyBaseExceptionRef { - vm.new_downcast_type_error(class, obj) - } - let class = Self::class(&vm.ctx); - Err(raise_downcast_type_error(vm, class, obj)) + Err(cold_downcast_type_error(vm, class, obj)) } fn class(ctx: &Context) -> &'static Py; @@ -101,6 +101,22 @@ pub trait PyPayload: MaybeTraverse + PyThreadingConstraint + Sized + 'static { { let exact_class = Self::class(&vm.ctx); if cls.fast_issubclass(exact_class) { + if exact_class.slots.basicsize != cls.slots.basicsize { + #[cold] + #[inline(never)] + fn _into_ref_size_error( + vm: &VirtualMachine, + cls: &PyTypeRef, + exact_class: &Py, + ) -> PyBaseExceptionRef { + vm.new_type_error(format!( + "cannot create '{}' instance: size differs from base type '{}'", + cls.name(), + exact_class.name() + )) + } + return Err(_into_ref_size_error(vm, &cls, exact_class)); + } Ok(self._into_ref(cls, &vm.ctx)) } else { #[cold] diff --git a/crates/vm/src/stdlib/ast/pyast.rs b/crates/vm/src/stdlib/ast/pyast.rs index b891a605dc2..e36635fe4b9 100644 --- a/crates/vm/src/stdlib/ast/pyast.rs +++ b/crates/vm/src/stdlib/ast/pyast.rs @@ -5,13 +5,14 @@ use crate::common::ascii; macro_rules! impl_node { ( - $(#[$meta:meta])* + #[pyclass(module = $_mod:literal, name = $_name:literal, base = $base:ty)] $vis:vis struct $name:ident, fields: [$($field:expr),* $(,)?], attributes: [$($attr:expr),* $(,)?] $(,)? ) => { - $(#[$meta])* - $vis struct $name; + #[pyclass(module = $_mod, name = $_name, base = $base)] + #[repr(transparent)] + $vis struct $name($base); #[pyclass(flags(HAS_DICT, BASETYPE))] impl $name { @@ -39,12 +40,12 @@ macro_rules! impl_node { }; // Without attributes ( - $(#[$meta:meta])* + #[pyclass(module = $_mod:literal, name = $_name:literal, base = $base:ty)] $vis:vis struct $name:ident, fields: [$($field:expr),* $(,)?] $(,)? ) => { impl_node!( - $(#[$meta])* + #[pyclass(module = $_mod, name = $_name, base = $base)] $vis struct $name, fields: [$($field),*], attributes: [], @@ -52,12 +53,12 @@ macro_rules! impl_node { }; // Without fields ( - $(#[$meta:meta])* + #[pyclass(module = $_mod:literal, name = $_name:literal, base = $base:ty)] $vis:vis struct $name:ident, attributes: [$($attr:expr),* $(,)?] $(,)? ) => { impl_node!( - $(#[$meta])* + #[pyclass(module = $_mod, name = $_name, base = $base)] $vis struct $name, fields: [], attributes: [$($attr),*], @@ -65,11 +66,11 @@ macro_rules! impl_node { }; // Without fields and attributes ( - $(#[$meta:meta])* + #[pyclass(module = $_mod:literal, name = $_name:literal, base = $base:ty)] $vis:vis struct $name:ident $(,)? ) => { impl_node!( - $(#[$meta])* + #[pyclass(module = $_mod, name = $_name, base = $base)] $vis struct $name, fields: [], attributes: [], @@ -78,7 +79,7 @@ macro_rules! impl_node { } #[pyclass(module = "_ast", name = "mod", base = NodeAst)] -pub(crate) struct NodeMod; +pub(crate) struct NodeMod(NodeAst); #[pyclass(flags(HAS_DICT, BASETYPE))] impl NodeMod {} @@ -102,7 +103,8 @@ impl_node!( ); #[pyclass(module = "_ast", name = "stmt", base = NodeAst)] -pub(crate) struct NodeStmt; +#[repr(transparent)] +pub(crate) struct NodeStmt(NodeAst); #[pyclass(flags(HAS_DICT, BASETYPE))] impl NodeStmt {} @@ -301,7 +303,8 @@ impl_node!( ); #[pyclass(module = "_ast", name = "expr", base = NodeAst)] -pub(crate) struct NodeExpr; +#[repr(transparent)] +pub(crate) struct NodeExpr(NodeAst); #[pyclass(flags(HAS_DICT, BASETYPE))] impl NodeExpr {} @@ -495,7 +498,8 @@ impl_node!( ); #[pyclass(module = "_ast", name = "expr_context", base = NodeAst)] -pub(crate) struct NodeExprContext; +#[repr(transparent)] +pub(crate) struct NodeExprContext(NodeAst); #[pyclass(flags(HAS_DICT, BASETYPE))] impl NodeExprContext {} @@ -518,7 +522,8 @@ impl_node!( ); #[pyclass(module = "_ast", name = "boolop", base = NodeAst)] -pub(crate) struct NodeBoolOp; +#[repr(transparent)] +pub(crate) struct NodeBoolOp(NodeAst); #[pyclass(flags(HAS_DICT, BASETYPE))] impl NodeBoolOp {} @@ -534,7 +539,8 @@ impl_node!( ); #[pyclass(module = "_ast", name = "operator", base = NodeAst)] -pub(crate) struct NodeOperator; +#[repr(transparent)] +pub(crate) struct NodeOperator(NodeAst); #[pyclass(flags(HAS_DICT, BASETYPE))] impl NodeOperator {} @@ -605,7 +611,8 @@ impl_node!( ); #[pyclass(module = "_ast", name = "unaryop", base = NodeAst)] -pub(crate) struct NodeUnaryOp; +#[repr(transparent)] +pub(crate) struct NodeUnaryOp(NodeAst); #[pyclass(flags(HAS_DICT, BASETYPE))] impl NodeUnaryOp {} @@ -631,7 +638,8 @@ impl_node!( ); #[pyclass(module = "_ast", name = "cmpop", base = NodeAst)] -pub(crate) struct NodeCmpOp; +#[repr(transparent)] +pub(crate) struct NodeCmpOp(NodeAst); #[pyclass(flags(HAS_DICT, BASETYPE))] impl NodeCmpOp {} @@ -692,7 +700,8 @@ impl_node!( ); #[pyclass(module = "_ast", name = "excepthandler", base = NodeAst)] -pub(crate) struct NodeExceptHandler; +#[repr(transparent)] +pub(crate) struct NodeExceptHandler(NodeAst); #[pyclass(flags(HAS_DICT, BASETYPE))] impl NodeExceptHandler {} @@ -744,7 +753,8 @@ impl_node!( ); #[pyclass(module = "_ast", name = "pattern", base = NodeAst)] -pub(crate) struct NodePattern; +#[repr(transparent)] +pub(crate) struct NodePattern(NodeAst); #[pyclass(flags(HAS_DICT, BASETYPE))] impl NodePattern {} @@ -805,7 +815,8 @@ impl_node!( ); #[pyclass(module = "_ast", name = "type_ignore", base = NodeAst)] -pub(crate) struct NodeTypeIgnore; +#[repr(transparent)] +pub(crate) struct NodeTypeIgnore(NodeAst); #[pyclass(flags(HAS_DICT, BASETYPE))] impl NodeTypeIgnore {} @@ -818,7 +829,8 @@ impl_node!( ); #[pyclass(module = "_ast", name = "type_param", base = NodeAst)] -pub(crate) struct NodeTypeParam; +#[repr(transparent)] +pub(crate) struct NodeTypeParam(NodeAst); #[pyclass(flags(HAS_DICT, BASETYPE))] impl NodeTypeParam {} diff --git a/crates/vm/src/stdlib/ctypes.rs b/crates/vm/src/stdlib/ctypes.rs index 2daaa6f3abd..70aee7378d3 100644 --- a/crates/vm/src/stdlib/ctypes.rs +++ b/crates/vm/src/stdlib/ctypes.rs @@ -15,7 +15,7 @@ use crate::builtins::PyModule; use crate::class::PyClassImpl; use crate::{Py, PyRef, VirtualMachine}; -pub use crate::stdlib::ctypes::base::{PyCData, PyCSimple, PyCSimpleType}; +pub use crate::stdlib::ctypes::base::{CDataObject, PyCData, PyCSimple, PyCSimpleType}; pub fn extend_module_nodes(vm: &VirtualMachine, module: &Py) { let ctx = &vm.ctx; @@ -47,7 +47,7 @@ pub(crate) fn make_module(vm: &VirtualMachine) -> PyRef { #[pymodule] pub(crate) mod _ctypes { - use super::base::{CDataObject, PyCSimple}; + use super::base::{CDataObject, PyCData, PyCSimple}; use crate::builtins::PyTypeRef; use crate::class::StaticType; use crate::convert::ToPyObject; @@ -369,13 +369,12 @@ pub(crate) mod _ctypes { Err(vm.new_attribute_error(format!("class must define a '_type_' attribute which must be\n a single character string containing one of {SIMPLE_TYPE_CHARS}, currently it is {tp_str}."))) } else { let size = get_size(&tp_str); + let cdata = CDataObject::from_bytes(vec![0u8; size], None); Ok(PyCSimple { + _base: PyCData::new(cdata.clone()), _type_: tp_str, value: AtomicCell::new(vm.ctx.none()), - cdata: rustpython_common::lock::PyRwLock::new(CDataObject::from_bytes( - vec![0u8; size], - None, - )), + cdata: rustpython_common::lock::PyRwLock::new(cdata), }) } } else { diff --git a/crates/vm/src/stdlib/ctypes/array.rs b/crates/vm/src/stdlib/ctypes/array.rs index 98274e388bf..fe12a781d9f 100644 --- a/crates/vm/src/stdlib/ctypes/array.rs +++ b/crates/vm/src/stdlib/ctypes/array.rs @@ -23,8 +23,9 @@ use rustpython_vm::stdlib::ctypes::base::PyCData; /// PyCArrayType - metatype for Array types /// CPython stores array info (type, length) in StgInfo via type_data #[pyclass(name = "PyCArrayType", base = PyType, module = "_ctypes")] -#[derive(Debug, Default, PyPayload)] -pub struct PyCArrayType {} +#[derive(Debug)] +#[repr(transparent)] +pub struct PyCArrayType(PyType); /// Create a new Array type with StgInfo stored in type_data (CPython style) pub fn create_array_type_with_stg_info(stg_info: StgInfo, vm: &VirtualMachine) -> PyResult { @@ -194,11 +195,13 @@ impl PyCArrayType { }; // Create instance + let cdata = CDataObject::from_bytes(data, None); let instance = PyCArray { + _base: PyCData::new(cdata.clone()), typ: PyRwLock::new(element_type), length: AtomicCell::new(length), element_size: AtomicCell::new(element_size), - cdata: PyRwLock::new(CDataObject::from_bytes(data, None)), + cdata: PyRwLock::new(cdata), } .into_pyobject(vm); @@ -235,8 +238,8 @@ impl AsNumber for PyCArrayType { metaclass = "PyCArrayType", module = "_ctypes" )] -#[derive(PyPayload)] pub struct PyCArray { + _base: PyCData, /// Element type - can be a simple type (c_int) or an array type (c_int * 5) pub(super) typ: PyRwLock, pub(super) length: AtomicCell, @@ -301,11 +304,13 @@ impl Constructor for PyCArray { } } + let cdata = CDataObject::from_bytes(buffer, None); PyCArray { + _base: PyCData::new(cdata.clone()), typ: PyRwLock::new(element_type), length: AtomicCell::new(length), element_size: AtomicCell::new(element_size), - cdata: PyRwLock::new(CDataObject::from_bytes(buffer, None)), + cdata: PyRwLock::new(cdata), } .into_ref_with_type(vm, cls) .map(Into::into) @@ -530,11 +535,13 @@ impl PyCArray { .unwrap_or(0); let element_size = if length > 0 { size / length } else { 0 }; + let cdata = CDataObject::from_bytes(bytes.to_vec(), None); Ok(PyCArray { + _base: PyCData::new(cdata.clone()), typ: PyRwLock::new(element_type.into()), length: AtomicCell::new(length), element_size: AtomicCell::new(element_size), - cdata: PyRwLock::new(CDataObject::from_bytes(bytes.to_vec(), None)), + cdata: PyRwLock::new(cdata), } .into_pyobject(vm)) } @@ -596,14 +603,13 @@ impl PyCArray { .unwrap_or(0); let element_size = if length > 0 { size / length } else { 0 }; + let cdata = CDataObject::from_bytes(data.to_vec(), Some(buffer.obj.clone())); Ok(PyCArray { + _base: PyCData::new(cdata.clone()), typ: PyRwLock::new(element_type.into()), length: AtomicCell::new(length), element_size: AtomicCell::new(element_size), - cdata: PyRwLock::new(CDataObject::from_bytes( - data.to_vec(), - Some(buffer.obj.clone()), - )), + cdata: PyRwLock::new(cdata), } .into_pyobject(vm)) } @@ -656,11 +662,13 @@ impl PyCArray { .unwrap_or(0); let element_size = if length > 0 { size / length } else { 0 }; + let cdata = CDataObject::from_bytes(data.to_vec(), None); Ok(PyCArray { + _base: PyCData::new(cdata.clone()), typ: PyRwLock::new(element_type.into()), length: AtomicCell::new(length), element_size: AtomicCell::new(element_size), - cdata: PyRwLock::new(CDataObject::from_bytes(data.to_vec(), None)), + cdata: PyRwLock::new(cdata), } .into_pyobject(vm)) } @@ -741,11 +749,13 @@ impl PyCArray { let element_size = if length > 0 { size / length } else { 0 }; // Create instance + let cdata = CDataObject::from_bytes(data, None); let instance = PyCArray { + _base: PyCData::new(cdata.clone()), typ: PyRwLock::new(element_type.into()), length: AtomicCell::new(length), element_size: AtomicCell::new(element_size), - cdata: PyRwLock::new(CDataObject::from_bytes(data, None)), + cdata: PyRwLock::new(cdata), } .into_pyobject(vm); diff --git a/crates/vm/src/stdlib/ctypes/base.rs b/crates/vm/src/stdlib/ctypes/base.rs index 80d42fba3ce..a4664ad3671 100644 --- a/crates/vm/src/stdlib/ctypes/base.rs +++ b/crates/vm/src/stdlib/ctypes/base.rs @@ -219,6 +219,14 @@ pub struct PyCData { pub cdata: PyRwLock, } +impl PyCData { + pub fn new(cdata: CDataObject) -> Self { + Self { + cdata: PyRwLock::new(cdata), + } + } +} + #[pyclass(flags(BASETYPE))] impl PyCData { #[pygetset] @@ -228,8 +236,9 @@ impl PyCData { } #[pyclass(module = "_ctypes", name = "PyCSimpleType", base = PyType)] -#[derive(Debug, PyPayload, Default)] -pub struct PyCSimpleType {} +#[derive(Debug)] +#[repr(transparent)] +pub struct PyCSimpleType(PyType); #[pyclass(flags(BASETYPE), with(AsNumber))] impl PyCSimpleType { @@ -411,8 +420,8 @@ impl AsNumber for PyCSimpleType { base = PyCData, metaclass = "PyCSimpleType" )] -#[derive(PyPayload)] pub struct PyCSimple { + pub _base: PyCData, pub _type_: String, pub value: AtomicCell, pub cdata: PyRwLock, @@ -647,10 +656,12 @@ impl Constructor for PyCSimple { .unwrap_or(false); let buffer = value_to_bytes_endian(&_type_, &value, swapped, vm); + let cdata = CDataObject::from_bytes(buffer, None); PyCSimple { + _base: PyCData::new(cdata.clone()), _type_, value: AtomicCell::new(value), - cdata: PyRwLock::new(CDataObject::from_bytes(buffer, None)), + cdata: PyRwLock::new(cdata), } .into_ref_with_type(vm, cls) .map(Into::into) diff --git a/crates/vm/src/stdlib/ctypes/field.rs b/crates/vm/src/stdlib/ctypes/field.rs index 8d6da5808a7..e760f07d035 100644 --- a/crates/vm/src/stdlib/ctypes/field.rs +++ b/crates/vm/src/stdlib/ctypes/field.rs @@ -8,8 +8,9 @@ use super::structure::PyCStructure; use super::union::PyCUnion; #[pyclass(name = "PyCFieldType", base = PyType, module = "_ctypes")] -#[derive(PyPayload, Debug)] +#[derive(Debug)] pub struct PyCFieldType { + pub _base: PyType, #[allow(dead_code)] pub(super) inner: PyCField, } diff --git a/crates/vm/src/stdlib/ctypes/function.rs b/crates/vm/src/stdlib/ctypes/function.rs index d202410b14a..b4e600f77ba 100644 --- a/crates/vm/src/stdlib/ctypes/function.rs +++ b/crates/vm/src/stdlib/ctypes/function.rs @@ -4,7 +4,7 @@ use crate::builtins::{PyNone, PyStr, PyTuple, PyTupleRef, PyType, PyTypeRef}; use crate::convert::ToPyObject; use crate::function::FuncArgs; use crate::stdlib::ctypes::PyCData; -use crate::stdlib::ctypes::base::{PyCSimple, ffi_type_from_str}; +use crate::stdlib::ctypes::base::{CDataObject, PyCSimple, ffi_type_from_str}; use crate::stdlib::ctypes::thunk::PyCThunk; use crate::types::Representable; use crate::types::{Callable, Constructor}; @@ -162,8 +162,8 @@ impl ReturnType for PyNone { } #[pyclass(module = "_ctypes", name = "CFuncPtr", base = PyCData)] -#[derive(PyPayload)] pub struct PyCFuncPtr { + _base: PyCData, pub name: PyRwLock>, pub ptr: PyRwLock>, #[allow(dead_code)] @@ -194,6 +194,7 @@ impl Constructor for PyCFuncPtr { if args.args.is_empty() { return PyCFuncPtr { + _base: PyCData::new(CDataObject::from_bytes(vec![], None)), ptr: PyRwLock::new(None), needs_free: AtomicCell::new(false), arg_types: PyRwLock::new(None), @@ -212,6 +213,7 @@ impl Constructor for PyCFuncPtr { if let Ok(addr) = first_arg.try_int(vm) { let ptr_val = addr.as_bigint().to_usize().unwrap_or(0); return PyCFuncPtr { + _base: PyCData::new(CDataObject::from_bytes(vec![], None)), ptr: PyRwLock::new(Some(CodePtr(ptr_val as *mut _))), needs_free: AtomicCell::new(false), arg_types: PyRwLock::new(None), @@ -271,6 +273,7 @@ impl Constructor for PyCFuncPtr { }; return PyCFuncPtr { + _base: PyCData::new(CDataObject::from_bytes(vec![], None)), ptr: PyRwLock::new(code_ptr), needs_free: AtomicCell::new(false), arg_types: PyRwLock::new(None), @@ -314,6 +317,7 @@ impl Constructor for PyCFuncPtr { let thunk_ref: PyRef = thunk.into_ref(&vm.ctx); return PyCFuncPtr { + _base: PyCData::new(CDataObject::from_bytes(vec![], None)), ptr: PyRwLock::new(Some(code_ptr)), needs_free: AtomicCell::new(true), arg_types: PyRwLock::new(arg_type_vec), diff --git a/crates/vm/src/stdlib/ctypes/pointer.rs b/crates/vm/src/stdlib/ctypes/pointer.rs index 156c4e54ee5..735034e7936 100644 --- a/crates/vm/src/stdlib/ctypes/pointer.rs +++ b/crates/vm/src/stdlib/ctypes/pointer.rs @@ -4,13 +4,14 @@ use rustpython_common::lock::PyRwLock; use crate::builtins::{PyType, PyTypeRef}; use crate::function::FuncArgs; use crate::protocol::PyNumberMethods; -use crate::stdlib::ctypes::PyCData; +use crate::stdlib::ctypes::{CDataObject, PyCData}; use crate::types::{AsNumber, Constructor}; use crate::{AsObject, Py, PyObjectRef, PyPayload, PyResult, VirtualMachine}; #[pyclass(name = "PyCPointerType", base = PyType, module = "_ctypes")] -#[derive(PyPayload, Debug, Default)] -pub struct PyCPointerType {} +#[derive(Debug)] +#[repr(transparent)] +pub struct PyCPointerType(PyType); #[pyclass(flags(IMMUTABLETYPE), with(AsNumber))] impl PyCPointerType { @@ -60,8 +61,9 @@ impl AsNumber for PyCPointerType { metaclass = "PyCPointerType", module = "_ctypes" )] -#[derive(Debug, PyPayload)] +#[derive(Debug)] pub struct PyCPointer { + _base: PyCData, contents: PyRwLock, } @@ -75,6 +77,7 @@ impl Constructor for PyCPointer { // Create a new PyCPointer instance with the provided value PyCPointer { + _base: PyCData::new(CDataObject::from_bytes(vec![], None)), contents: PyRwLock::new(initial_contents), } .into_ref_with_type(vm, cls) @@ -124,6 +127,7 @@ impl PyCPointer { } // Pointer just stores the address value Ok(PyCPointer { + _base: PyCData::new(CDataObject::from_bytes(vec![], None)), contents: PyRwLock::new(vm.ctx.new_int(address).into()), } .into_ref_with_type(vm, cls)? @@ -168,6 +172,7 @@ impl PyCPointer { let ptr_val = usize::from_ne_bytes(ptr_bytes.try_into().expect("size is checked above")); Ok(PyCPointer { + _base: PyCData::new(CDataObject::from_bytes(vec![], None)), contents: PyRwLock::new(vm.ctx.new_int(ptr_val).into()), } .into_ref_with_type(vm, cls)? @@ -204,6 +209,7 @@ impl PyCPointer { let ptr_val = usize::from_ne_bytes(ptr_bytes.try_into().expect("size is checked above")); Ok(PyCPointer { + _base: PyCData::new(CDataObject::from_bytes(vec![], None)), contents: PyRwLock::new(vm.ctx.new_int(ptr_val).into()), } .into_ref_with_type(vm, cls)? @@ -259,6 +265,7 @@ impl PyCPointer { // For pointer types, we return a pointer to the symbol address Ok(PyCPointer { + _base: PyCData::new(CDataObject::from_bytes(vec![], None)), contents: PyRwLock::new(vm.ctx.new_int(symbol_address).into()), } .into_ref_with_type(vm, cls)? diff --git a/crates/vm/src/stdlib/ctypes/structure.rs b/crates/vm/src/stdlib/ctypes/structure.rs index d3cdea69c72..f32d6865cb6 100644 --- a/crates/vm/src/stdlib/ctypes/structure.rs +++ b/crates/vm/src/stdlib/ctypes/structure.rs @@ -15,8 +15,9 @@ use std::fmt::Debug; /// PyCStructType - metaclass for Structure #[pyclass(name = "PyCStructType", base = PyType, module = "_ctypes")] -#[derive(Debug, PyPayload, Default)] -pub struct PyCStructType {} +#[derive(Debug)] +#[repr(transparent)] +pub struct PyCStructType(PyType); impl Constructor for PyCStructType { type Args = FuncArgs; @@ -218,8 +219,8 @@ pub struct FieldInfo { base = PyCData, metaclass = "PyCStructType" )] -#[derive(PyPayload)] pub struct PyCStructure { + _base: PyCData, /// Common CDataObject for memory buffer pub(super) cdata: PyRwLock, /// Field information (name -> FieldInfo) @@ -295,8 +296,10 @@ impl Constructor for PyCStructure { // Initialize buffer with zeros let mut stg_info = StgInfo::new(total_size, max_align); stg_info.length = fields_map.len(); + let cdata = CDataObject::from_stg_info(&stg_info); let instance = PyCStructure { - cdata: PyRwLock::new(CDataObject::from_stg_info(&stg_info)), + _base: PyCData::new(cdata.clone()), + cdata: PyRwLock::new(cdata), fields: PyRwLock::new(fields_map.clone()), }; @@ -364,8 +367,10 @@ impl PyCStructure { }; // Create instance + let cdata = CDataObject::from_bytes(data, None); Ok(PyCStructure { - cdata: PyRwLock::new(CDataObject::from_bytes(data, None)), + _base: PyCData::new(cdata.clone()), + cdata: PyRwLock::new(cdata), fields: PyRwLock::new(IndexMap::new()), } .into_ref_with_type(vm, cls)? @@ -415,8 +420,10 @@ impl PyCStructure { let data = bytes[offset..offset + size].to_vec(); // Create instance + let cdata = CDataObject::from_bytes(data, Some(source)); Ok(PyCStructure { - cdata: PyRwLock::new(CDataObject::from_bytes(data, Some(source))), + _base: PyCData::new(cdata.clone()), + cdata: PyRwLock::new(cdata), fields: PyRwLock::new(IndexMap::new()), } .into_ref_with_type(vm, cls)? @@ -458,8 +465,10 @@ impl PyCStructure { let data = source_bytes[offset..offset + size].to_vec(); // Create instance + let cdata = CDataObject::from_bytes(data, None); Ok(PyCStructure { - cdata: PyRwLock::new(CDataObject::from_bytes(data, None)), + _base: PyCData::new(cdata.clone()), + cdata: PyRwLock::new(cdata), fields: PyRwLock::new(IndexMap::new()), } .into_ref_with_type(vm, cls)? diff --git a/crates/vm/src/stdlib/ctypes/union.rs b/crates/vm/src/stdlib/ctypes/union.rs index 37d8e4f688b..e6873e87506 100644 --- a/crates/vm/src/stdlib/ctypes/union.rs +++ b/crates/vm/src/stdlib/ctypes/union.rs @@ -13,8 +13,9 @@ use rustpython_common::lock::PyRwLock; /// PyCUnionType - metaclass for Union #[pyclass(name = "UnionType", base = PyType, module = "_ctypes")] -#[derive(Debug, PyPayload, Default)] -pub struct PyCUnionType {} +#[derive(Debug)] +#[repr(transparent)] +pub struct PyCUnionType(PyType); impl Constructor for PyCUnionType { type Args = FuncArgs; @@ -121,8 +122,8 @@ impl PyCUnionType {} /// PyCUnion - base class for Union #[pyclass(module = "_ctypes", name = "Union", base = PyCData, metaclass = "PyCUnionType")] -#[derive(PyPayload)] pub struct PyCUnion { + _base: PyCData, /// Common CDataObject for memory buffer pub(super) cdata: PyRwLock, } @@ -173,8 +174,10 @@ impl Constructor for PyCUnion { // Initialize buffer with zeros let stg_info = StgInfo::new(max_size, max_align); + let cdata = CDataObject::from_stg_info(&stg_info); PyCUnion { - cdata: PyRwLock::new(CDataObject::from_stg_info(&stg_info)), + _base: PyCData::new(cdata.clone()), + cdata: PyRwLock::new(cdata), } .into_ref_with_type(vm, cls) .map(Into::into) @@ -204,8 +207,10 @@ impl PyCUnion { return Err(vm.new_value_error("NULL pointer access".to_owned())); } let stg_info = StgInfo::new(size, 1); + let cdata = CDataObject::from_stg_info(&stg_info); Ok(PyCUnion { - cdata: PyRwLock::new(CDataObject::from_stg_info(&stg_info)), + _base: PyCData::new(cdata.clone()), + cdata: PyRwLock::new(cdata), } .into_ref_with_type(vm, cls)? .into()) @@ -249,8 +254,10 @@ impl PyCUnion { let bytes = buffer.obj_bytes(); let data = bytes[offset..offset + size].to_vec(); + let cdata = CDataObject::from_bytes(data, None); Ok(PyCUnion { - cdata: PyRwLock::new(CDataObject::from_bytes(data, None)), + _base: PyCData::new(cdata.clone()), + cdata: PyRwLock::new(cdata), } .into_ref_with_type(vm, cls)? .into()) @@ -286,8 +293,10 @@ impl PyCUnion { // Copy data from source let data = source_bytes[offset..offset + size].to_vec(); + let cdata = CDataObject::from_bytes(data, None); Ok(PyCUnion { - cdata: PyRwLock::new(CDataObject::from_bytes(data, None)), + _base: PyCData::new(cdata.clone()), + cdata: PyRwLock::new(cdata), } .into_ref_with_type(vm, cls)? .into()) diff --git a/crates/vm/src/stdlib/io.rs b/crates/vm/src/stdlib/io.rs index e5a438b86de..3d67591d567 100644 --- a/crates/vm/src/stdlib/io.rs +++ b/crates/vm/src/stdlib/io.rs @@ -69,7 +69,7 @@ impl ToPyException for std::io::Error { .set_attr("winerror", vm.new_pyobj(winerror), vm) .unwrap(); } - exc + exc.upcast() } } @@ -87,9 +87,7 @@ pub(crate) fn make_module(vm: &VirtualMachine) -> PyRef { #[cfg(any(not(target_arch = "wasm32"), target_os = "wasi"))] fileio::extend_module(vm, &module).unwrap(); - let unsupported_operation = _io::UNSUPPORTED_OPERATION - .get_or_init(|| _io::make_unsupportedop(ctx)) - .clone(); + let unsupported_operation = _io::unsupported_operation().to_owned(); extend_module!(vm, &module, { "UnsupportedOperation" => unsupported_operation, "BlockingIOError" => ctx.exceptions.blocking_io_error.to_owned(), @@ -150,7 +148,7 @@ mod _io { AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, TryFromBorrowedObject, TryFromObject, builtins::{ - PyBaseExceptionRef, PyByteArray, PyBytes, PyBytesRef, PyIntRef, PyMemoryView, PyStr, + PyBaseExceptionRef, PyBool, PyByteArray, PyBytes, PyBytesRef, PyMemoryView, PyStr, PyStrRef, PyTuple, PyTupleRef, PyType, PyTypeRef, PyUtf8StrRef, }, class::StaticType, @@ -204,7 +202,8 @@ mod _io { } pub fn new_unsupported_operation(vm: &VirtualMachine, msg: String) -> PyBaseExceptionRef { - vm.new_exception_msg(UNSUPPORTED_OPERATION.get().unwrap().clone(), msg) + vm.new_os_subtype_error(unsupported_operation().to_owned(), None, msg) + .upcast() } fn _unsupported(vm: &VirtualMachine, zelf: &PyObject, operation: &str) -> PyResult { @@ -425,7 +424,7 @@ mod _io { #[pyattr] #[pyclass(name = "_IOBase")] - #[derive(Debug, PyPayload)] + #[derive(Debug, Default, PyPayload)] pub struct _IOBase; #[pyclass(with(IterNext, Iterable, Destructor), flags(BASETYPE, HAS_DICT))] @@ -456,7 +455,7 @@ mod _io { } #[pyattr] - fn __closed(ctx: &Context) -> PyIntRef { + fn __closed(ctx: &Context) -> PyRef { ctx.new_bool(false) } @@ -639,7 +638,9 @@ mod _io { #[pyattr] #[pyclass(name = "_RawIOBase", base = _IOBase)] - pub(super) struct _RawIOBase; + #[derive(Debug, Default)] + #[repr(transparent)] + pub(super) struct _RawIOBase(_IOBase); #[pyclass(flags(BASETYPE, HAS_DICT))] impl _RawIOBase { @@ -697,7 +698,9 @@ mod _io { #[pyattr] #[pyclass(name = "_BufferedIOBase", base = _IOBase)] - struct _BufferedIOBase; + #[derive(Debug, Default)] + #[repr(transparent)] + struct _BufferedIOBase(_IOBase); #[pyclass(flags(BASETYPE))] impl _BufferedIOBase { @@ -760,8 +763,9 @@ mod _io { // TextIO Base has no public constructor #[pyattr] #[pyclass(name = "_TextIOBase", base = _IOBase)] - #[derive(Debug, PyPayload)] - struct _TextIOBase; + #[derive(Debug, Default)] + #[repr(transparent)] + struct _TextIOBase(_IOBase); #[pyclass(flags(BASETYPE))] impl _TextIOBase { @@ -1760,8 +1764,9 @@ mod _io { #[pyattr] #[pyclass(name = "BufferedReader", base = _BufferedIOBase)] - #[derive(Debug, Default, PyPayload)] + #[derive(Debug, Default)] struct BufferedReader { + _base: _BufferedIOBase, data: PyThreadMutex, } @@ -1829,8 +1834,9 @@ mod _io { #[pyattr] #[pyclass(name = "BufferedWriter", base = _BufferedIOBase)] - #[derive(Debug, Default, PyPayload)] + #[derive(Debug, Default)] struct BufferedWriter { + _base: _BufferedIOBase, data: PyThreadMutex, } @@ -1874,8 +1880,9 @@ mod _io { #[pyattr] #[pyclass(name = "BufferedRandom", base = _BufferedIOBase)] - #[derive(Debug, Default, PyPayload)] + #[derive(Debug, Default)] struct BufferedRandom { + _base: _BufferedIOBase, data: PyThreadMutex, } @@ -1934,8 +1941,9 @@ mod _io { #[pyattr] #[pyclass(name = "BufferedRWPair", base = _BufferedIOBase)] - #[derive(Debug, Default, PyPayload)] + #[derive(Debug, Default)] struct BufferedRWPair { + _base: _BufferedIOBase, read: BufferedReader, write: BufferedWriter, } @@ -2366,8 +2374,9 @@ mod _io { #[pyattr] #[pyclass(name = "TextIOWrapper", base = _TextIOBase)] - #[derive(Debug, Default, PyPayload)] + #[derive(Debug, Default)] struct TextIOWrapper { + _base: _TextIOBase, data: PyThreadMutex>, } @@ -3646,8 +3655,9 @@ mod _io { #[pyattr] #[pyclass(name = "StringIO", base = _TextIOBase)] - #[derive(Debug, PyPayload)] + #[derive(Debug)] struct StringIO { + _base: _TextIOBase, buffer: PyRwLock, closed: AtomicCell, } @@ -3677,6 +3687,7 @@ mod _io { .map_or_else(Vec::new, |v| v.as_bytes().to_vec()); Ok(Self { + _base: Default::default(), buffer: PyRwLock::new(BufferedIO::new(Cursor::new(raw_bytes))), closed: AtomicCell::new(false), }) @@ -3789,8 +3800,9 @@ mod _io { #[pyattr] #[pyclass(name = "BytesIO", base = _BufferedIOBase)] - #[derive(Debug, PyPayload)] + #[derive(Debug)] struct BytesIO { + _base: _BufferedIOBase, buffer: PyRwLock, closed: AtomicCell, exports: AtomicCell, @@ -3805,6 +3817,7 @@ mod _io { .map_or_else(Vec::new, |input| input.as_bytes().to_vec()); Ok(Self { + _base: Default::default(), buffer: PyRwLock::new(BufferedIO::new(Cursor::new(raw_bytes))), closed: AtomicCell::new(false), exports: AtomicCell::new(0), @@ -4229,11 +4242,7 @@ mod _io { } } - rustpython_common::static_cell! { - pub(super) static UNSUPPORTED_OPERATION: PyTypeRef; - } - - pub(super) fn make_unsupportedop(ctx: &Context) -> PyTypeRef { + fn create_unsupported_operation(ctx: &Context) -> PyTypeRef { use crate::types::PyTypeSlots; PyType::new_heap( "UnsupportedOperation", @@ -4249,6 +4258,13 @@ mod _io { .unwrap() } + pub fn unsupported_operation() -> &'static Py { + rustpython_common::static_cell! { + static CELL: PyTypeRef; + } + CELL.get_or_init(|| create_unsupported_operation(Context::genesis())) + } + #[pyfunction] fn text_encoding( encoding: PyObjectRef, @@ -4423,8 +4439,9 @@ mod fileio { #[pyattr] #[pyclass(module = "io", name, base = _RawIOBase)] - #[derive(Debug, PyPayload)] + #[derive(Debug)] pub(super) struct FileIO { + _base: _RawIOBase, fd: AtomicCell, closefd: AtomicCell, mode: AtomicCell, @@ -4446,6 +4463,7 @@ mod fileio { impl Default for FileIO { fn default() -> Self { Self { + _base: Default::default(), fd: AtomicCell::new(-1), closefd: AtomicCell::new(true), mode: AtomicCell::new(Mode::empty()), diff --git a/crates/vm/src/stdlib/os.rs b/crates/vm/src/stdlib/os.rs index b75f601c8db..f6ffd66759d 100644 --- a/crates/vm/src/stdlib/os.rs +++ b/crates/vm/src/stdlib/os.rs @@ -525,13 +525,15 @@ pub(super) mod _os { return Err(vm.new_value_error("embedded null byte")); } if key.is_empty() || key.contains(&b'=') { - return Err(vm.new_errno_error( + let x = vm.new_errno_error( 22, format!( "Invalid argument: {}", std::str::from_utf8(key).unwrap_or("") ), - )); + ); + + return Err(x.upcast()); } let key = super::bytes_as_os_str(key, vm)?; // SAFETY: requirements forwarded from the caller diff --git a/crates/vm/src/types/structseq.rs b/crates/vm/src/types/structseq.rs index b2ff5868d45..be0a1c9a70c 100644 --- a/crates/vm/src/types/structseq.rs +++ b/crates/vm/src/types/structseq.rs @@ -184,10 +184,12 @@ pub trait PyStructSequence: StaticType + PyClassImpl + Sized + 'static { /// Convert a Data struct into a PyStructSequence instance. fn from_data(data: Self::Data, vm: &VirtualMachine) -> PyTupleRef { + let tuple = + ::into_tuple(data, vm); let typ = Self::static_type(); - data.into_tuple(vm) + tuple .into_ref_with_type(vm, typ.to_owned()) - .unwrap() + .expect("Every PyStructSequence must be a valid tuple. This is a RustPython bug.") } #[pyslot] diff --git a/crates/vm/src/vm/context.rs b/crates/vm/src/vm/context.rs index 191c090f121..fbda71dc1f6 100644 --- a/crates/vm/src/vm/context.rs +++ b/crates/vm/src/vm/context.rs @@ -1,9 +1,10 @@ use crate::{ PyResult, VirtualMachine, builtins::{ - PyBaseException, PyByteArray, PyBytes, PyComplex, PyDict, PyDictRef, PyEllipsis, PyFloat, - PyFrozenSet, PyInt, PyIntRef, PyList, PyListRef, PyNone, PyNotImplemented, PyStr, - PyStrInterned, PyTuple, PyTupleRef, PyType, PyTypeRef, + PyByteArray, PyBytes, PyComplex, PyDict, PyDictRef, PyEllipsis, PyFloat, PyFrozenSet, + PyInt, PyIntRef, PyList, PyListRef, PyNone, PyNotImplemented, PyStr, PyStrInterned, + PyTuple, PyTupleRef, PyType, PyTypeRef, + bool_::PyBool, code::{self, PyCode}, descriptor::{ MemberGetter, MemberKind, MemberSetter, MemberSetterFunc, PyDescriptorOwned, @@ -13,7 +14,7 @@ use crate::{ object, pystr, type_::PyAttributes, }, - class::{PyClassImpl, StaticType}, + class::StaticType, common::rc::PyRc, exceptions, function::{ @@ -31,8 +32,8 @@ use rustpython_common::lock::PyRwLock; #[derive(Debug)] pub struct Context { - pub true_value: PyIntRef, - pub false_value: PyIntRef, + pub true_value: PyRef, + pub false_value: PyRef, pub none: PyRef, pub empty_tuple: PyTupleRef, pub empty_frozenset: PyRef, @@ -279,10 +280,7 @@ impl Context { let exceptions = exceptions::ExceptionZoo::init(); #[inline] - fn create_object( - payload: T, - cls: &'static Py, - ) -> PyRef { + fn create_object(payload: T, cls: &'static Py) -> PyRef { PyRef::new_ref(payload, cls.to_owned(), None) } @@ -305,8 +303,8 @@ impl Context { }) .collect(); - let true_value = create_object(PyInt::from(1), types.bool_type); - let false_value = create_object(PyInt::from(0), types.bool_type); + let true_value = create_object(PyBool(PyInt::from(1)), types.bool_type); + let false_value = create_object(PyBool(PyInt::from(0)), types.bool_type); let empty_tuple = create_object( PyTuple::new_unchecked(Vec::new().into_boxed_slice()), @@ -449,13 +447,13 @@ impl Context { } #[inline(always)] - pub fn new_bool(&self, b: bool) -> PyIntRef { + pub fn new_bool(&self, b: bool) -> PyRef { let value = if b { &self.true_value } else { &self.false_value }; - value.clone() + value.to_owned() } #[inline(always)] @@ -510,14 +508,17 @@ impl Context { attrs.insert(identifier!(self, __module__), self.new_str(module).into()); let interned_name = self.intern_str(name); + let slots = PyTypeSlots { + name: interned_name.as_str(), + basicsize: 0, + flags: PyTypeFlags::heap_type_flags() | PyTypeFlags::HAS_DICT, + ..PyTypeSlots::default() + }; PyType::new_heap( name, bases, attrs, - PyTypeSlots { - name: interned_name.as_str(), - ..PyBaseException::make_slots() - }, + slots, self.types.type_type.to_owned(), self, ) diff --git a/crates/vm/src/vm/vm_new.rs b/crates/vm/src/vm/vm_new.rs index 1054ba9b313..36481e5dbe3 100644 --- a/crates/vm/src/vm/vm_new.rs +++ b/crates/vm/src/vm/vm_new.rs @@ -1,8 +1,8 @@ use crate::{ - AsObject, Py, PyObject, PyObjectRef, PyRef, + AsObject, Py, PyObject, PyObjectRef, PyPayload, PyRef, builtins::{ - PyBaseException, PyBaseExceptionRef, PyBytesRef, PyDictRef, PyModule, PyStrRef, PyType, - PyTypeRef, + PyBaseException, PyBaseExceptionRef, PyBytesRef, PyDictRef, PyModule, PyOSError, PyStrRef, + PyType, PyTypeRef, builtin_func::PyNativeFunction, descriptor::PyMethodDescriptor, tuple::{IntoPyTuple, PyTupleRef}, @@ -10,6 +10,7 @@ use crate::{ convert::{ToPyException, ToPyObject}, function::{IntoPyNativeFn, PyMethodFlags}, scope::Scope, + types::Constructor, vm::VirtualMachine, }; use rustpython_compiler_core::SourceLocation; @@ -92,18 +93,54 @@ impl VirtualMachine { /// [`vm.invoke_exception()`][Self::invoke_exception] or /// [`exceptions::ExceptionCtor`][crate::exceptions::ExceptionCtor] instead. pub fn new_exception(&self, exc_type: PyTypeRef, args: Vec) -> PyBaseExceptionRef { - // TODO: add repr of args into logging? + debug_assert_eq!( + exc_type.slots.basicsize, + std::mem::size_of::(), + "vm.new_exception() is only for exception types without additional payload. The given type '{}' is not allowed.", + exc_type.class().name() + ); PyRef::new_ref( - // TODO: this constructor might be invalid, because multiple - // exception (even builtin ones) are using custom constructors, - // see `OSError` as an example: PyBaseException::new(args, self), exc_type, Some(self.ctx.new_dict()), ) } + pub fn new_os_error(&self, msg: impl ToPyObject) -> PyRef { + self.new_os_subtype_error(self.ctx.exceptions.os_error.to_owned(), None, msg) + .upcast() + } + + pub fn new_os_subtype_error( + &self, + exc_type: PyTypeRef, + errno: Option, + msg: impl ToPyObject, + ) -> PyRef { + debug_assert_eq!(exc_type.slots.basicsize, std::mem::size_of::()); + let msg = msg.to_pyobject(self); + + fn new_os_subtype_error_impl( + vm: &VirtualMachine, + exc_type: PyTypeRef, + errno: Option, + msg: PyObjectRef, + ) -> PyRef { + let args = match errno { + Some(e) => vec![vm.new_pyobj(e), msg], + None => vec![msg], + }; + let payload = + PyOSError::py_new(&exc_type, args.into(), vm).expect("new_os_error usage error"); + payload + .into_ref_with_type(vm, exc_type) + .expect("new_os_error usage error") + } + + new_os_subtype_error_impl(self, exc_type, errno, msg) + } + /// Instantiate an exception with no arguments. /// This function should only be used with builtin exception types; if a user-defined exception /// type is passed in, it may not be fully initialized; try using @@ -220,16 +257,11 @@ impl VirtualMachine { err.to_pyexception(self) } - pub fn new_errno_error(&self, errno: i32, msg: impl Into) -> PyBaseExceptionRef { - let vm = self; - let exc_type = - crate::exceptions::errno_to_exc_type(errno, vm).unwrap_or(vm.ctx.exceptions.os_error); + pub fn new_errno_error(&self, errno: i32, msg: impl ToPyObject) -> PyRef { + let exc_type = crate::exceptions::errno_to_exc_type(errno, self) + .unwrap_or(self.ctx.exceptions.os_error); - let errno_obj = vm.new_pyobj(errno); - vm.new_exception( - exc_type.to_owned(), - vec![errno_obj, vm.new_pyobj(msg.into())], - ) + self.new_os_subtype_error(exc_type.to_owned(), Some(errno), msg) } pub fn new_unicode_decode_error_real( @@ -565,7 +597,6 @@ impl VirtualMachine { define_exception_fn!(fn new_eof_error, eof_error, EOFError); define_exception_fn!(fn new_attribute_error, attribute_error, AttributeError); define_exception_fn!(fn new_type_error, type_error, TypeError); - define_exception_fn!(fn new_os_error, os_error, OSError); define_exception_fn!(fn new_system_error, system_error, SystemError); // TODO: remove & replace with new_unicode_decode_error_real diff --git a/extra_tests/snippets/stdlib_ctypes.py b/extra_tests/snippets/stdlib_ctypes.py index 0ec7568a839..0a5d1387a8d 100644 --- a/extra_tests/snippets/stdlib_ctypes.py +++ b/extra_tests/snippets/stdlib_ctypes.py @@ -397,3 +397,5 @@ def get_win_folder_via_ctypes(csidl_name: str) -> str: return buf.value # print(get_win_folder_via_ctypes("CSIDL_DOWNLOADS")) + +print("done")