diff --git a/lib/logitech/unifying_receiver/common.py b/lib/logitech/unifying_receiver/common.py index 3bf889e3..7afa06bb 100644 --- a/lib/logitech/unifying_receiver/common.py +++ b/lib/logitech/unifying_receiver/common.py @@ -11,8 +11,8 @@ from struct import pack as _pack class NamedInt(int): """An reqular Python integer with an attached name. - Careful when using this, because - """ + Caution: comparison with strings will also match this NamedInt's name + (case-insensitive).""" def __new__(cls, value, name): obj = int.__new__(cls, value) @@ -23,45 +23,18 @@ class NamedInt(int): value = int(self) if value.bit_length() > count * 8: raise ValueError('cannot fit %X into %d bytes' % (value, count)) - return _pack(b'!L', value)[-count:] - def __hash__(self): - return int(self) - def __eq__(self, other): + if isinstance(other, NamedInt): + return int(self) == int(other) and self.name == other.name if isinstance(other, int): return int(self) == int(other) - if isinstance(other, basestring): return self.name.lower() == other.lower() def __ne__(self, other): - if isinstance(other, int): - return int(self) != int(other) - - if isinstance(other, basestring): - return self.name.lower() != other.lower() - - def __lt__(self, other): - if not isinstance(other, int): - raise TypeError('unorderable types: %s < %s' % (type(self), type(other))) - return int(self) < int(other) - - def __le__(self, other): - if not isinstance(other, int): - raise TypeError('unorderable types: %s <= %s' % (type(self), type(other))) - return int(self) <= int(other) - - def __gt__(self, other): - if not isinstance(other, int): - raise TypeError('unorderable types: %s > %s' % (type(self), type(other))) - return int(self) > int(other) - - def __ge__(self, other): - if not isinstance(other, int): - raise TypeError('unorderable types: %s >= %s' % (type(self), type(other))) - return int(self) >= int(other) + return not self.__eq__(other) def __str__(self): return self.name @@ -72,14 +45,23 @@ class NamedInt(int): class NamedInts(object): + """A collection of NamedInt values. + + Behaves partially like a sorted list (by int value), partially like a dict. + """ __slots__ = ['__dict__', '_values', '_indexed', '_fallback'] def __init__(self, **kwargs): - values = dict((k, NamedInt(v, k.lstrip('_') if k == k.upper() else - k.replace('__', '/').replace('_', ' '))) for (k, v) in kwargs.items()) + def _readable_name(n): + assert isinstance(n, basestring) + if n == n.upper(): + n.lstrip('_') + return n.replace('__', '/').replace('_', ' ') + + values = {k: NamedInt(v, _readable_name(k)) for (k, v) in kwargs.items()} self.__dict__ = values self._values = sorted(list(values.values())) - self._indexed = dict((int(v), v) for v in self._values) + self._indexed = {int(v): v for v in self._values} self._fallback = None def flag_names(self, value): @@ -102,26 +84,47 @@ class NamedInts(object): if isinstance(index, int): if index in self._indexed: return self._indexed[int(index)] - if self._fallback and type(index) == int: value = NamedInt(index, self._fallback(index)) self._indexed[index] = value self._values = sorted(self._values + [value]) return value - elif type(index) == slice: + elif isinstance(index, slice): return self._values[index] + elif isinstance(index, basestring): + if index in self.__dict__: + return self.__dict__[index] + + def __setitem__(self, index, name): + assert isinstance(index, int) + if isinstance(name, NamedInt): + assert int(index) == int(name) + value = name + elif isinstance(name, basestring): + value = NamedInt(index, name) else: - if index in self._values: - index = self._values.index(index) - return self._values[index] + raise TypeError('name must be a basestring') + + if str(value) in self.__dict__: + raise ValueError('%s (%d) already known' % (str(value), int(value))) + if int(value) in self._indexed: + raise ValueError('%d (%s) already known' % (int(value), str(value))) + + self._values = sorted(self._values + [value]) + self.__dict__[str(value)] = value + self._indexed[int(value)] = value def __contains__(self, value): - return value in self._values + if isinstance(value, int): + return int(value) in self._indexed + if isinstance(value, basestring): + return str(value) in self.__dict__ def __iter__(self): - return iter(sorted(self._values)) + for v in self._values: + yield v def __len__(self): return len(self._values) diff --git a/lib/logitech/unifying_receiver/receiver.py b/lib/logitech/unifying_receiver/receiver.py index a72bebb8..c59469ac 100644 --- a/lib/logitech/unifying_receiver/receiver.py +++ b/lib/logitech/unifying_receiver/receiver.py @@ -6,7 +6,6 @@ from __future__ import absolute_import, division, print_function, unicode_litera import errno as _errno from weakref import proxy as _proxy -from collections import defaultdict as _defaultdict from logging import getLogger _log = getLogger('LUR').getChild('receiver') @@ -15,7 +14,7 @@ del getLogger from . import base as _base from . import hidpp10 as _hidpp10 from . import hidpp20 as _hidpp20 -from .common import strhex as _strhex +from .common import strhex as _strhex, NamedInts as _NamedInts from .descriptors import DEVICES as _DEVICES # @@ -144,7 +143,7 @@ class PairedDevice(object): if self._registers is None: descriptor = _DEVICES.get(self.codename) if descriptor is None or descriptor.registers is None: - self._registers = _defaultdict(lambda: None) + self._registers = _NamedInts() else: self._registers = descriptor.registers return self._registers