Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions AUTHORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,4 @@
- ([@gpetrou](https://github.com/gpetrou))
- Ehsan Iran-Nejad ([@eirannejad](https://github.com/eirannejad))
- ([@legomanww](https://github.com/legomanww))
- ([@gertdreyer](https://github.com/gertdreyer))
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ This document follows the conventions laid out in [Keep a CHANGELOG][].

### Fixed

- Fixed RecursionError for reverse operators on C# operable types from python. See #2240

## [3.0.3](https://github.com/pythonnet/pythonnet/releases/tag/v3.0.3) - 2023-10-11

Expand Down
230 changes: 230 additions & 0 deletions src/embed_tests/TestOperator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ public class TestOperator
public void SetUp()
{
PythonEngine.Initialize();
OwnIntCodec.Setup();
}

[OneTimeTearDown]
Expand All @@ -23,6 +24,120 @@ public void Dispose()
PythonEngine.Shutdown();
}

// Mock Integer class to test math ops on non-native dotnet types
public struct OwnInt
{
private int _value;

public int Num => _value;

public OwnInt()
{
_value = 0;
}

public OwnInt(int value)
{
_value = value;
}

public static OwnInt operator -(OwnInt p1, OwnInt p2)
{
return new OwnInt(p1._value - p2._value);
}

public static OwnInt operator +(OwnInt p1, OwnInt p2)
{
return new OwnInt(p1._value + p2._value);
}

public static OwnInt operator *(OwnInt p1, OwnInt p2)
{
return new OwnInt(p1._value * p2._value);
}

public static OwnInt operator /(OwnInt p1, OwnInt p2)
{
return new OwnInt(p1._value / p2._value);
}

public static OwnInt operator %(OwnInt p1, OwnInt p2)
{
return new OwnInt(p1._value % p2._value);
}

public static OwnInt operator ^(OwnInt p1, OwnInt p2)
{
return new OwnInt(p1._value ^ p2._value);
}

public static bool operator <(OwnInt p1, OwnInt p2)
{
return p1._value < p2._value;
}

public static bool operator >(OwnInt p1, OwnInt p2)
{
return p1._value > p2._value;
}

public static bool operator ==(OwnInt p1, OwnInt p2)
{
return p1._value == p2._value;
}

public static bool operator !=(OwnInt p1, OwnInt p2)
{
return p1._value != p2._value;
}

public static OwnInt operator |(OwnInt p1, OwnInt p2)
{
return new OwnInt(p1._value | p2._value);
}

public static OwnInt operator &(OwnInt p1, OwnInt p2)
{
return new OwnInt(p1._value & p2._value);
}

public static bool operator <=(OwnInt p1, OwnInt p2)
{
return p1._value <= p2._value;
}

public static bool operator >=(OwnInt p1, OwnInt p2)
{
return p1._value >= p2._value;
}
}

// Codec for mock class above.
public class OwnIntCodec : IPyObjectDecoder
{
public static void Setup()
{
PyObjectConversions.RegisterDecoder(new OwnIntCodec());
}

public bool CanDecode(PyType objectType, Type targetType)
{
return objectType.Name == "int" && targetType == typeof(OwnInt);
}

public bool TryDecode<T>(PyObject pyObj, out T? value)
{
if (pyObj.PyType.Name != "int" || typeof(T) != typeof(OwnInt))
{
value = default(T);
return false;
}

value = (T)(object)new OwnInt(pyObj.As<int>());
return true;
}
}

public class OperableObject
{
public int Num { get; set; }
Expand Down Expand Up @@ -524,6 +639,121 @@ public void ShiftOperatorOverloads()

c = a >> b.Num
assert c.Num == a.Num >> b.Num
");
}

[Test]
public void ReverseOperatorWithCodec()
{
string name = string.Format("{0}.{1}",
typeof(OwnInt).DeclaringType.Name,
typeof(OwnInt).Name);
string module = MethodBase.GetCurrentMethod().DeclaringType.Namespace;

PythonEngine.Exec($@"
from {module} import *
cls = {name}
a = 2
b = cls(10)

c = a + b
assert c.Num == a + b.Num

c = a - b
assert c.Num == a - b.Num

c = a * b
assert c.Num == a * b.Num

c = a / b
assert c.Num == a // b.Num

c = a % b
assert c.Num == a % b.Num

c = a & b
assert c.Num == a & b.Num

c = a | b
assert c.Num == a | b.Num

c = a ^ b
assert c.Num == a ^ b.Num

c = a == b
assert c == (a == b.Num)

c = a != b
assert c == (a != b.Num)

c = a <= b
assert c == (a <= b.Num)

c = a >= b
assert c == (a >= b.Num)

c = a < b
assert c == (a < b.Num)

c = a > b
assert c == (a > b.Num)
");
}

[Test]
public void ForwardOperatorWithCodec()
{
string name = string.Format("{0}.{1}",
typeof(OwnInt).DeclaringType.Name,
typeof(OwnInt).Name);
string module = MethodBase.GetCurrentMethod().DeclaringType.Namespace;

PythonEngine.Exec($@"
from {module} import *
cls = {name}
a = cls(2)
b = 10
c = a + b
assert c.Num == a.Num + b

c = a - b
assert c.Num == a.Num - b

c = a * b
assert c.Num == a.Num * b

c = a / b
assert c.Num == a.Num // b

c = a % b
assert c.Num == a.Num % b

c = a & b
assert c.Num == a.Num & b

c = a | b
assert c.Num == a.Num | b

c = a ^ b
assert c.Num == a.Num ^ b

c = a == b
assert c == (a.Num == b)

c = a != b
assert c == (a.Num != b)

c = a <= b
assert c == (a.Num <= b)

c = a >= b
assert c == (a.Num >= b)

c = a < b
assert c == (a.Num < b)

c = a > b
assert c == (a.Num > b)
");
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/ClassManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -546,7 +546,7 @@ private static ClassInfo GetClassInfo(Type type, ClassBase impl)
ci.members[pyName] = new MethodObject(type, name, forwardMethods).AllocObject();
// Only methods where only the right operand is the declaring type.
if (reverseMethods.Length > 0)
ci.members[pyNameReverse] = new MethodObject(type, name, reverseMethods).AllocObject();
ci.members[pyNameReverse] = new MethodObject(type, name, reverseMethods, reverse_args: true).AllocObject();
}
}

Expand Down
40 changes: 24 additions & 16 deletions src/runtime/MethodBinder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,22 @@ internal class MethodBinder

[NonSerialized]
public bool init = false;

public const bool DefaultAllowThreads = true;
public bool allow_threads = DefaultAllowThreads;

internal MethodBinder()
public bool args_reversed = false;

internal MethodBinder(bool reverse_args = false)
{
list = new List<MaybeMethodBase>();
args_reversed = reverse_args;
}

internal MethodBinder(MethodInfo mi)
internal MethodBinder(MethodInfo mi, bool reverse_args = false)
{
list = new List<MaybeMethodBase> { new MaybeMethodBase(mi) };
args_reversed = reverse_args;
}

public int Count
Expand Down Expand Up @@ -271,10 +276,11 @@ internal static int ArgPrecedence(Type t)
/// <param name="inst">The Python target of the method invocation.</param>
/// <param name="args">The Python arguments.</param>
/// <param name="kw">The Python keyword arguments.</param>
/// <param name="reverse_args">Reverse arguments of methods. Used for methods such as __radd__, __rsub__, __rmod__ etc</param>
/// <returns>A Binding if successful. Otherwise null.</returns>
internal Binding? Bind(BorrowedReference inst, BorrowedReference args, BorrowedReference kw)
internal Binding? Bind(BorrowedReference inst, BorrowedReference args, BorrowedReference kw, bool reverse_args = false)
{
return Bind(inst, args, kw, null, null);
return Bind(inst, args, kw, null, null, reverse_args);
}

/// <summary>
Expand All @@ -287,10 +293,11 @@ internal static int ArgPrecedence(Type t)
/// <param name="args">The Python arguments.</param>
/// <param name="kw">The Python keyword arguments.</param>
/// <param name="info">If not null, only bind to that method.</param>
/// <param name="reverse_args">Reverse arguments of methods. Used for methods such as __radd__, __rsub__, __rmod__ etc</param>
/// <returns>A Binding if successful. Otherwise null.</returns>
internal Binding? Bind(BorrowedReference inst, BorrowedReference args, BorrowedReference kw, MethodBase? info)
internal Binding? Bind(BorrowedReference inst, BorrowedReference args, BorrowedReference kw, MethodBase? info, bool reverse_args = false)
{
return Bind(inst, args, kw, info, null);
return Bind(inst, args, kw, info, null, reverse_args);
}

private readonly struct MatchedMethod
Expand Down Expand Up @@ -334,8 +341,9 @@ public MismatchedMethod(Exception exception, MethodBase mb)
/// <param name="kw">The Python keyword arguments.</param>
/// <param name="info">If not null, only bind to that method.</param>
/// <param name="methodinfo">If not null, additionally attempt to bind to the generic methods in this array by inferring generic type parameters.</param>
/// <param name="reverse_args">Reverse arguments of methods. Used for methods such as __radd__, __rsub__, __rmod__ etc</param>
/// <returns>A Binding if successful. Otherwise null.</returns>
internal Binding? Bind(BorrowedReference inst, BorrowedReference args, BorrowedReference kw, MethodBase? info, MethodBase[]? methodinfo)
internal Binding? Bind(BorrowedReference inst, BorrowedReference args, BorrowedReference kw, MethodBase? info, MethodBase[]? methodinfo, bool reverse_args = false)
{
// loop to find match, return invoker w/ or w/o error
var kwargDict = new Dictionary<string, PyObject>();
Expand Down Expand Up @@ -363,10 +371,10 @@ public MismatchedMethod(Exception exception, MethodBase mb)
_methods = GetMethods();
}

return Bind(inst, args, kwargDict, _methods, matchGenerics: true);
return Bind(inst, args, kwargDict, _methods, matchGenerics: true, reverse_args);
}

static Binding? Bind(BorrowedReference inst, BorrowedReference args, Dictionary<string, PyObject> kwargDict, MethodBase[] methods, bool matchGenerics)
private static Binding? Bind(BorrowedReference inst, BorrowedReference args, Dictionary<string, PyObject> kwargDict, MethodBase[] methods, bool matchGenerics, bool reversed = false)
{
var pynargs = (int)Runtime.PyTuple_Size(args);
var isGeneric = false;
Expand All @@ -386,7 +394,7 @@ public MismatchedMethod(Exception exception, MethodBase mb)
// Binary operator methods will have 2 CLR args but only one Python arg
// (unary operators will have 1 less each), since Python operator methods are bound.
isOperator = isOperator && pynargs == pi.Length - 1;
bool isReverse = isOperator && OperatorMethod.IsReverse((MethodInfo)mi); // Only cast if isOperator.
bool isReverse = isOperator && reversed; // Only cast if isOperator.
if (isReverse && OperatorMethod.IsComparisonOp((MethodInfo)mi))
continue; // Comparison operators in Python have no reverse mode.
if (!MatchesArgumentCount(pynargs, pi, kwargDict, out bool paramsArray, out ArrayList? defaultArgList, out int kwargsMatched, out int defaultsNeeded) && !isOperator)
Expand Down Expand Up @@ -809,14 +817,14 @@ static bool MatchesArgumentCount(int positionalArgumentCount, ParameterInfo[] pa
return match;
}

internal virtual NewReference Invoke(BorrowedReference inst, BorrowedReference args, BorrowedReference kw)
internal virtual NewReference Invoke(BorrowedReference inst, BorrowedReference args, BorrowedReference kw, bool reverse_args = false)
{
return Invoke(inst, args, kw, null, null);
return Invoke(inst, args, kw, null, null, reverse_args);
}

internal virtual NewReference Invoke(BorrowedReference inst, BorrowedReference args, BorrowedReference kw, MethodBase? info)
internal virtual NewReference Invoke(BorrowedReference inst, BorrowedReference args, BorrowedReference kw, MethodBase? info, bool reverse_args = false)
{
return Invoke(inst, args, kw, info, null);
return Invoke(inst, args, kw, info, null, reverse_args = false);
}

protected static void AppendArgumentTypes(StringBuilder to, BorrowedReference args)
Expand Down Expand Up @@ -852,7 +860,7 @@ protected static void AppendArgumentTypes(StringBuilder to, BorrowedReference ar
to.Append(')');
}

internal virtual NewReference Invoke(BorrowedReference inst, BorrowedReference args, BorrowedReference kw, MethodBase? info, MethodBase[]? methodinfo)
internal virtual NewReference Invoke(BorrowedReference inst, BorrowedReference args, BorrowedReference kw, MethodBase? info, MethodBase[]? methodinfo, bool reverse_args = false)
{
// No valid methods, nothing to bind.
if (GetMethods().Length == 0)
Expand All @@ -865,7 +873,7 @@ internal virtual NewReference Invoke(BorrowedReference inst, BorrowedReference a
return Exceptions.RaiseTypeError(msg.ToString());
}

Binding? binding = Bind(inst, args, kw, info, methodinfo);
Binding? binding = Bind(inst, args, kw, info, methodinfo, reverse_args);
object result;
IntPtr ts = IntPtr.Zero;

Expand Down
Loading