refactor: use IntEnum for firmware and cidgroup constances

* Refactor: test_named_ints_flag_names

Shorten test and clarify behavior using binary numbers.

* Introduce plain flag_names function

This replicates the NamedInts functionality as plain function.

* Refactor FeatureFlag to use IntFlag

Replace NamedInts implementation with IntFlag enum and plain flag_names
function.

Related #2273

* Refactor FirmwareKind to use IntEnum

- Move general FirmwareKind to common module.
- Replace NamedInts implementation with IntEnum.
- Harden related HIDPP 1.0 get_firmware test.

Related #2273

* Refactor CID_GROUP, CID_GROUP_BIT to use IntEnum

Related #2273
This commit is contained in:
MattHag 2024-10-23 22:25:35 +02:00 committed by GitHub
parent 79ffbda903
commit 1afcfe4b57
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 186 additions and 43 deletions

View File

@ -18,8 +18,11 @@ from __future__ import annotations
import binascii import binascii
import dataclasses import dataclasses
import typing
from enum import IntEnum from enum import IntEnum
from typing import Generator
from typing import Iterable
from typing import Optional from typing import Optional
from typing import Union from typing import Union
@ -27,6 +30,9 @@ import yaml
from solaar.i18n import _ from solaar.i18n import _
if typing.TYPE_CHECKING:
from logitech_receiver.hidpp20_constants import FirmwareKind
LOGITECH_VENDOR_ID = 0x046D LOGITECH_VENDOR_ID = 0x046D
@ -502,6 +508,31 @@ class NamedInts:
return isinstance(other, self.__class__) and self._values == other._values return isinstance(other, self.__class__) and self._values == other._values
def flag_names(enum_class: Iterable, value: int) -> Generator[str]:
"""Extracts single bit flags from a (binary) number.
Parameters
----------
enum_class
Enum class to extract flags from.
value
Number to extract binary flags from.
"""
indexed = {item.value: item.name for item in enum_class}
unknown_bits = value
for k in indexed:
# Ensure that the key (flag value) is a power of 2 (a single bit flag)
assert bin(k).count("1") == 1
if k & value == k:
unknown_bits &= ~k
yield indexed[k].lower()
# Yield any remaining unknown bits
if unknown_bits != 0:
yield f"unknown:{unknown_bits:06X}"
class UnsortedNamedInts(NamedInts): class UnsortedNamedInts(NamedInts):
def _sort_values(self): def _sort_values(self):
pass pass
@ -543,9 +574,16 @@ class KwException(Exception):
return self.args[0].get(k) # was self.args[0][k] return self.args[0].get(k) # was self.args[0][k]
class FirmwareKind(IntEnum):
Firmware = 0x00
Bootloader = 0x01
Hardware = 0x02
Other = 0x03
@dataclasses.dataclass @dataclasses.dataclass
class FirmwareInfo: class FirmwareInfo:
kind: str kind: FirmwareKind
name: str name: str
version: str version: str
extras: str | None extras: str | None

View File

@ -14,10 +14,14 @@
## You should have received a copy of the GNU General Public License along ## You should have received a copy of the GNU General Public License along
## with this program; if not, write to the Free Software Foundation, Inc., ## with this program; if not, write to the Free Software Foundation, Inc.,
## 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. ## 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
from __future__ import annotations
import errno import errno
import logging import logging
import threading import threading
import time import time
import typing
from typing import Any from typing import Any
from typing import Callable from typing import Callable
@ -37,6 +41,9 @@ from .common import Alert
from .common import Battery from .common import Battery
from .hidpp20_constants import SupportedFeature from .hidpp20_constants import SupportedFeature
if typing.TYPE_CHECKING:
from logitech_receiver import common
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_hidpp10 = hidpp10.Hidpp10() _hidpp10 = hidpp10.Hidpp10()
@ -265,7 +272,7 @@ class Device:
return self._kind or "?" return self._kind or "?"
@property @property
def firmware(self): def firmware(self) -> tuple[common.FirmwareInfo]:
if self._firmware is None and self.online: if self._firmware is None and self.online:
if self.protocol >= 2.0: if self.protocol >= 2.0:
self._firmware = _hidpp20.get_firmware(self) self._firmware = _hidpp20.get_firmware(self)

