Skip to content

Make Fraction more subclassing-friendly #136096

Open
@skirpichev

Description

@skirpichev

Feature or enhancement

Proposal:

Currently, most methods of the Fraction class subclasses return an instance of the Fraction class. That happens for arithmetic methods as well:

>>> from fractions import Fraction
>>> class MyFraction(Fraction):
...     pass
... a = MyFraction(1, 2)
... b = MyFraction(2, 3)
... 
>>> a+b
Fraction(7, 6)

I would guess, it was intentional.

On another hand, this makes subclassing of the Fraction - less useful. For example, what if we want to use something other than builtin int's for components of a fraction? Currently, it's possible, but... not quite:

>>> from gmpy2 import mpz
>>> from fractions import Fraction
>>> class mpq(Fraction):
...     def __new__(cls, numerator=0, denominator=None):
...         self = super(mpq, cls).__new__(cls, numerator, denominator)
...         self._numerator = mpz(numerator)
...         self._denominator = mpz(denominator)
...         return self
...         
>>> a, b = mpq(1, 2), mpq(3, 4)
>>> c = a + b
>>> c._numerator  # subclass instances use fast integer arithmetic
mpz(5)
>>> c._denominator
mpz(4)
>>> c  # but it's still an instance of the Fraction!
Fraction(5, 4)

Attached patch fixes this.

To better support such subclasses, I think we could also add Fraction.gcd class attribute to override the math.gcd().

Of course, this is a compatibility break (patch breaks our CI tests). Though, looking on the GitHub search for subclasses of Fraction, I don't think this will really break some people code.

A quick patch
diff --git a/Lib/fractions.py b/Lib/fractions.py
index cb05ae7c20..d81be2f90e 100644
--- a/Lib/fractions.py
+++ b/Lib/fractions.py
@@ -411,8 +411,9 @@ def limit_denominator(self, max_denominator=1000000):
 
         if max_denominator < 1:
             raise ValueError("max_denominator should be at least 1")
+        cls = self.__class__
         if self._denominator <= max_denominator:
-            return Fraction(self)
+            return cls(self)
 
         p0, q0, p1, q1 = 0, 1, 1, 0
         n, d = self._numerator, self._denominator
@@ -430,9 +431,9 @@ def limit_denominator(self, max_denominator=1000000):
         # the distance from p1/q1 to self is d/(q1*self._denominator). So we
         # need to compare 2*(q0+k*q1) with self._denominator/d.
         if 2*d*(q0+k*q1) <= self._denominator:
-            return Fraction._from_coprime_ints(p1, q1)
+            return cls._from_coprime_ints(p1, q1)
         else:
-            return Fraction._from_coprime_ints(p0+k*p1, q0+k*q1)
+            return cls._from_coprime_ints(p0+k*p1, q0+k*q1)
 
     @property
     def numerator(a):
@@ -782,38 +783,41 @@ def reverse(b, a):
 
     def _add(a, b):
         """a + b"""
+        cls = a.__class__
         na, da = a._numerator, a._denominator
         nb, db = b._numerator, b._denominator
         g = math.gcd(da, db)
         if g == 1:
