Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
GenericAlias works
  • Loading branch information
youknowone committed Jun 30, 2025
commit 685e4132a24ef0817d7869ae81f07d220b4f0340
31 changes: 5 additions & 26 deletions Lib/test/test_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,8 +689,6 @@ class A(Generic[T, Unpack[Ts]]): ...
self.assertEqual(A[float, range].__args__, (float, range))
self.assertEqual(A[float, *tuple[int, ...]].__args__, (float, *tuple[int, ...]))

# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_typevar_and_typevartuple_specialization(self):
T = TypeVar("T")
U = TypeVar("U", default=float)
Expand Down Expand Up @@ -738,8 +736,6 @@ class A(Generic[T, P]): ...
self.assertEqual(A[float].__args__, (float, (str, int)))
self.assertEqual(A[float, [range]].__args__, (float, (range,)))

# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_typevar_and_paramspec_specialization(self):
T = TypeVar("T")
U = TypeVar("U", default=float)
Expand All @@ -750,8 +746,6 @@ class A(Generic[T, U, P]): ...
self.assertEqual(A[float, int].__args__, (float, int, (str, int)))
self.assertEqual(A[float, int, [range]].__args__, (float, int, (range,)))

# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_paramspec_and_typevar_specialization(self):
T = TypeVar("T")
P = ParamSpec('P', default=[str, int])
Expand Down Expand Up @@ -2539,8 +2533,6 @@ def __call__(self):
self.assertIs(a().__class__, C1)
self.assertEqual(a().__orig_class__, C1[[int], T])

# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_paramspec(self):
Callable = self.Callable
fullname = f"{Callable.__module__}.Callable"
Expand Down Expand Up @@ -2575,8 +2567,6 @@ def test_paramspec(self):
self.assertEqual(repr(C2), f"{fullname}[~P, int]")
self.assertEqual(repr(C2[int, str]), f"{fullname}[[int, str], int]")

# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_concatenate(self):
Callable = self.Callable
fullname = f"{Callable.__module__}.Callable"
Expand Down Expand Up @@ -2604,8 +2594,6 @@ def test_concatenate(self):
Callable[Concatenate[int, str, P2], int])
self.assertEqual(C[...], Callable[Concatenate[int, ...], int])

# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_nested_paramspec(self):
# Since Callable has some special treatment, we want to be sure
# that substituion works correctly, see gh-103054
Expand Down Expand Up @@ -2648,8 +2636,6 @@ class My(Generic[P, T]):
self.assertEqual(C4[bool, bytes, float],
My[[Callable[[int, bool, bytes, str], float], float], float])

# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_errors(self):
Callable = self.Callable
alias = Callable[[int, str], float]
Expand Down Expand Up @@ -2678,6 +2664,11 @@ def test_consistency(self):
class CollectionsCallableTests(BaseCallableTests, BaseTestCase):
Callable = collections.abc.Callable

# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_errors(self):
super().test_errors()


class LiteralTests(BaseTestCase):
def test_basics(self):
Expand Down Expand Up @@ -4627,8 +4618,6 @@ class Base(Generic[T_co]):
class Sub(Base, Generic[T]):
...

# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_parameter_detection(self):
self.assertEqual(List[T].__parameters__, (T,))
self.assertEqual(List[List[T]].__parameters__, (T,))
Expand All @@ -4646,8 +4635,6 @@ class A:
# C version of GenericAlias
self.assertEqual(list[A()].__parameters__, (T,))

# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_non_generic_subscript(self):
T = TypeVar('T')
class G(Generic[T]):
Expand Down Expand Up @@ -8854,8 +8841,6 @@ def test_bad_var_substitution(self):
with self.assertRaises(TypeError):
collections.abc.Callable[P, T][arg, str]

# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_type_var_subst_for_other_type_vars(self):
T = TypeVar('T')
T2 = TypeVar('T2')
Expand Down Expand Up @@ -8977,8 +8962,6 @@ class PandT(Generic[P, T]):
self.assertEqual(C3.__args__, ((int, *Ts), T))
self.assertEqual(C3[str, bool, bytes], PandT[[int, str, bool], bytes])

# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_paramspec_in_nested_generics(self):
# Although ParamSpec should not be found in __parameters__ of most
# generics, they probably should be found when nested in
Expand All @@ -8997,8 +8980,6 @@ def test_paramspec_in_nested_generics(self):
self.assertEqual(G2[[int, str], float], list[C])
self.assertEqual(G3[[int, str], float], list[C] | int)

# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_paramspec_gets_copied(self):
# bpo-46581
P = ParamSpec('P')
Expand Down Expand Up @@ -9086,8 +9067,6 @@ def test_invalid_uses(self):
):
Concatenate[int]

# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_var_substitution(self):
T = TypeVar('T')
P = ParamSpec('P')
Expand Down
133 changes: 103 additions & 30 deletions vm/src/builtins/genericalias.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,17 +167,17 @@ impl PyGenericAlias {
}

#[pymethod]
fn __getitem__(&self, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult {
fn __getitem__(zelf: PyRef<Self>, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult {
let new_args = subs_parameters(
|vm| self.repr(vm),
self.args.clone(),
self.parameters.clone(),
zelf.to_owned().into(),
zelf.args.clone(),
zelf.parameters.clone(),
needle,
vm,
)?;

Ok(
PyGenericAlias::new(self.origin.clone(), new_args.to_pyobject(vm), vm)
PyGenericAlias::new(zelf.origin.clone(), new_args.to_pyobject(vm), vm)
.into_pyobject(vm),
)
}
Expand Down Expand Up @@ -278,6 +278,18 @@ fn tuple_index(vec: &[PyObjectRef], item: &PyObjectRef) -> Option<usize> {
vec.iter().position(|element| element.is(item))
}

fn is_unpacked_typevartuple(arg: &PyObjectRef, vm: &VirtualMachine) -> PyResult<bool> {
if arg.class().is(vm.ctx.types.type_type) {
return Ok(false);
}

if let Ok(attr) = arg.get_attr(identifier!(vm, __typing_is_unpacked_typevartuple__), vm) {
attr.try_to_bool(vm)
} else {
Ok(false)
}
}

fn subs_tvars(
obj: PyObjectRef,
params: &PyTupleRef,
Expand Down Expand Up @@ -325,22 +337,40 @@ fn subs_tvars(
}

// _Py_subs_parameters
pub fn subs_parameters<F: Fn(&VirtualMachine) -> PyResult<String>>(
repr: F,
pub fn subs_parameters(
alias: PyObjectRef, // The GenericAlias object itself
args: PyTupleRef,
parameters: PyTupleRef,
needle: PyObjectRef,
vm: &VirtualMachine,
) -> PyResult<PyTupleRef> {
let num_params = parameters.len();
if num_params == 0 {
return Err(vm.new_type_error(format!("There are no type variables left in {}", repr(vm)?)));
return Err(vm.new_type_error(format!("{} is not a generic class", alias.repr(vm)?)));
}

// Handle __typing_prepare_subst__ for each parameter
// Following CPython: each prepare function transforms the args
let mut prepared_args = needle.clone();

// Ensure args is a tuple
if prepared_args.try_to_ref::<PyTuple>(vm).is_err() {
prepared_args = PyTuple::new_ref(vec![prepared_args], &vm.ctx).into();
}

let items = needle.try_to_ref::<PyTuple>(vm);
for param in parameters.iter() {
if let Ok(prepare) = param.get_attr(identifier!(vm, __typing_prepare_subst__), vm) {
if !prepare.is(&vm.ctx.none) {
// Call prepare(cls, args) where cls is the GenericAlias
prepared_args = prepare.call((alias.clone(), prepared_args), vm)?;
}
}
}

let items = prepared_args.try_to_ref::<PyTuple>(vm);
let arg_items = match items {
Ok(tuple) => tuple.as_slice(),
Err(_) => std::slice::from_ref(&needle),
Err(_) => std::slice::from_ref(&prepared_args),
};

let num_items = arg_items.len();
Expand All @@ -363,40 +393,82 @@ pub fn subs_parameters<F: Fn(&VirtualMachine) -> PyResult<String>>(

let min_required = num_params - params_with_defaults;
if num_items < min_required {
let repr_str = alias.repr(vm)?;
return Err(vm.new_type_error(format!(
"Too few arguments for {}; actual {}, expected at least {}",
repr(vm)?,
num_items,
min_required
"Too few arguments for {repr_str}; actual {num_items}, expected at least {min_required}"
)));
}
} else if num_items > num_params {
let repr_str = alias.repr(vm)?;
return Err(vm.new_type_error(format!(
"Too many arguments for {}; actual {}, expected {}",
repr(vm)?,
num_items,
num_params
"Too many arguments for {repr_str}; actual {num_items}, expected {num_params}"
)));
}

let mut new_args = Vec::new();
let mut new_args = Vec::with_capacity(args.len());

for arg in args.iter() {
// Skip bare Python classes
if arg.class().is(vm.ctx.types.type_type) {
new_args.push(arg.clone());
continue;
}

// Check if this is an unpacked TypeVarTuple
let unpack = is_unpacked_typevartuple(arg, vm)?;

// Check for __typing_subst__ attribute directly (like CPython)
if let Ok(subst) = arg.get_attr(identifier!(vm, __typing_subst__), vm) {
let idx = tuple_index(parameters.as_slice(), arg).unwrap();
if idx < num_items {
// Call __typing_subst__ with the argument
let substituted = subst.call((arg_items[idx].clone(),), vm)?;
new_args.push(substituted);
if let Some(idx) = tuple_index(parameters.as_slice(), arg) {
if idx < num_items {
// Call __typing_subst__ with the argument
let substituted = subst.call((arg_items[idx].clone(),), vm)?;

if unpack {
// Unpack the tuple if it's a TypeVarTuple
if let Ok(tuple) = substituted.try_to_ref::<PyTuple>(vm) {
for elem in tuple.iter() {
new_args.push(elem.clone());
}
} else {
new_args.push(substituted);
}
} else {
new_args.push(substituted);
}
} else {
// Use default value if available
if let Ok(default_val) = vm.call_method(arg, "__default__", ()) {
if !default_val.is(&vm.ctx.typing_no_default) {
new_args.push(default_val);
} else {
return Err(vm.new_type_error(format!(
"No argument provided for parameter at index {idx}"
)));
}
} else {
return Err(vm.new_type_error(format!(
"No argument provided for parameter at index {idx}"
)));
}
}
} else {
// CPython doesn't support default values in this context
return Err(
vm.new_type_error(format!("No argument provided for parameter at index {idx}"))
);
new_args.push(arg.clone());
}
} else {
new_args.push(subs_tvars(arg.clone(), &parameters, arg_items, vm)?);
let subst_arg = subs_tvars(arg.clone(), &parameters, arg_items, vm)?;
if unpack {
// Unpack the tuple if it's a TypeVarTuple
if let Ok(tuple) = subst_arg.try_to_ref::<PyTuple>(vm) {
for elem in tuple.iter() {
new_args.push(elem.clone());
}
} else {
new_args.push(subst_arg);
}
} else {
new_args.push(subst_arg);
}
}
}

Expand All @@ -407,7 +479,8 @@ impl AsMapping for PyGenericAlias {
fn as_mapping() -> &'static PyMappingMethods {
static AS_MAPPING: LazyLock<PyMappingMethods> = LazyLock::new(|| PyMappingMethods {
subscript: atomic_func!(|mapping, needle, vm| {
PyGenericAlias::mapping_downcast(mapping).__getitem__(needle.to_owned(), vm)
let zelf = PyGenericAlias::mapping_downcast(mapping);
PyGenericAlias::__getitem__(zelf.to_owned(), needle.to_owned(), vm)
}),
..PyMappingMethods::NOT_IMPLEMENTED
});
Expand Down
11 changes: 6 additions & 5 deletions vm/src/builtins/union.rs
Original file line number Diff line number Diff line change
Expand Up @@ -206,11 +206,11 @@ pub fn make_union(args: &Py<PyTuple>, vm: &VirtualMachine) -> PyObjectRef {
}

impl PyUnion {
fn getitem(&self, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult {
fn getitem(zelf: PyRef<Self>, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult {
let new_args = genericalias::subs_parameters(
|vm| self.repr(vm),
self.args.clone(),
self.parameters.clone(),
zelf.to_owned().into(),
zelf.args.clone(),
zelf.parameters.clone(),
needle,
vm,
)?;
Expand All @@ -232,7 +232,8 @@ impl AsMapping for PyUnion {
fn as_mapping() -> &'static PyMappingMethods {
static AS_MAPPING: LazyLock<PyMappingMethods> = LazyLock::new(|| PyMappingMethods {
subscript: atomic_func!(|mapping, needle, vm| {
PyUnion::mapping_downcast(mapping).getitem(needle.to_owned(), vm)
let zelf = PyUnion::mapping_downcast(mapping);
PyUnion::getitem(zelf.to_owned(), needle.to_owned(), vm)
}),
..PyMappingMethods::NOT_IMPLEMENTED
});
Expand Down
Loading