View File

@ -25,8 +25,8 @@ from . import common
from .common import Battery from .common import Battery
from .common import BatteryLevelApproximation from .common import BatteryLevelApproximation
from .common import BatteryStatus from .common import BatteryStatus
from .common import FirmwareKind
from .hidpp10_constants import Registers from .hidpp10_constants import Registers
from .hidpp20_constants import FIRMWARE_KIND
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -110,7 +110,7 @@ class Hidpp10:
device.registers.append(Registers.BATTERY_STATUS) device.registers.append(Registers.BATTERY_STATUS)
return parse_battery_status(Registers.BATTERY_STATUS, reply) return parse_battery_status(Registers.BATTERY_STATUS, reply)
def get_firmware(self, device: Device): def get_firmware(self, device: Device) -> tuple[common.FirmwareInfo] | None:
assert device is not None assert device is not None
firmware = [None, None, None] firmware = [None, None, None]
@ -125,21 +125,21 @@ class Hidpp10:
reply = read_register(device, Registers.FIRMWARE, 0x02) reply = read_register(device, Registers.FIRMWARE, 0x02)
if reply: if reply:
fw_version += ".B" + common.strhex(reply[1:3]) fw_version += ".B" + common.strhex(reply[1:3])
fw = common.FirmwareInfo(FIRMWARE_KIND.Firmware, "", fw_version, None) fw = common.FirmwareInfo(FirmwareKind.Firmware, "", fw_version, None)
firmware[0] = fw firmware[0] = fw
reply = read_register(device, Registers.FIRMWARE, 0x04) reply = read_register(device, Registers.FIRMWARE, 0x04)
if reply: if reply:
bl_version = common.strhex(reply[1:3]) bl_version = common.strhex(reply[1:3])
bl_version = f"{bl_version[0:2]}.{bl_version[2:4]}" bl_version = f"{bl_version[0:2]}.{bl_version[2:4]}"
bl = common.FirmwareInfo(FIRMWARE_KIND.Bootloader, "", bl_version, None) bl = common.FirmwareInfo(FirmwareKind.Bootloader, "", bl_version, None)
firmware[1] = bl firmware[1] = bl
reply = read_register(device, Registers.FIRMWARE, 0x03) reply = read_register(device, Registers.FIRMWARE, 0x03)
if reply: if reply:
o_version = common.strhex(reply[1:3]) o_version = common.strhex(reply[1:3])
o_version = f"{o_version[0:2]}.{o_version[2:4]}" o_version = f"{o_version[0:2]}.{o_version[2:4]}"
o = common.FirmwareInfo(FIRMWARE_KIND.Other, "", o_version, None) o = common.FirmwareInfo(FirmwareKind.Other, "", o_version, None)
firmware[2] = o firmware[2] = o
if any(firmware): if any(firmware):

View File