-            return Fraction._from_coprime_ints(na * db + da * nb, da * db)
+            return cls._from_coprime_ints(na * db + da * nb, da * db)
         s = da // g
         t = na * (db // g) + nb * s
         g2 = math.gcd(t, g)
         if g2 == 1:
-            return Fraction._from_coprime_ints(t, s * db)
-        return Fraction._from_coprime_ints(t // g2, s * (db // g2))
+            return cls._from_coprime_ints(t, s * db)
+        return cls._from_coprime_ints(t // g2, s * (db // g2))
 
     __add__, __radd__ = _operator_fallbacks(_add, operator.add)
 
     def _sub(a, b):
         """a - b"""
+        cls = a.__class__
         na, da = a._numerator, a._denominator
         nb, db = b._numerator, b._denominator
         g = math.gcd(da, db)
         if g == 1:
-            return Fraction._from_coprime_ints(na * db - da * nb, da * db)
+            return cls._from_coprime_ints(na * db - da * nb, da * db)
         s = da // g
         t = na * (db // g) - nb * s
         g2 = math.gcd(t, g)
         if g2 == 1:
-            return Fraction._from_coprime_ints(t, s * db)
-        return Fraction._from_coprime_ints(t // g2, s * (db // g2))
+            return cls._from_coprime_ints(t, s * db)
+        return cls._from_coprime_ints(t // g2, s * (db // g2))
 
     __sub__, __rsub__ = _operator_fallbacks(_sub, operator.sub)
 
     def _mul(a, b):
         """a * b"""
+        cls = a.__class__
         na, da = a._numerator, a._denominator
         nb, db = b._numerator, b._denominator
         g1 = math.gcd(na, db)
@@ -824,13 +828,14 @@ def _mul(a, b):
         if g2 > 1:
             nb //= g2
             da //= g2
-        return Fraction._from_coprime_ints(na * nb, db * da)
+        return cls._from_coprime_ints(na * nb, db * da)
 
     __mul__, __rmul__ = _operator_fallbacks(_mul, operator.mul)
 
     def _div(a, b):
         """a / b"""
         # Same as _mul(), with inversed b.
+        cls = a.__class__
         nb, db = b._numerator, b._denominator
         if nb == 0:
             raise ZeroDivisionError('Fraction(%s, 0)' % db)
@@ -846,7 +851,7 @@ def _div(a, b):
         n, d = na * db, nb * da
         if d < 0:
             n, d = -n, -d
-        return Fraction._from_coprime_ints(n, d)
+        return cls._from_coprime_ints(n, d)
 
     __truediv__, __rtruediv__ = _operator_fallbacks(_div, operator.truediv)
 
@@ -858,6 +863,7 @@ def _floordiv(a, b):
 
     def _divmod(a, b):
         """(a // b, a % b)"""
+        cls = a.__class__
         da, db = a.denominator, b.denominator
         div, n_mod = divmod(a.numerator * db, da * b.numerator)
         return div, Fraction(n_mod, da * db)
@@ -866,8 +872,9 @@ def _divmod(a, b):
 
     def _mod(a, b):
         """a % b"""
+        cls = a.__class__
         da, db = a.denominator, b.denominator
-        return Fraction((a.numerator * db) % (b.numerator * da), da * db)
+        return cls((a.numerator * db) % (b.numerator * da), da * db)
 
     __mod__, __rmod__ = _operator_fallbacks(_mod, operator.mod, False)
 
@@ -881,21 +888,22 @@ def __pow__(a, b, modulo=None):
         """
         if modulo is not None:
             return NotImplemented
+        cls = a.__class__
         if isinstance(b, numbers.Rational):
             if b.denominator == 1:
                 power = b.numerator
                 if power >= 0:
-                    return Fraction._from_coprime_ints(a._numerator ** power,
-                                                       a._denominator ** power)
+                    return cls._from_coprime_ints(a._numerator ** power,
+                                                  a._denominator ** power)
                 elif a._numerator > 0:
-                    return Fraction._from_coprime_ints(a._denominator ** -power,
-                                                       a._numerator ** -power)
+                    return cls._from_coprime_ints(a._denominator ** -power,
+                                                  a._numerator ** -power)
                 elif a._numerator == 0:
                     raise ZeroDivisionError('Fraction(%s, 0)' %
                                             a._denominator ** -power)
                 else:
-                    return Fraction._from_coprime_ints((-a._denominator) ** -power,
-                                                       (-a._numerator) ** -power)
+                    return cls._from_coprime_ints((-a._denominator) ** -power,
+                                                  (-a._numerator) ** -power)
             else:
                 # A fractional power will generally produce an
                 # irrational number.
@@ -923,15 +931,18 @@ def __rpow__(b, a, modulo=None):
 
     def __pos__(a):
         """+a: Coerces a subclass instance to Fraction"""
-        return Fraction._from_coprime_ints(a._numerator, a._denominator)
+        cls = a.__class__
+        return cls._from_coprime_ints(a._numerator, a._denominator)
 
     def __neg__(a):
         """-a"""
-        return Fraction._from_coprime_ints(-a._numerator, a._denominator)
+        cls = a.__class__
+        return cls._from_coprime_ints(-a._numerator, a._denominator)
 
     def __abs__(a):
         """abs(a)"""
-        return Fraction._from_coprime_ints(abs(a._numerator), a._denominator)
+        cls = a.__class__
+        return cls._from_coprime_ints(abs(a._numerator), a._denominator)
 
     def __int__(a, _index=operator.index):
         """int(a)"""
@@ -977,10 +988,11 @@ def __round__(self, ndigits=None):
         # See _operator_fallbacks.forward to check that the results of
         # these operations will always be Fraction and therefore have
         # round().
+        cls = self.__class__
         if ndigits > 0:
-            return Fraction(round(self * shift), shift)
+            return cls(round(self * shift), shift)
         else:
-            return Fraction(round(self / shift) * shift)
+            return cls(round(self / shift) * shift)
 
     def __hash__(self):
         """hash(self)"""
diff --git a/Lib/test/test_fractions.py b/Lib/test/test_fractions.py
index d1d2739856..0e57478f1d 100644
--- a/Lib/test/test_fractions.py
+++ b/Lib/test/test_fractions.py
@@ -851,7 +851,7 @@ def testMixedMultiplication(self):
         self.assertTypedEquals(0.1 + 0j, (1.0 + 0j) * F(1, 10))
 
         self.assertTypedEquals(F(3, 2) * DummyFraction(5, 3), F(5, 2))
-        self.assertTypedEquals(DummyFraction(5, 3) * F(3, 2), F(5, 2))
+        self.assertTypedEquals(DummyFraction(5, 3) * F(3, 2), DummyFraction(5, 2))
         self.assertTypedEquals(F(3, 2) * Rat(5, 3), Rat(15, 6))
         self.assertTypedEquals(Rat(5, 3) * F(3, 2), F(5, 2))
 
@@ -881,7 +881,7 @@ def testMixedDivision(self):
         self.assertTypedEquals(10.0 + 0j, (1.0 + 0j) / F(1, 10))
 
         self.assertTypedEquals(F(3, 2) / DummyFraction(3, 5), F(5, 2))
-        self.assertTypedEquals(DummyFraction(5, 3) / F(2, 3), F(5, 2))
+        self.assertTypedEquals(DummyFraction(5, 3) / F(2, 3), DummyFraction(5, 2))
         self.assertTypedEquals(F(3, 2) / Rat(3, 5), Rat(15, 6))
         self.assertTypedEquals(Rat(5, 3) / F(2, 3), F(5, 2))
 
@@ -927,7 +927,7 @@ def testMixedIntegerDivision(self):
         self.assertTypedTupleEquals(divmod(-0.1, float('-inf')), divmod(F(-1, 10), float('-inf')))
 
         self.assertTypedEquals(F(3, 2) % DummyFraction(3, 5), F(3, 10))
-        self.assertTypedEquals(DummyFraction(5, 3) % F(2, 3), F(1, 3))
+        self.assertTypedEquals(DummyFraction(5, 3) % F(2, 3), DummyFraction(1, 3))
         self.assertTypedEquals(F(3, 2) % Rat(3, 5), Rat(3, 6))
         self.assertTypedEquals(Rat(5, 3) % F(2, 3), F(1, 3))

Has this already been discussed elsewhere?

No response given

Links to previous discussion of this feature:

No response

Metadata

Metadata

Assignees

Labels

pendingThe issue will be closed if no feedback is providedstdlibPython modules in the Lib dirtype-featureA feature request or enhancement

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions