From 614a5dc633013e6946dc392a65050545ceb622fa Mon Sep 17 00:00:00 2001 From: MattHag <16444067+MattHag@users.noreply.github.com> Date: Sat, 28 Sep 2024 14:10:14 +0200 Subject: [PATCH] Add type hints and clean up --- lib/hidapi/hidapi_impl.py | 10 ++--- lib/hidapi/udev_impl.py | 2 +- lib/logitech_receiver/base.py | 77 +++++++++++++++++++---------------- 3 files changed, 47 insertions(+), 42 deletions(-) diff --git a/lib/hidapi/hidapi_impl.py b/lib/hidapi/hidapi_impl.py index 3512ea47..13d834a7 100644 --- a/lib/hidapi/hidapi_impl.py +++ b/lib/hidapi/hidapi_impl.py @@ -263,10 +263,10 @@ def _match(action, device, filterfn): if not device["hidpp_short"] and not device["hidpp_long"]: return None - filter = filterfn(bus_id, vid, pid, device["hidpp_short"], device["hidpp_long"]) - if not filter: + filter_func = filterfn(bus_id, vid, pid, device["hidpp_short"], device["hidpp_long"]) + if not filter_func: return - isDevice = filter.get("isDevice") + isDevice = filter_func.get("isDevice") if action == "add": d_info = DeviceInfo( @@ -305,12 +305,12 @@ def _match(action, device, filterfn): return d_info -def find_paired_node(receiver_path, index, timeout): +def find_paired_node(receiver_path: str, index: int, timeout: int): """Find the node of a device paired with a receiver""" return None -def find_paired_node_wpid(receiver_path, index): +def find_paired_node_wpid(receiver_path: str, index: int): """Find the node of a device paired with a receiver, get wpid from udev""" return None diff --git a/lib/hidapi/udev_impl.py b/lib/hidapi/udev_impl.py index 39a43edf..58dc6644 100644 --- a/lib/hidapi/udev_impl.py +++ b/lib/hidapi/udev_impl.py @@ -176,7 +176,7 @@ def _match(action, device, filter_func: typing.Callable[[int, int, int, bool, bo return d_info -def find_paired_node(receiver_path, index, timeout): +def find_paired_node(receiver_path: str, index: int, timeout: int): """Find the node of a device paired with a receiver""" context = pyudev.Context() receiver_phys = pyudev.Devices.from_device_file(context, receiver_path).find_parent("hid").get("HID_PHYS") diff --git a/lib/logitech_receiver/base.py b/lib/logitech_receiver/base.py index b5faa4b4..da3ac43b 100644 --- a/lib/logitech_receiver/base.py +++ b/lib/logitech_receiver/base.py @@ -37,7 +37,6 @@ from . import common from . import descriptors from . import exceptions from . import hidpp10_constants -from . import hidpp20 from . import hidpp20_constants from .common import LOGITECH_VENDOR_ID from .common import BusID @@ -53,7 +52,25 @@ else: logger = logging.getLogger(__name__) -_hidpp20 = hidpp20.Hidpp20() + +_SHORT_MESSAGE_SIZE = 7 +_LONG_MESSAGE_SIZE = 20 +_MEDIUM_MESSAGE_SIZE = 15 +_MAX_READ_SIZE = 32 + +HIDPP_SHORT_MESSAGE_ID = 0x10 +HIDPP_LONG_MESSAGE_ID = 0x11 +DJ_MESSAGE_ID = 0x20 + + +"""Default timeout on read (in seconds).""" +DEFAULT_TIMEOUT = 4 +# the receiver itself should reply very fast, within 500ms +_RECEIVER_REQUEST_TIMEOUT = 0.9 +# devices may reply a lot slower, as the call has to go wireless to them and come back +_DEVICE_REQUEST_TIMEOUT = DEFAULT_TIMEOUT +# when pinging, be extra patient (no longer) +_PING_TIMEOUT = DEFAULT_TIMEOUT @dataclasses.dataclass @@ -112,32 +129,6 @@ def product_information(usb_id: int) -> dict[str, Any]: return base_usb.get_receiver_info(usb_id) -_SHORT_MESSAGE_SIZE = 7 -_LONG_MESSAGE_SIZE = 20 -_MEDIUM_MESSAGE_SIZE = 15 -_MAX_READ_SIZE = 32 - -HIDPP_SHORT_MESSAGE_ID = 0x10 -HIDPP_LONG_MESSAGE_ID = 0x11 -DJ_MESSAGE_ID = 0x20 - -# mapping from report_id to message length -report_lengths = { - HIDPP_SHORT_MESSAGE_ID: _SHORT_MESSAGE_SIZE, - HIDPP_LONG_MESSAGE_ID: _LONG_MESSAGE_SIZE, - DJ_MESSAGE_ID: _MEDIUM_MESSAGE_SIZE, - 0x21: _MAX_READ_SIZE, -} -"""Default timeout on read (in seconds).""" -DEFAULT_TIMEOUT = 4 -# the receiver itself should reply very fast, within 500ms -_RECEIVER_REQUEST_TIMEOUT = 0.9 -# devices may reply a lot slower, as the call has to go wireless to them and come back -_DEVICE_REQUEST_TIMEOUT = DEFAULT_TIMEOUT -# when pinging, be extra patient (no longer) -_PING_TIMEOUT = DEFAULT_TIMEOUT - - def _match(record: dict[str, Any], bus_id: int, vendor_id: int, product_id: int): return ( (record.get("bus_id") is None or record.get("bus_id") == bus_id) @@ -171,7 +162,9 @@ def receivers(): yield from hidapi.enumerate(filter_receivers) -def filter(bus_id: int, vendor_id: int, product_id: int, hidpp_short: bool = False, hidpp_long: bool = False): +def filter_products_of_interest( + bus_id: int, vendor_id: int, product_id: int, hidpp_short: bool = False, hidpp_long: bool = False +) -> dict[str, Any] | None: """Check that this product is of interest and if so return the device record for further checking""" record = filter_receivers(bus_id, vendor_id, product_id, hidpp_short, hidpp_long) if record: # known or unknown receiver @@ -188,7 +181,7 @@ def filter(bus_id: int, vendor_id: int, product_id: int, hidpp_short: bool = Fal def receivers_and_devices(): """Enumerate all the receivers and devices directly attached to the machine.""" - yield from hidapi.enumerate(filter) + yield from hidapi.enumerate(filter_products_of_interest) def notify_on_receivers_glib(glib: GLib, callback): @@ -199,7 +192,7 @@ def notify_on_receivers_glib(glib: GLib, callback): glib GLib instance. """ - return hidapi.monitor_glib(glib, callback, filter) + return hidapi.monitor_glib(glib, callback, filter_products_of_interest) def open_path(path): @@ -301,11 +294,23 @@ def read(handle, timeout=DEFAULT_TIMEOUT): return reply -# sanity checks on message report id and size -def check_message(data): +def is_relevant_message(data: bytes) -> bool: + """Checks if given id is a HID++ or DJ message. + + Applies sanity checks on message report ID and message size. + """ assert isinstance(data, bytes), (repr(data), type(data)) + + # mapping from report_id to message length + report_lengths = { + HIDPP_SHORT_MESSAGE_ID: _SHORT_MESSAGE_SIZE, + HIDPP_LONG_MESSAGE_ID: _LONG_MESSAGE_SIZE, + DJ_MESSAGE_ID: _MEDIUM_MESSAGE_SIZE, + 0x21: _MAX_READ_SIZE, + } + report_id = ord(data[:1]) - if report_id in report_lengths: # is this an HID++ or DJ message? + if report_id in report_lengths: if report_lengths.get(report_id) == len(data): return True else: @@ -331,7 +336,7 @@ def _read(handle, timeout): close(handle) raise exceptions.NoReceiver(reason=reason) from reason - if data and check_message(data): # ignore messages that fail check + if data and is_relevant_message(data): # ignore messages that fail check report_id = ord(data[:1]) devnumber = ord(data[1:2]) @@ -361,7 +366,7 @@ def _skip_incoming(handle, ihandle, notifications_hook): raise exceptions.NoReceiver(reason=reason) from reason if data: - if check_message(data): # only process messages that pass check + if is_relevant_message(data): # only process messages that pass check # report_id = ord(data[:1]) if notifications_hook: n = make_notification(ord(data[:1]), ord(data[1:2]), data[2:])