@ -23,6 +23,7 @@ import threading
from typing import Any from typing import Any
from typing import Dict from typing import Dict
from typing import Generator
from typing import List from typing import List
from typing import Optional from typing import Optional
from typing import Tuple from typing import Tuple
@ -39,13 +40,13 @@ from . import special_keys
from .common import Battery from .common import Battery
from .common import BatteryLevelApproximation from .common import BatteryLevelApproximation
from .common import BatteryStatus from .common import BatteryStatus
from .common import FirmwareKind
from .common import NamedInt from .common import NamedInt
from .hidpp20_constants import CHARGE_LEVEL from .hidpp20_constants import CHARGE_LEVEL
from .hidpp20_constants import CHARGE_STATUS from .hidpp20_constants import CHARGE_STATUS
from .hidpp20_constants import CHARGE_TYPE from .hidpp20_constants import CHARGE_TYPE
from .hidpp20_constants import DEVICE_KIND from .hidpp20_constants import DEVICE_KIND
from .hidpp20_constants import ERROR from .hidpp20_constants import ERROR
from .hidpp20_constants import FIRMWARE_KIND
from .hidpp20_constants import GESTURE from .hidpp20_constants import GESTURE
from .hidpp20_constants import SupportedFeature from .hidpp20_constants import SupportedFeature
@ -241,8 +242,8 @@ class ReprogrammableKeyV4(ReprogrammableKey):
self._mapped_to = None self._mapped_to = None
@property @property
def group_mask(self): def group_mask(self) -> Generator[str]:
return special_keys.CID_GROUP_BIT.flag_names(self._gmask) return common.flag_names(special_keys.CIDGroupBit, self._gmask)
@property @property
def mapped_to(self) -> NamedInt: def mapped_to(self) -> NamedInt:
@ -259,7 +260,7 @@ class ReprogrammableKeyV4(ReprogrammableKey):
if self.group_mask: # only keys with a non-zero gmask are remappable if self.group_mask: # only keys with a non-zero gmask are remappable
ret[self.default_task] = self.default_task # it should always be possible to map the key to itself ret[self.default_task] = self.default_task # it should always be possible to map the key to itself
for g in self.group_mask: for g in self.group_mask:
g = special_keys.CID_GROUP[str(g)] g = special_keys.CidGroup[str(g)]
for tgt_cid in self._device.keys.group_cids[g]: for tgt_cid in self._device.keys.group_cids[g]:
tgt_task = str(special_keys.TASK[self._device.keys.cid_to_tid[tgt_cid]]) tgt_task = str(special_keys.TASK[self._device.keys.cid_to_tid[tgt_cid]])
tgt_task = NamedInt(tgt_cid, tgt_task) tgt_task = NamedInt(tgt_cid, tgt_task)
@ -515,7 +516,7 @@ class KeysArrayV2(KeysArray):
self.cid_to_tid = {} self.cid_to_tid = {}
"""The mapping from Control ID groups to Controls IDs that belong to it. """The mapping from Control ID groups to Controls IDs that belong to it.
A key k can only be remapped to targets in groups within k.group_mask.""" A key k can only be remapped to targets in groups within k.group_mask."""
self.group_cids = {g: [] for g in special_keys.CID_GROUP} self.group_cids = {g: [] for g in special_keys.CidGroup}
def _query_key(self, index: int): def _query_key(self, index: int):
if index < 0 or index >= len(self.keys): if index < 0 or index >= len(self.keys):
@ -543,7 +544,7 @@ class KeysArrayV4(KeysArrayV2):
self.keys[index] = ReprogrammableKeyV4(self.device, index, cid, tid, flags, pos, group, gmask) self.keys[index] = ReprogrammableKeyV4(self.device, index, cid, tid, flags, pos, group, gmask)
self.cid_to_tid[cid] = tid self.cid_to_tid[cid] = tid
if group != 0: # 0 = does not belong to a group if group != 0: # 0 = does not belong to a group
self.group_cids[special_keys.CID_GROUP[group]].append(cid) self.group_cids[special_keys.CidGroup(group)].append(cid)
elif logger.isEnabledFor(logging.WARNING): elif logger.isEnabledFor(logging.WARNING):
logger.warning(f"Key with index {index} was expected to exist but device doesn't report it.") logger.warning(f"Key with index {index} was expected to exist but device doesn't report it.")
@ -1451,7 +1452,7 @@ battery_voltage_remaining = (
class Hidpp20: class Hidpp20:
def get_firmware(self, device): def get_firmware(self, device) -> tuple[common.FirmwareInfo] | None:
"""Reads a device's firmware info. """Reads a device's firmware info.
:returns: a list of FirmwareInfo tuples, ordered by firmware layer. :returns: a list of FirmwareInfo tuples, ordered by firmware layer.
@ -1471,11 +1472,11 @@ class Hidpp20:
if build: if build:
version += f".B{build:04X}" version += f".B{build:04X}"
extras = fw_info[9:].rstrip(b"\x00") or None extras = fw_info[9:].rstrip(b"\x00") or None
fw_info = common.FirmwareInfo(FIRMWARE_KIND[level], name.decode("ascii"), version, extras) fw_info = common.FirmwareInfo(FirmwareKind(level), name.decode("ascii"), version, extras)
elif level == FIRMWARE_KIND.Hardware: elif level == FirmwareKind.Hardware:
fw_info = common.FirmwareInfo(FIRMWARE_KIND.Hardware, "", str(ord(fw_info[1:2])), None) fw_info = common.FirmwareInfo(FirmwareKind.Hardware, "", str(ord(fw_info[1:2])), None)
else: else:
fw_info = common.FirmwareInfo(FIRMWARE_KIND.Other, "", "", None) fw_info = common.FirmwareInfo(FirmwareKind.Other, "", "", None)
fw.append(fw_info) fw.append(fw_info)
return tuple(fw) return tuple(fw)

View File

@ -15,6 +15,7 @@
## with this program; if not, write to the Free Software Foundation, Inc., ## with this program; if not, write to the Free Software Foundation, Inc.,
## 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. ## 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
from enum import IntEnum from enum import IntEnum
from enum import IntFlag
from .common import NamedInts from .common import NamedInts
@ -152,7 +153,13 @@ class SupportedFeature(IntEnum):
return self.name.replace("_", " ") return self.name.replace("_", " ")
FEATURE_FLAG = NamedInts(internal=0x20, hidden=0x40, obsolete=0x80) class FeatureFlag(IntFlag):
"""Single bit flags."""
INTERNAL = 0x20
HIDDEN = 0x40
OBSOLETE = 0x80
DEVICE_KIND = NamedInts( DEVICE_KIND = NamedInts(
keyboard=0x00, keyboard=0x00,
@ -165,7 +172,6 @@ DEVICE_KIND = NamedInts(
receiver=0x07, receiver=0x07,
) )
FIRMWARE_KIND = NamedInts(Firmware=0x00, Bootloader=0x01, Hardware=0x02, Other=0x03)
ONBOARD_MODES = NamedInts(MODE_NO_CHANGE=0x00, MODE_ONBOARD=0x01, MODE_HOST=0x02) ONBOARD_MODES = NamedInts(MODE_NO_CHANGE=0x00, MODE_ONBOARD=0x01, MODE_HOST=0x02)

View File

@ -15,9 +15,12 @@
## with this program; if not, write to the Free Software Foundation, Inc., ## with this program; if not, write to the Free Software Foundation, Inc.,
## 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. ## 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
from __future__ import annotations
import errno import errno
import logging import logging
import time import time
import typing
from dataclasses import dataclass from dataclasses import dataclass
from typing import Callable from typing import Callable
@ -35,6 +38,9 @@ from .common import Notification
from .device import Device from .device import Device
from .hidpp10_constants import Registers from .hidpp10_constants import Registers
if typing.TYPE_CHECKING:
from logitech_receiver import common
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_hidpp10 = hidpp10.Hidpp10() _hidpp10 = hidpp10.Hidpp10()
@ -145,7 +151,7 @@ class Receiver:
self.status_callback(self, alert=alert, reason=reason) self.status_callback(self, alert=alert, reason=reason)
@property @property
def firmware(self): def firmware(self) -> tuple[common.FirmwareInfo]:
if self._firmware is None and self.handle: if self._firmware is None and self.handle:
self._firmware = _hidpp10.get_firmware(self) self._firmware = _hidpp10.get_firmware(self)
return self._firmware return self._firmware

View File

@ -19,6 +19,8 @@
import os import os
from enum import IntEnum
import yaml import yaml
from .common import NamedInts from .common import NamedInts
@ -595,8 +597,30 @@ MAPPING_FLAG = NamedInts(
persistently_diverted=0x04, persistently_diverted=0x04,
diverted=0x01, diverted=0x01,
) )
CID_GROUP_BIT = NamedInts(g8=0x80, g7=0x40, g6=0x20, g5=0x10, g4=0x08, g3=0x04, g2=0x02, g1=0x01)
CID_GROUP = NamedInts(g8=8, g7=7, g6=6, g5=5, g4=4, g3=3, g2=2, g1=1)
class CIDGroupBit(IntEnum):
g1 = 0x01
g2 = 0x02
g3 = 0x04
g4 = 0x08
g5 = 0x10
g6 = 0x20
g7 = 0x40
g8 = 0x80
class CidGroup(IntEnum):
g1 = 1
g2 = 2
g3 = 3
g4 = 4
g5 = 5
g6 = 6
g7 = 7
g8 = 8
DISABLE = NamedInts( DISABLE = NamedInts(
Caps_Lock=0x01, Caps_Lock=0x01,
Num_Lock=0x02, Num_Lock=0x02,

View File

@ -14,6 +14,7 @@
## with this program; if not, write to the Free Software Foundation, Inc., ## with this program; if not, write to the Free Software Foundation, Inc.,
## 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. ## 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
from logitech_receiver import common
from logitech_receiver import exceptions from logitech_receiver import exceptions
from logitech_receiver import hidpp10 from logitech_receiver import hidpp10
from logitech_receiver import hidpp10_constants from logitech_receiver import hidpp10_constants
@ -149,7 +150,7 @@ def _print_device(dev, num=None):
for feature, index in dev.features.enumerate(): for feature, index in dev.features.enumerate():
flags = dev.request(0x0000, feature.bytes(2)) flags = dev.request(0x0000, feature.bytes(2))
flags = 0 if flags is None else ord(flags[1:2]) flags = 0 if flags is None else ord(flags[1:2])
flags = hidpp20_constants.FEATURE_FLAG.flag_names(flags) flags = common.flag_names(hidpp20_constants.FeatureFlag, flags)
version = dev.features.get_feature_version(int(feature)) version = dev.features.get_feature_version(int(feature))
version = version if version else 0 version = version if version else 0
print(" %2d: %-22s {%04X} V%s %s " % (index, feature, feature, version, ", ".join(flags))) print(" %2d: %-22s {%04X} V%s %s " % (index, feature, feature, version, ", ".join(flags)))

View File

@ -1,3 +1,5 @@
from enum import IntFlag
import pytest import pytest
import yaml import yaml
@ -121,18 +123,41 @@ def test_named_ints_range():
assert 6 not in named_ints_range assert 6 not in named_ints_range
def test_named_ints_flag_names(): @pytest.mark.parametrize(
named_ints_flag_bits = common.NamedInts(one=1, two=2, three=4) "code, expected_flags",
[
(0, []),
(0b0010, ["two"]),
(0b0101, ["one", "three"]),
(0b1001, ["one", "unknown:000008"]),
],
)
def test_named_ints_flag_names(code, expected_flags):
named_ints_flag_bits = common.NamedInts(one=0b001, two=0b010, three=0b100)
flags0 = list(named_ints_flag_bits.flag_names(0)) flags = list(named_ints_flag_bits.flag_names(code))
flags2 = list(named_ints_flag_bits.flag_names(2))
flags5 = list(named_ints_flag_bits.flag_names(5))
flags9 = list(named_ints_flag_bits.flag_names(9))
assert flags0 == [] assert flags == expected_flags
assert flags2 == ["two"]
assert flags5 == ["one", "three"]
assert flags9 == ["one", "unknown:000008"] @pytest.mark.parametrize(
"code, expected_flags",
[
(0, []),
(0b0010, ["two"]),
(0b0101, ["one", "three"]),
(0b1001, ["one", "unknown:000008"]),
],
)
def test_flag_names(code, expected_flags):
class ExampleFlag(IntFlag):
one = 0x1
two = 0x2
three = 0x4
flags = common.flag_names(ExampleFlag, code)
assert list(flags) == expected_flags
def test_named_ints_setitem(): def test_named_ints_setitem():

View File

@ -9,7 +9,6 @@ import pytest
from logitech_receiver import common from logitech_receiver import common
from logitech_receiver import hidpp10 from logitech_receiver import hidpp10
from logitech_receiver import hidpp10_constants from logitech_receiver import hidpp10_constants
from logitech_receiver import hidpp20_constants
from logitech_receiver.hidpp10_constants import Registers from logitech_receiver.hidpp10_constants import Registers
_hidpp10 = hidpp10.Hidpp10() _hidpp10 = hidpp10.Hidpp10()
@ -190,18 +189,28 @@ def test_hidpp10_get_battery(device, expected_result, expected_register):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"device, expected_length", "device, expected_firmwares",
[ [
(device_offline, 0), (device_offline, []),
(device_standard, 3), (
device_standard,
[
common.FirmwareKind.Firmware,
common.FirmwareKind.Bootloader,
common.FirmwareKind.Other,
],
),
], ],
) )
def test_hidpp10_get_firmware(device, expected_length): def test_hidpp10_get_firmware(device, expected_firmwares):
firmwares = _hidpp10.get_firmware(device) firmwares = _hidpp10.get_firmware(device)
assert len(firmwares) == expected_length if expected_length > 0 else firmwares is None if not expected_firmwares:
for firmware in firmwares if firmwares is not None else []: assert firmwares is None
assert firmware.kind in hidpp20_constants.FIRMWARE_KIND else:
firmware_types = [firmware.kind for firmware in firmwares]
assert firmware_types == expected_firmwares
assert len(firmwares) == len(expected_firmwares)
@pytest.mark.parametrize( @pytest.mark.parametrize(

View File

@ -192,7 +192,7 @@ def test_ReprogrammableKey_key(device, index, cid, tid, flags, default_task, fla
), ),
], ],
) )
def test_ReprogrammableKeyV4_key(device, index, cid, tid, flags, pos, group, gmask, default_task, flag_names, group_names): def test_reprogrammable_key_v4_key(device, index, cid, tid, flags, pos, group, gmask, default_task, flag_names, group_names):
key = hidpp20.ReprogrammableKeyV4(device, index, cid, tid, flags, pos, group, gmask) key = hidpp20.ReprogrammableKeyV4(device, index, cid, tid, flags, pos, group, gmask)
assert key._device == device assert key._device == device

