fix for python3

This commit is contained in:
Daniel Pavel 2012-12-14 16:25:46 +02:00
parent 3cd0665166
commit cc6c0ee7df
1 changed files with 21 additions and 24 deletions

View File

@ -15,30 +15,31 @@ class NamedInt(int):
(case-insensitive).""" (case-insensitive)."""
def __new__(cls, value, name): def __new__(cls, value, name):
assert isinstance(name, str) or isinstance(name, unicode)
obj = int.__new__(cls, value) obj = int.__new__(cls, value)
obj.name = str(name) obj.name = name
return obj return obj
def bytes(self, count=2): def bytes(self, count=2):
value = int(self) if self.bit_length() > count * 8:
if value.bit_length() > count * 8: raise ValueError('cannot fit %X into %d bytes' % (self, count))
raise ValueError('cannot fit %X into %d bytes' % (value, count)) return _pack(b'!L', self)[-count:]
return _pack(b'!L', value)[-count:]
def __eq__(self, other): def __eq__(self, other):
if isinstance(other, NamedInt): if isinstance(other, NamedInt):
return int(self) == int(other) and self.name == other.name return int(self) == int(other) and self.name == other.name
if isinstance(other, int): if isinstance(other, int):
return int(self) == int(other) return int(self) == int(other)
if isinstance(other, basestring): if isinstance(other, str) or isinstance(other, unicode):
return self.name.lower() == other.lower() return self.name.lower() == other.lower()
def __ne__(self, other): def __ne__(self, other):
return not self.__eq__(other) return not self.__eq__(other)
def __str__(self): def __str__(self):
return self.name return str(self.name)
__unicode__ = __str__ def __unicode__(self):
return unicode(self.name)
def __repr__(self): def __repr__(self):
return 'NamedInt(%d, %s)' % (int(self), repr(self.name)) return 'NamedInt(%d, %s)' % (int(self), repr(self.name))
@ -61,7 +62,8 @@ class NamedInts(object):
def __init__(self, **kwargs): def __init__(self, **kwargs):
def _readable_name(n): def _readable_name(n):
assert isinstance(n, basestring) if not isinstance(n, str) and not isinstance(n, unicode):
raise TypeError("expected string, got " + type(n))
if n == n.upper(): if n == n.upper():
n.lstrip('_') n.lstrip('_')
return n.replace('__', '/').replace('_', ' ') return n.replace('__', '/').replace('_', ' ')
@ -88,11 +90,6 @@ class NamedInts(object):
if unknown_bits: if unknown_bits:
yield 'unknown:%06X' % unknown_bits yield 'unknown:%06X' % unknown_bits
# def index(self, value):
# if value in self._values:
# return self._values.index(value)
# raise IndexError('%s not found' % value)
def __getitem__(self, index): def __getitem__(self, index):
if isinstance(index, int): if isinstance(index, int):
if index in self._indexed: if index in self._indexed:
@ -103,7 +100,7 @@ class NamedInts(object):
self._values = sorted(self._values + [value]) self._values = sorted(self._values + [value])
return value return value
elif isinstance(index, basestring): elif isinstance(index, str) or isinstance(index, unicode):
if index in self.__dict__: if index in self.__dict__:
return self.__dict__[index] return self.__dict__[index]
@ -134,19 +131,19 @@ class NamedInts(object):
return self._values[start_index:stop_index] return self._values[start_index:stop_index]
def __setitem__(self, index, name): def __setitem__(self, index, name):
assert isinstance(index, int) assert isinstance(index, int), type(index)
if isinstance(name, NamedInt): if isinstance(name, NamedInt):
assert int(index) == int(name) assert int(index) == int(name), repr(index) + ' ' + repr(name)
value = name value = name
elif isinstance(name, basestring): elif isinstance(name, str) or isinstance(name, unicode):
value = NamedInt(index, name) value = NamedInt(index, name)
else: else:
raise TypeError('name must be a basestring') raise TypeError('name must be a string')
if str(value) in self.__dict__: if str(value) in self.__dict__:
raise ValueError('%s (%d) already known' % (str(value), int(value))) raise ValueError('%s (%d) already known' % (value, int(value)))
if int(value) in self._indexed: if int(value) in self._indexed:
raise ValueError('%d (%s) already known' % (int(value), str(value))) raise ValueError('%d (%s) already known' % (int(value), value))
self._values = sorted(self._values + [value]) self._values = sorted(self._values + [value])
self.__dict__[str(value)] = value self.__dict__[str(value)] = value
@ -154,9 +151,9 @@ class NamedInts(object):
def __contains__(self, value): def __contains__(self, value):
if isinstance(value, int): if isinstance(value, int):
return int(value) in self._indexed return value in self._indexed
if isinstance(value, basestring): if isinstance(value, str) or isinstance(value, unicode):
return str(value) in self.__dict__ return value in self.__dict__
def __iter__(self): def __iter__(self):
for v in self._values: for v in self._values: