From 8d7a1ea1f99fb1036f3b828967ce787da1aa1c50 Mon Sep 17 00:00:00 2001 From: RunasSudo Date: Sun, 3 Jan 2021 18:39:32 +1100 Subject: [PATCH] Optimisations for number handling DRY number classes Implement __iadd__, etc. for better performance --- pyRCV2/numbers/base.py | 242 ++++++++++++++++++++++++++++++++++ pyRCV2/numbers/fixed_js.py | 68 +++++----- pyRCV2/numbers/fixed_py.py | 75 ++--------- pyRCV2/numbers/native_js.py | 70 ++++++---- pyRCV2/numbers/native_py.py | 78 ++--------- pyRCV2/numbers/rational_js.py | 81 ++++++------ pyRCV2/numbers/rational_py.py | 83 ++---------- 7 files changed, 387 insertions(+), 310 deletions(-) create mode 100644 pyRCV2/numbers/base.py diff --git a/pyRCV2/numbers/base.py b/pyRCV2/numbers/base.py new file mode 100644 index 0000000..4c01801 --- /dev/null +++ b/pyRCV2/numbers/base.py @@ -0,0 +1,242 @@ +# pyRCV2: Preferential vote counting +# Copyright © 2020–2021 Lee Yingtong Li (RunasSudo) +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +__pragma__ = lambda x: None +is_py = False +__pragma__('skip') +is_py = True +import functools +import math +__pragma__('noskip') + +def compatible_types(f): + if is_py: + __pragma__('skip') + import os + if 'PYTEST_CURRENT_TEST' in os.environ: + @functools.wraps(f) + def wrapper(self, other): + if not isinstance(other, self.__class__): + raise ValueError('Attempt to operate on incompatible types') + return f(self, other) + return wrapper + else: + return f + __pragma__('noskip') + else: + # FIXME: Do we need to perform type checking in JS? + return f + +class BaseNum: + __slots__ = ['impl'] # Optimisation to reduce overhead of initialising new object + + # These enum values may be overridden in subclasses depending on underlying library + ROUND_DOWN = 0 + ROUND_HALF_UP = 1 + ROUND_HALF_EVEN = 2 + ROUND_UP = 3 + + def __init__(self, value): + if isinstance(value, self.__class__): + self.impl = value.impl + else: + self.impl = self._to_impl(value) + + @classmethod + def _to_impl(cls, value): + """ + Internal use: Convert the given value to an impl + Subclasses must override this method + """ + raise NotImplementedError('Method not implemented') + + @classmethod + def _from_impl(cls, impl): + """Internal use: Return an instance directly from the given impl without performing checks""" + if is_py: + obj = cls.__new__(cls) + else: + # Transcrypt's __new__ (incorrectly) calls the constructor + obj = __pragma__('js', '{}', 'Object.create (cls, {__class__: {value: cls, enumerable: true}})') + obj.impl = impl + return obj + + def pp(self, dp): + """ + Pretty print to specified number of decimal places + Subclasses must override this method + """ + raise NotImplementedError('Method not implemented') + + # Implementation of arithmetic on impls + # Subclasses must override these functions: + + @classmethod + def _add_impl(cls, i1, i2): + raise NotImplementedError('Method not implemented') + @classmethod + def _sub_impl(cls, i1, i2): + raise NotImplementedError('Method not implemented') + @classmethod + def _mul_impl(cls, i1, i2): + raise NotImplementedError('Method not implemented') + @classmethod + def _truediv_impl(cls, i1, i2): + raise NotImplementedError('Method not implemented') + + @compatible_types + def __eq__(self, other): + raise NotImplementedError('Method not implemented') + @compatible_types + def __ne__(self, other): + return not (self.__eq__(other)) + @compatible_types + def __gt__(self, other): + raise NotImplementedError('Method not implemented') + @compatible_types + def __ge__(self, other): + raise NotImplementedError('Method not implemented') + @compatible_types + def __lt__(self, other): + raise NotImplementedError('Method not implemented') + @compatible_types + def __le__(self, other): + raise NotImplementedError('Method not implemented') + + def round(self, dps, mode): + """ + Round to the specified number of decimal places, using the ROUND_* mode specified + Subclasses must override this method + """ + raise NotImplementedError('Method not implemented') + + # Implement various data model functions based on _*_impl + + @compatible_types + def __add__(self, other): + return self._from_impl(self._add_impl(self.impl, other.impl)) + @compatible_types + def __sub__(self, other): + return self._from_impl(self._sub_impl(self.impl, other.impl)) + @compatible_types + def __mul__(self, other): + return self._from_impl(self._mul_impl(self.impl, other.impl)) + @compatible_types + def __truediv__(self, other): + return self._from_impl(self._truediv_impl(self.impl, other.impl)) + + @compatible_types + def __iadd__(self, other): + self.impl = self._add_impl(self.impl, other.impl) + return self + @compatible_types + def __isub__(self, other): + self.impl = self._sub_impl(self.impl, other.impl) + return self + @compatible_types + def __imul__(self, other): + self.impl = self._mul_impl(self.impl, other.impl) + return self + @compatible_types + def __itruediv__(self, other): + self.impl = self._truediv_impl(self.impl, other.impl) + return self + + def __floor__(self): + return self.round(0, self.ROUND_DOWN) + +class BasePyNum(BaseNum): + """Helper class for Num wrappers of Python objects that already implement overloading""" + + _py_class = None # Subclasses must specify + + @classmethod + def _to_impl(cls, value): + """Implements BaseNum._to_impl""" + return cls._py_class(value) + + def pp(self, dp): + """Implements BaseNum.pp""" + return format(self.impl, '.{}f'.format(dp)) + + @classmethod + def _add_impl(cls, i1, i2): + """Implements BaseNum._add_impl""" + return i1 + i2 + @classmethod + def _sub_impl(cls, i1, i2): + """Implements BaseNum._sub_impl""" + return i1 - i2 + @classmethod + def _mul_impl(cls, i1, i2): + """Implements BaseNum._mul_impl""" + return i1 * i2 + @classmethod + def _truediv_impl(cls, i1, i2): + """Implements BaseNum._truediv_impl""" + return i1 / i2 + + @compatible_types + def __eq__(self, other): + """Implements BaseNum.__eq__""" + return self.impl == other.impl + @compatible_types + def __ne__(self, other): + """Overrides BaseNum.__ne__""" + return self.impl != other.impl + @compatible_types + def __gt__(self, other): + """Implements BaseNum.__gt__""" + return self.impl > other.impl + @compatible_types + def __ge__(self, other): + """Implements BaseNum.__ge__""" + return self.impl >= other.impl + @compatible_types + def __lt__(self, other): + """Implements BaseNum.__lt__""" + return self.impl < other.impl + @compatible_types + def __le__(self, other): + """Implements BaseNum.__le__""" + return self.impl <= other.impl + + @compatible_types + def __iadd__(self, other): + """Overrides BaseNum.__iadd__""" + self.impl += other.impl + return self + @compatible_types + def __isub__(self, other): + """Overrides BaseNum.__isub__""" + self.impl -= other.impl + return self + @compatible_types + def __imul__(self, other): + """Overrides BaseNum.__imul__""" + self.impl *= other.impl + return self + @compatible_types + def __itruediv__(self, other): + """Overrides BaseNum.__itruediv__""" + self.impl /= other.impl + return self + + def __floor__(self): + return self._from_impl(math.floor(self.impl)) + + def __repr__(self): + return '<{} {}>'.format(self.__class__.__name__, str(self.impl)) diff --git a/pyRCV2/numbers/fixed_js.py b/pyRCV2/numbers/fixed_js.py index 4384af4..c0a4ccf 100644 --- a/pyRCV2/numbers/fixed_js.py +++ b/pyRCV2/numbers/fixed_js.py @@ -1,5 +1,5 @@ # pyRCV2: Preferential vote counting -# Copyright © 2020 Lee Yingtong Li (RunasSudo) +# Copyright © 2020–2021 Lee Yingtong Li (RunasSudo) # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -14,61 +14,65 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . +from pyRCV2.numbers.base import BaseNum, compatible_types + Big.DP = 6 def set_dps(dps): Big.DP = dps -class Fixed: +class Fixed(BaseNum): """ Wrapper for big.js (fixed-point arithmetic) """ - ROUND_DOWN = 0 - ROUND_HALF_UP = 1 - ROUND_HALF_EVEN = 2 - ROUND_UP = 3 - - def __init__(self, val): - if isinstance(val, Fixed): - self.impl = val.impl - else: - self.impl = Big(val).round(Big.DP) + @classmethod + def _to_impl(cls, value): + """Implements BaseNum._to_impl""" + return Big(value).round(Big.DP) def pp(self, dp): - """Pretty print to specified number of decimal places""" + """Implements BaseNum.pp""" return self.impl.toFixed(dp) - def to_rational(self): - """Convert to an instance of Rational""" - from pyRCV2.numbers import Rational - return Rational(self.impl.toString()) - - def __add__(self, other): - return Fixed(self.impl.plus(other.impl)) - def __sub__(self, other): - return Fixed(self.impl.minus(other.impl)) - def __mul__(self, other): - return Fixed(self.impl.times(other.impl)) - def __div__(self, other): - return Fixed(self.impl.div(other.impl)) + @classmethod + def _add_impl(cls, i1, i2): + """Implements BaseNum._add_impl""" + return i1.plus(i2) + @classmethod + def _sub_impl(cls, i1, i2): + """Implements BaseNum._sub_impl""" + return i1.minus(i2) + @classmethod + def _mul_impl(cls, i1, i2): + """Implements BaseNum._mul_impl""" + return i1.times(i2) + @classmethod + def _truediv_impl(cls, i1, i2): + """Implements BaseNum._truediv_impl""" + return i1.div(i2) + @compatible_types def __eq__(self, other): + """Implements BaseNum.__eq__""" return self.impl.eq(other.impl) - def __ne__(self, other): - return not self.impl.eq(other.impl) + @compatible_types def __gt__(self, other): + """Implements BaseNum.__gt__""" return self.impl.gt(other.impl) + @compatible_types def __ge__(self, other): + """Implements BaseNum.__ge__""" return self.impl.gte(other.impl) + @compatible_types def __lt__(self, other): + """Implements BaseNum.__lt__""" return self.impl.lt(other.impl) + @compatible_types def __le__(self, other): + """Implements BaseNum.__le__""" return self.impl.lte(other.impl) - def __floor__(self): - return self.round(0, Fixed.ROUND_DOWN) - def round(self, dps, mode): - """Round to the specified number of decimal places, using the ROUND_* mode specified""" + """Implements BaseNum.round""" return Fixed(self.impl.round(dps, mode)) diff --git a/pyRCV2/numbers/fixed_py.py b/pyRCV2/numbers/fixed_py.py index 68fea8e..0a23efc 100644 --- a/pyRCV2/numbers/fixed_py.py +++ b/pyRCV2/numbers/fixed_py.py @@ -1,5 +1,5 @@ # pyRCV2: Preferential vote counting -# Copyright © 2020 Lee Yingtong Li (RunasSudo) +# Copyright © 2020–2021 Lee Yingtong Li (RunasSudo) # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -14,9 +14,9 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . +from pyRCV2.numbers.base import BasePyNum, compatible_types + import decimal -import functools -import math _quantize_exp = 6 @@ -24,76 +24,23 @@ def set_dps(dps): global _quantize_exp _quantize_exp = decimal.Decimal('10') ** -dps -def compatible_types(f): - @functools.wraps(f) - def wrapper(self, other): - if not isinstance(other, Fixed): - raise ValueError('Attempt to operate on incompatible types') - return f(self, other) - return wrapper - -class Fixed: +class Fixed(BasePyNum): """ Wrapper for Python Decimal (for fixed-point arithmetic) """ + _py_class = decimal.Decimal # For BasePyNum + ROUND_DOWN = decimal.ROUND_DOWN ROUND_HALF_UP = decimal.ROUND_HALF_UP ROUND_HALF_EVEN = decimal.ROUND_HALF_EVEN ROUND_UP = decimal.ROUND_UP - def __init__(self, val): - if isinstance(val, Fixed): - self.impl = val.impl - else: - self.impl = decimal.Decimal(val).quantize(_quantize_exp) - - def __repr__(self): - return ''.format(str(self.impl)) - def pp(self, dp): - """Pretty print to specified number of decimal places""" - return format(self.impl, '.{}f'.format(dp)) - - def to_rational(self): - """Convert to an instance of Rational""" - from pyRCV2.numbers import Rational - return Rational(self.impl) - - @compatible_types - def __add__(self, other): - return Fixed(self.impl + other.impl) - @compatible_types - def __sub__(self, other): - return Fixed(self.impl - other.impl) - @compatible_types - def __mul__(self, other): - return Fixed(self.impl * other.impl) - @compatible_types - def __truediv__(self, other): - return Fixed(self.impl / other.impl) - - @compatible_types - def __eq__(self, other): - return self.impl == other.impl - @compatible_types - def __ne__(self, other): - return self.impl != other.impl - @compatible_types - def __gt__(self, other): - return self.impl > other.impl - @compatible_types - def __ge__(self, other): - return self.impl >= other.impl - @compatible_types - def __lt__(self, other): - return self.impl < other.impl - @compatible_types - def __le__(self, other): - return self.impl <= other.impl - - def __floor__(self): - return Fixed(math.floor(self.impl)) + @classmethod + def _to_impl(cls, value): + """Overrides BasePyNum._to_impl""" + return decimal.Decimal(value).quantize(_quantize_exp) def round(self, dps, mode): - """Round to the specified number of decimal places, using the ROUND_* mode specified""" + """Implements BaseNum.round""" return Fixed(self.impl.quantize(decimal.Decimal('10') ** -dps, mode)) diff --git a/pyRCV2/numbers/native_js.py b/pyRCV2/numbers/native_js.py index 341c7bd..54a0b6c 100644 --- a/pyRCV2/numbers/native_js.py +++ b/pyRCV2/numbers/native_js.py @@ -1,5 +1,5 @@ # pyRCV2: Preferential vote counting -# Copyright © 2020 Lee Yingtong Li (RunasSudo) +# Copyright © 2020–2021 Lee Yingtong Li (RunasSudo) # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -14,65 +14,77 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -class Native: +from pyRCV2.numbers.base import BaseNum, compatible_types + +class Native(BaseNum): """ Wrapper for JS numbers (naive floating-point arithmetic) """ - ROUND_DOWN = 0 - ROUND_HALF_UP = 1 - ROUND_HALF_EVEN = 2 - ROUND_UP = 3 - - def __init__(self, val): - if isinstance(val, Native): - self.impl = val.impl - else: - self.impl = parseFloat(val) + @classmethod + def _to_impl(cls, value): + """Implements BaseNum._to_impl""" + return parseFloat(value) def pp(self, dp): - """Pretty print to specified number of decimal places""" + """Implements BaseNum.pp""" return self.impl.toFixed(dp) - def to_rational(self): - """Convert to an instance of Rational""" - from pyRCV2.numbers import Rational - return Rational(self.impl) - - def __add__(self, other): - return Native(self.impl + other.impl) - def __sub__(self, other): - return Native(self.impl - other.impl) - def __mul__(self, other): - return Native(self.impl * other.impl) - def __div__(self, other): - return Native(self.impl / other.impl) + @classmethod + def _add_impl(cls, i1, i2): + """Implements BaseNum._add_impl""" + return i1 + i2 + @classmethod + def _sub_impl(cls, i1, i2): + """Implements BaseNum._sub_impl""" + return i1 - i2 + @classmethod + def _mul_impl(cls, i1, i2): + """Implements BaseNum._mul_impl""" + return i1 * i2 + @classmethod + def _truediv_impl(cls, i1, i2): + """Implements BaseNum._truediv_impl""" + return i1 / i2 + @compatible_types def __eq__(self, other): + """Implements BaseNum.__eq__""" return self.impl == other.impl + @compatible_types def __ne__(self, other): + """Overrides BaseNum.__ne__""" return self.impl != other.impl + @compatible_types def __gt__(self, other): + """Implements BaseNum.__gt__""" return self.impl > other.impl + @compatible_types def __ge__(self, other): + """Implements BaseNum.__ge__""" return self.impl >= other.impl + @compatible_types def __lt__(self, other): + """Implements BaseNum.__lt__""" return self.impl < other.impl + @compatible_types def __le__(self, other): + """Implements BaseNum.__le__""" return self.impl <= other.impl def __floor__(self): + """Overrides BaseNum.__floor__""" return Native(Math.floor(self.impl)) def round(self, dps, mode): - """Round to the specified number of decimal places, using the ROUND_* mode specified""" + """Implements BaseNum.round""" if mode == Native.ROUND_DOWN: return Native(Math.floor(self.impl * Math.pow(10, dps)) / Math.pow(10, dps)) elif mode == Native.ROUND_HALF_UP: return Native(Math.round(self.impl * Math.pow(10, dps)) / Math.pow(10, dps)) elif mode == Native.ROUND_HALF_EVEN: - raise Exception('ROUND_HALF_EVEN is not implemented in JS Native context') + raise NotImplementedError('ROUND_HALF_EVEN is not implemented in JS Native context') elif mode == Native.ROUND_UP: return Native(Math.ceil(self.impl * Math.pow(10, dps)) / Math.pow(10, dps)) else: - raise Exception('Invalid rounding mode') + raise ValueError('Invalid rounding mode') diff --git a/pyRCV2/numbers/native_py.py b/pyRCV2/numbers/native_py.py index 778e9d0..d4c9e07 100644 --- a/pyRCV2/numbers/native_py.py +++ b/pyRCV2/numbers/native_py.py @@ -1,5 +1,5 @@ # pyRCV2: Preferential vote counting -# Copyright © 2020 Lee Yingtong Li (RunasSudo) +# Copyright © 2020–2021 Lee Yingtong Li (RunasSudo) # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -14,89 +14,27 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -import functools +from pyRCV2.numbers.base import BasePyNum, compatible_types + import math -def compatible_types(f): - @functools.wraps(f) - def wrapper(self, other): - if not isinstance(other, Native): - raise ValueError('Attempt to operate on incompatible types') - return f(self, other) - return wrapper - -class Native: +class Native(BasePyNum): """ Wrapper for Python float (naive floating-point arithmetic) """ - ROUND_DOWN = 0 - ROUND_HALF_UP = 1 - ROUND_HALF_EVEN = 2 - ROUND_UP = 3 - - def __init__(self, val): - if isinstance(val, Native): - self.impl = val.impl - else: - self.impl = float(val) - - def __repr__(self): - return ''.format(str(self.impl)) - def pp(self, dp): - """Pretty print to specified number of decimal places""" - return format(self.impl, '.{}f'.format(dp)) - - def to_rational(self): - """Convert to an instance of Rational""" - from pyRCV2.numbers import Rational - return Rational(self.impl) - - @compatible_types - def __add__(self, other): - return Native(self.impl + other.impl) - @compatible_types - def __sub__(self, other): - return Native(self.impl - other.impl) - @compatible_types - def __mul__(self, other): - return Native(self.impl * other.impl) - @compatible_types - def __truediv__(self, other): - return Native(self.impl / other.impl) - - @compatible_types - def __eq__(self, other): - return self.impl == other.impl - @compatible_types - def __ne__(self, other): - return self.impl != other.impl - @compatible_types - def __gt__(self, other): - return self.impl > other.impl - @compatible_types - def __ge__(self, other): - return self.impl >= other.impl - @compatible_types - def __lt__(self, other): - return self.impl < other.impl - @compatible_types - def __le__(self, other): - return self.impl <= other.impl - - def __floor__(self): - return Native(math.floor(self.impl)) + _py_class = float # For BasePyNum def round(self, dps, mode): - """Round to the specified number of decimal places, using the ROUND_* mode specified""" + """Implements BaseNum.round""" factor = 10 ** dps if mode == Native.ROUND_DOWN: return Native(math.floor(self.impl * factor) / factor) elif mode == Native.ROUND_HALF_UP: - raise Exception('ROUND_HALF_UP is not implemented in Python Native context') + raise NotImplementedError('ROUND_HALF_UP is not implemented in Python Native context') elif mode == Native.ROUND_HALF_EVEN: return Native(round(self.impl * factor) / factor) elif mode == Native.ROUND_UP: return Native(math.ceil(self.impl * factor) / factor) else: - raise Exception('Invalid rounding mode') + raise ValueError('Invalid rounding mode') diff --git a/pyRCV2/numbers/rational_js.py b/pyRCV2/numbers/rational_js.py index b935510..2f76fc7 100644 --- a/pyRCV2/numbers/rational_js.py +++ b/pyRCV2/numbers/rational_js.py @@ -1,5 +1,5 @@ # pyRCV2: Preferential vote counting -# Copyright © 2020 Lee Yingtong Li (RunasSudo) +# Copyright © 2020–2021 Lee Yingtong Li (RunasSudo) # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -14,76 +14,75 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -class Rational: +from pyRCV2.numbers.base import BaseNum, compatible_types + +class Rational(BaseNum): """ Wrapper for BigRational.js (rational arithmetic) """ - ROUND_DOWN = 0 - ROUND_HALF_UP = 1 - ROUND_HALF_EVEN = 2 - ROUND_UP = 3 - - def __init__(self, val): - if isinstance(val, Rational): - self.impl = val.impl - else: - self.impl = bigRat(val) + @classmethod + def _to_impl(cls, value): + """Implements BaseNum._to_impl""" + return bigRat(value) def pp(self, dp): - """ - Pretty print to specified number of decimal places - This will fail for numbers which cannot be represented as a JavaScript number - """ + """Implements BaseNum.pp""" + # FIXME: This will fail for numbers which cannot be represented as a JavaScript number return self.impl.valueOf().toFixed(dp) - def to_rational(self): - return self - - def to_num(self): - """ - Convert to an instance of Num - """ - from pyRCV2.numbers import Num - __pragma__('opov') - return Num(self.impl.numerator.toString()) / Num(self.impl.denominator.toString()) - __pragma__('noopov') - - def __add__(self, other): - return Rational(self.impl.add(other.impl)) - def __sub__(self, other): - return Rational(self.impl.subtract(other.impl)) - def __mul__(self, other): - return Rational(self.impl.multiply(other.impl)) - def __div__(self, other): - return Rational(self.impl.divide(other.impl)) + @classmethod + def _add_impl(cls, i1, i2): + """Implements BaseNum._add_impl""" + return i1.add(i2) + @classmethod + def _sub_impl(cls, i1, i2): + """Implements BaseNum._sub_impl""" + return i1.subtract(i2) + @classmethod + def _mul_impl(cls, i1, i2): + """Implements BaseNum._mul_impl""" + return i1.multiply(i2) + @classmethod + def _truediv_impl(cls, i1, i2): + """Implements BaseNum._truediv_impl""" + return i1.divide(i2) + @compatible_types def __eq__(self, other): + """Implements BaseNum.__eq__""" return self.impl.equals(other.impl) - def __ne__(self, other): - return not self.impl.equals(other.impl) + @compatible_types def __gt__(self, other): + """Implements BaseNum.__gt__""" return self.impl.greater(other.impl) + @compatible_types def __ge__(self, other): + """Implements BaseNum.__ge__""" return self.impl.greaterOrEquals(other.impl) + @compatible_types def __lt__(self, other): + """Implements BaseNum.__lt__""" return self.impl.lesser(other.impl) + @compatible_types def __le__(self, other): + """Implements BaseNum.__le__""" return self.impl.lesserOrEquals(other.impl) def __floor__(self): + """Overrides BaseNum.__floor__""" return Rational(self.impl.floor()) def round(self, dps, mode): - """Round to the specified number of decimal places, using the ROUND_* mode specified""" + """Implements BaseNum.round""" factor = bigRat(10).pow(dps) if mode == Rational.ROUND_DOWN: return Rational(self.impl.multiply(factor).floor().divide(factor)) elif mode == Rational.ROUND_HALF_UP: return Rational(self.impl.multiply(factor).round().divide(factor)) elif mode == Rational.ROUND_HALF_EVEN: - raise Exception('ROUND_HALF_EVEN is not implemented in JS Native context') + raise NotImplementedError('ROUND_HALF_EVEN is not implemented in JS Native context') elif mode == Rational.ROUND_UP: return Rational(self.impl.multiply(factor).ceil().divide(factor)) else: - raise Exception('Invalid rounding mode') + raise ValueError('Invalid rounding mode') diff --git a/pyRCV2/numbers/rational_py.py b/pyRCV2/numbers/rational_py.py index 2155250..c7b8dce 100644 --- a/pyRCV2/numbers/rational_py.py +++ b/pyRCV2/numbers/rational_py.py @@ -1,5 +1,5 @@ # pyRCV2: Preferential vote counting -# Copyright © 2020 Lee Yingtong Li (RunasSudo) +# Copyright © 2020–2021 Lee Yingtong Li (RunasSudo) # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -14,98 +14,33 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . +from pyRCV2.numbers.base import BasePyNum, compatible_types + from fractions import Fraction -import functools import math -def compatible_types(f): - @functools.wraps(f) - def wrapper(self, other): - if not isinstance(other, Rational): - raise ValueError('Attempt to operate on incompatible types') - return f(self, other) - return wrapper - -class Rational: +class Rational(BasePyNum): """ Wrapper for Python Fraction (rational arithmetic) """ - ROUND_DOWN = 0 - ROUND_HALF_UP = 1 - ROUND_HALF_EVEN = 2 - ROUND_UP = 3 + _py_class = Fraction # For BasePyNum - def __init__(self, val): - if isinstance(val, Rational): - self.impl = val.impl - else: - self.impl = Fraction(val) - - def __repr__(self): - return ''.format(str(self.impl)) def pp(self, dp): - """ - Pretty print to specified number of decimal places - """ + """Overrides BasePyNum.pp""" # TODO: Work out if there is a better way of doing this return format(float(self.impl), '.{}f'.format(dp)) - def to_rational(self): - return self - - def to_num(self): - """ - Convert to an instance of Num - """ - from pyRCV2.numbers import Num - return Num(self.impl.numerator) / Num(self.impl.denominator) - - @compatible_types - def __add__(self, other): - return Rational(self.impl + other.impl) - @compatible_types - def __sub__(self, other): - return Rational(self.impl - other.impl) - @compatible_types - def __mul__(self, other): - return Rational(self.impl * other.impl) - @compatible_types - def __truediv__(self, other): - return Rational(self.impl / other.impl) - - @compatible_types - def __eq__(self, other): - return self.impl == other.impl - @compatible_types - def __ne__(self, other): - return self.impl != other.impl - @compatible_types - def __gt__(self, other): - return self.impl > other.impl - @compatible_types - def __ge__(self, other): - return self.impl >= other.impl - @compatible_types - def __lt__(self, other): - return self.impl < other.impl - @compatible_types - def __le__(self, other): - return self.impl <= other.impl - - def __floor__(self): - return Rational(math.floor(self.impl)) - def round(self, dps, mode): - """Round to the specified number of decimal places, using the ROUND_* mode specified""" + """Implements BaseNum.round""" factor = Fraction(10) ** dps if mode == Rational.ROUND_DOWN: return Rational(math.floor(self.impl * factor) / factor) elif mode == Rational.ROUND_HALF_UP: - raise Exception('ROUND_HALF_UP is not implemented in Python Rational context') + raise NotImplementedError('ROUND_HALF_UP is not implemented in Python Rational context') elif mode == Rational.ROUND_HALF_EVEN: return Rational(round(self.impl * factor) / factor) elif mode == Rational.ROUND_UP: return Rational(math.ceil(self.impl * factor) / factor) else: - raise Exception('Invalid rounding mode') + raise ValueError('Invalid rounding mode')