From cfc3ccc1864e09c59eb5b2efa9dce27a07b7d704 Mon Sep 17 00:00:00 2001 From: Philippe Proulx Date: Thu, 27 Jun 2019 23:51:24 -0400 Subject: [PATCH] bt2: value.py: refactor value comparison Changed: * Remove _spec_eq() methods: each class implements its own __eq__() method directly. * Do not use native_bt.value_compare(): we never reached that, because container value classes (`ArrayValue` and `MapValue`) implement their own rich, recursive comparison. * In _NumericValue._extract_value(), do not check the `_NumericValue`, `False`, and `True` special case: check for `BoolValue` and `bool` to return a boolean object in those cases. * In NumericValue.__lt__() and NumericValue.__eq__(), do not check that the parameter is a number object: self._extract_value() does this already. * In BoolValue._value_to_bool(), return the boolean value directly, not using int(): it's already an integral number. * In ArrayValue.__eq__(), be more strict: expect that the parameter is a sequence object, not just an iterable object. Before this, it was possible to compare an array value to an ordered dict with keys equal to the array value content, and this seems wrong as: collections.OrderedDict((('A', 23), ('B', 42))) == ['A', 'B'] is false. An ordered dict is not a sequence. * In MapValue.__eq__(), be more strict: expect that the parameter is a mapping object, not just an iterable and indexable object. The reason is similar to the ArrayValue.__eq__() case above. This should be enough to compare to another map value or to a dict (or ordered dict). Signed-off-by: Philippe Proulx Change-Id: I9941d2d82942e2efa8d5380c8ff5a4a2d2cb3a84 Reviewed-on: https://review.lttng.org/c/babeltrace/+/1563 Tested-by: jenkins --- src/bindings/python/bt2/bt2/value.py | 108 +++++++++------------------ 1 file changed, 36 insertions(+), 72 deletions(-) diff --git a/src/bindings/python/bt2/bt2/value.py b/src/bindings/python/bt2/bt2/value.py index b6fb6769..7f003c2b 100644 --- a/src/bindings/python/bt2/bt2/value.py +++ b/src/bindings/python/bt2/bt2/value.py @@ -89,31 +89,9 @@ class _Value(object._SharedObject, metaclass=abc.ABCMeta): _get_ref = staticmethod(native_bt.value_get_ref) _put_ref = staticmethod(native_bt.value_put_ref) - def __eq__(self, other): - if other is None: - # self is never the null value object - return False - - # try type-specific comparison first - spec_eq = self._spec_eq(other) - - if spec_eq is not None: - return spec_eq - - if not isinstance(other, _Value): - # not comparing apples to apples - return False - - # fall back to native comparison function - return native_bt.value_compare(self._ptr, other._ptr) - def __ne__(self, other): return not (self == other) - @abc.abstractmethod - def _spec_eq(self, other): - pass - def _handle_status(self, status): _handle_status(status, self._NAME) @@ -127,11 +105,8 @@ class _Value(object._SharedObject, metaclass=abc.ABCMeta): class _NumericValue(_Value): @staticmethod def _extract_value(other): - if isinstance(other, _NumericValue): - return other._value - - if other is True or other is False: - return other + if isinstance(other, BoolValue) or isinstance(other, bool): + return bool(other) if isinstance(other, numbers.Integral): return int(other) @@ -154,21 +129,14 @@ class _NumericValue(_Value): return repr(self._value) def __lt__(self, other): - if not isinstance(other, numbers.Number): - raise TypeError('unorderable types: {}() < {}()'.format(self.__class__.__name__, - other.__class__.__name__)) - return self._value < self._extract_value(other) - def _spec_eq(self, other): - pass - def __eq__(self, other): - if not isinstance(other, numbers.Number): + try: + return self._value == self._extract_value(other) + except: return False - return self._value == self._extract_value(other) - def __rmod__(self, other): return self._extract_value(other) % self._value @@ -329,9 +297,11 @@ class BoolValue(_Value): self._check_create_status(ptr) super().__init__(ptr) - def _spec_eq(self, other): - if isinstance(other, numbers.Number): - return self._value == bool(other) + def __eq__(self, other): + try: + return self._value == self._value_to_bool(other) + except: + return False def __bool__(self): return self._value @@ -346,7 +316,7 @@ class BoolValue(_Value): if not isinstance(value, bool): raise TypeError("'{}' object is not a 'bool' or 'BoolValue' object".format(value.__class__)) - return int(value) + return value @property def _value(self): @@ -462,11 +432,11 @@ class StringValue(collections.abc.Sequence, _Value): value = property(fset=_set_value) - def _spec_eq(self, other): + def __eq__(self, other): try: return self._value == self._value_to_str(other) except: - return + return False def __lt__(self, other): return self._value < self._value_to_str(other) @@ -515,19 +485,19 @@ class ArrayValue(_Container, collections.abc.MutableSequence, _Value): for elem in value: self.append(elem) - def _spec_eq(self, other): - try: - if len(self) != len(other): - # early mismatch - return False + def __eq__(self, other): + if not isinstance(other, collections.abc.Sequence): + return False - for self_elem, other_elem in zip(self, other): - if self_elem != other_elem: - return False + if len(self) != len(other): + # early mismatch + return False - return True - except: - return + for self_elem, other_elem in zip(self, other): + if self_elem != other_elem: + return False + + return True def __len__(self): size = native_bt.value_array_get_size(self._ptr) @@ -623,31 +593,25 @@ class MapValue(_Container, collections.abc.MutableMapping, _Value): for key, elem in value.items(): self[key] = elem - def __eq__(self, other): - return _Value.__eq__(self, other) - def __ne__(self, other): return _Value.__ne__(self, other) - def _spec_eq(self, other): - try: - if len(self) != len(other): - # early mismatch - return False + def __eq__(self, other): + if not isinstance(other, collections.abc.Mapping): + return False - for self_key in self: - if self_key not in other: - return False + if len(self) != len(other): + # early mismatch + return False - self_value = self[self_key] - other_value = other[self_key] + for self_key in self: + if self_key not in other: + return False - if self_value != other_value: - return False + if self[self_key] != other[self_key]: + return False - return True - except: - return + return True def __len__(self): size = native_bt.value_map_get_size(self._ptr) -- 2.34.1