View File

@ -18,6 +18,7 @@ import pytest
from logitech_receiver import common from logitech_receiver import common
from logitech_receiver import hidpp20 from logitech_receiver import hidpp20
from logitech_receiver import hidpp20_constants
from logitech_receiver.hidpp20_constants import SupportedFeature from logitech_receiver.hidpp20_constants import SupportedFeature
from . import fake_hidpp from . import fake_hidpp
@ -412,3 +413,28 @@ def test_decipher_adc_measurement():
assert battery.level == 90 assert battery.level == 90
assert battery.status == common.BatteryStatus.RECHARGING assert battery.status == common.BatteryStatus.RECHARGING
assert battery.voltage == 0x1000 assert battery.voltage == 0x1000
@pytest.mark.parametrize(
"code, expected_flags",
[
(0x01, ["unknown:000001"]),
(0x0F, ["unknown:00000F"]),
(0xF0, ["internal", "hidden", "obsolete", "unknown:000010"]),
(0x20, ["internal"]),
(0x33, ["internal", "unknown:000013"]),
(0x3F, ["internal", "unknown:00001F"]),
(0x40, ["hidden"]),
(0x50, ["hidden", "unknown:000010"]),
(0x5F, ["hidden", "unknown:00001F"]),
(0x7F, ["internal", "hidden", "unknown:00001F"]),
(0x80, ["obsolete"]),
(0xA0, ["internal", "obsolete"]),
(0xE0, ["internal", "hidden", "obsolete"]),
(0xFF, ["internal", "hidden", "obsolete", "unknown:00001F"]),
],
)
def test_feature_flag_names(code, expected_flags):
flags = common.flag_names(hidpp20_constants.FeatureFlag, code)
assert list(flags) == expected_flags