Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions Lib/test/test_builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2401,8 +2401,6 @@ def test_type_nokwargs(self):
with self.assertRaises(TypeError):
type('a', (), dict={})

# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_type_name(self):
for name in 'A', '\xc4', '\U0001f40d', 'B.A', '42', '':
with self.subTest(name=name):
Expand Down Expand Up @@ -2452,8 +2450,6 @@ def test_type_qualname(self):
A.__qualname__ = b'B'
self.assertEqual(A.__qualname__, 'D.E')

# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_type_doc(self):
for doc in 'x', '\xc4', '\U0001f40d', 'x\x00y', b'x', 42, None:
A = type('A', (), {'__doc__': doc})
Expand Down
2 changes: 1 addition & 1 deletion crates/vm/src/builtins/str.rs
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ impl PyStr {
self.data.as_str()
}

fn ensure_valid_utf8(&self, vm: &VirtualMachine) -> PyResult<()> {
pub(crate) fn ensure_valid_utf8(&self, vm: &VirtualMachine) -> PyResult<()> {
if self.is_utf8() {
Ok(())
} else {
Expand Down
22 changes: 22 additions & 0 deletions crates/vm/src/builtins/type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1163,6 +1163,7 @@ impl PyType {
if name.as_bytes().contains(&0) {
return Err(vm.new_value_error("type name must not contain null characters"));
}
name.ensure_valid_utf8(vm)?;

// Use std::mem::replace to swap the new value in and get the old value out,
// then drop the old value after releasing the lock (similar to CPython's Py_SETREF)
Expand Down Expand Up @@ -1254,6 +1255,7 @@ impl Constructor for PyType {
if name.as_bytes().contains(&0) {
return Err(vm.new_value_error("type name must not contain null characters"));
}
name.ensure_valid_utf8(vm)?;

let (metatype, base, bases, base_is_type) = if bases.is_empty() {
let base = vm.ctx.types.object_type.to_owned();
Expand Down Expand Up @@ -1306,6 +1308,13 @@ impl Constructor for PyType {
});
let mut attributes = dict.to_attributes(vm);

// Check __doc__ for surrogates - raises UnicodeEncodeError during type creation
if let Some(doc) = attributes.get(identifier!(vm, __doc__))
&& let Some(doc_str) = doc.downcast_ref::<PyStr>()
{
doc_str.ensure_valid_utf8(vm)?;
}

if let Some(f) = attributes.get_mut(identifier!(vm, __init_subclass__))
&& f.class().is(vm.ctx.types.function_type)
{
Expand Down Expand Up @@ -1340,6 +1349,13 @@ impl Constructor for PyType {

let (heaptype_slots, add_dict): (Option<PyRef<PyTuple<PyStrRef>>>, bool) =
if let Some(x) = attributes.get(identifier!(vm, __slots__)) {
// Check if __slots__ is bytes - not allowed
if x.class().is(vm.ctx.types.bytes_type) {
return Err(vm.new_type_error(
"__slots__ items must be strings, not 'bytes'".to_owned(),
));
}

let slots = if x.class().is(vm.ctx.types.str_type) {
let x = unsafe { x.downcast_unchecked_ref::<PyStr>() };
PyTuple::new_ref_typed(vec![x.to_owned()], &vm.ctx)
Expand All @@ -1348,6 +1364,12 @@ impl Constructor for PyType {
let elements = {
let mut elements = Vec::new();
while let PyIterReturn::Return(element) = iter.next(vm)? {
// Check if any slot item is bytes
if element.class().is(vm.ctx.types.bytes_type) {
return Err(vm.new_type_error(
"__slots__ items must be strings, not 'bytes'".to_owned(),
));
}
elements.push(element);
}
elements
Expand Down
Loading