From 1f85ec01e71fd36f11134bd7c30aa7879d7b7909 Mon Sep 17 00:00:00 2001 From: MattHag <16444067+MattHag@users.noreply.github.com> Date: Wed, 2 Oct 2024 22:07:28 +0200 Subject: [PATCH] base: Add more unit tests Make internal functions private. --- lib/logitech_receiver/base.py | 36 +++++++++--------- tests/logitech_receiver/test_base.py | 56 +++++++++++++++++++++++----- 2 files changed, 66 insertions(+), 26 deletions(-) diff --git a/lib/logitech_receiver/base.py b/lib/logitech_receiver/base.py index b3319c5b..bb000013 100644 --- a/lib/logitech_receiver/base.py +++ b/lib/logitech_receiver/base.py @@ -137,7 +137,7 @@ for _ignore, d in descriptors.DEVICES.items(): KNOWN_DEVICE_IDS.append(_bluetooth_device(d.btid)) -def other_device_check(bus_id: int, vendor_id: int, product_id: int) -> dict[str, Any] | None: +def _other_device_check(bus_id: int, vendor_id: int, product_id: int) -> dict[str, Any] | None: """Check whether product is a Logitech USB-connected or Bluetooth device based on bus, vendor, and product IDs This allows Solaar to support receiverless HID++ 2.0 devices that it knows nothing about""" if vendor_id != LOGITECH_VENDOR_ID: @@ -164,10 +164,10 @@ def _match(record: dict[str, Any], bus_id: int, vendor_id: int, product_id: int) ) -def filter_receivers( +def _filter_receivers( bus_id: int, vendor_id: int, product_id: int, _hidpp_short: bool = False, _hidpp_long: bool = False ) -> dict[str, Any]: - """Check that this product is a Logitech receiver. + """Check that this product is a Logitech receiver and return it. Filters based on bus_id, vendor_id and product_id. @@ -182,18 +182,19 @@ def filter_receivers( if vendor_id == LOGITECH_VENDOR_ID and 0xC500 <= product_id <= 0xC5FF: # unknown receiver return {"vendor_id": vendor_id, "product_id": product_id, "bus_id": bus_id, "isDevice": False} + return None def receivers(): """Enumerate all the receivers attached to the machine.""" - yield from hidapi.enumerate(filter_receivers) + yield from hidapi.enumerate(_filter_receivers) -def filter_products_of_interest( +def _filter_products_of_interest( bus_id: int, vendor_id: int, product_id: int, hidpp_short: bool = False, hidpp_long: bool = False ) -> dict[str, Any] | None: """Check that this product is of interest and if so return the device record for further checking""" - record = filter_receivers(bus_id, vendor_id, product_id, hidpp_short, hidpp_long) + record = _filter_receivers(bus_id, vendor_id, product_id, hidpp_short, hidpp_long) if record: # known or unknown receiver return record @@ -203,12 +204,13 @@ def filter_products_of_interest( if hidpp_short or hidpp_long: # unknown devices that use HID++ return {"vendor_id": vendor_id, "product_id": product_id, "bus_id": bus_id, "isDevice": True} elif hidpp_short is None and hidpp_long is None: # unknown devices in correct range of IDs - return other_device_check(bus_id, vendor_id, product_id) + return _other_device_check(bus_id, vendor_id, product_id) + return None def receivers_and_devices(): """Enumerate all the receivers and devices directly attached to the machine.""" - yield from hidapi.enumerate(filter_products_of_interest) + yield from hidapi.enumerate(_filter_products_of_interest) def notify_on_receivers_glib(glib: GLib, callback: Callable): @@ -219,7 +221,7 @@ def notify_on_receivers_glib(glib: GLib, callback: Callable): glib GLib instance. """ - return hidapi.monitor_glib(glib, callback, filter_products_of_interest) + return hidapi.monitor_glib(glib, callback, _filter_products_of_interest) def open_path(path): @@ -321,7 +323,7 @@ def read(handle, timeout=DEFAULT_TIMEOUT): return reply -def is_relevant_message(data: bytes) -> bool: +def _is_relevant_message(data: bytes) -> bool: """Checks if given id is a HID++ or DJ message. Applies sanity checks on message report ID and message size. @@ -363,7 +365,7 @@ def _read(handle, timeout): close(handle) raise exceptions.NoReceiver(reason=reason) from reason - if data and is_relevant_message(data): # ignore messages that fail check + if data and _is_relevant_message(data): # ignore messages that fail check report_id = ord(data[:1]) devnumber = ord(data[1:2]) @@ -398,7 +400,7 @@ def _skip_incoming(handle, ihandle, notifications_hook): raise exceptions.NoReceiver(reason=reason) from reason if data: - if is_relevant_message(data): # only process messages that pass check + if _is_relevant_message(data): # only process messages that pass check # report_id = ord(data[:1]) if notifications_hook: n = make_notification(ord(data[:1]), ord(data[1:2]), data[2:]) @@ -409,23 +411,23 @@ def _skip_incoming(handle, ihandle, notifications_hook): return -def make_notification(report_id, devnumber, data) -> HIDPPNotification | None: +def make_notification(report_id: int, devnumber: int, data: bytes) -> HIDPPNotification | None: """Guess if this is a notification (and not just a request reply), and return a Notification if it is.""" sub_id = ord(data[:1]) if sub_id & 0x80 == 0x80: # this is either a HID++1.0 register r/w, or an error reply - return + return None # DJ input records are not notifications if report_id == DJ_MESSAGE_ID and (sub_id < 0x10): - return + return None address = ord(data[1:2]) if sub_id == 0x00 and (address & 0x0F == 0x00): # this is a no-op notification - don't do anything with it - return + return None if ( # standard HID++ 1.0 notification, SubId may be 0x40 - 0x7F @@ -441,6 +443,7 @@ def make_notification(report_id, devnumber, data) -> HIDPPNotification | None: (address & 0x0F == 0x00) ): # noqa: E129 return HIDPPNotification(report_id, devnumber, sub_id, address, data[2:]) + return None request_lock = threading.Lock() # serialize all requests @@ -554,7 +557,6 @@ def request( while delta < timeout: reply = _read(handle, timeout) - if reply: report_id, reply_devnumber, reply_data = reply if reply_devnumber == devnumber or reply_devnumber == devnumber ^ 0xFF: # BT device returning 0x00 diff --git a/tests/logitech_receiver/test_base.py b/tests/logitech_receiver/test_base.py index c84afb44..3908c1e7 100644 --- a/tests/logitech_receiver/test_base.py +++ b/tests/logitech_receiver/test_base.py @@ -29,7 +29,7 @@ def test_filter_receivers_known(): vendor_id = 0x046D product_id = 0xC548 - receiver_info = base.filter_receivers(bus_id, vendor_id, product_id) + receiver_info = base._filter_receivers(bus_id, vendor_id, product_id) assert receiver_info["name"] == "Bolt Receiver" assert receiver_info["receiver_kind"] == "bolt" @@ -40,22 +40,28 @@ def test_filter_receivers_unknown(): vendor_id = 0x046D product_id = 0xC500 - receiver_info = base.filter_receivers(bus_id, vendor_id, product_id) + receiver_info = base._filter_receivers(bus_id, vendor_id, product_id) assert receiver_info["bus_id"] == bus_id assert receiver_info["product_id"] == product_id @pytest.mark.parametrize( - "hidpp_short, hidpp_long", - [(True, False), (False, True), (False, False)], + "product_id, hidpp_short, hidpp_long", + [ + (0xC548, True, False), + (0xC07E, True, False), + (0xC07E, False, True), + (0xA07E, False, True), + (0xA07E, None, None), + (0xA07C, False, False), + ], ) -def test_filter_products_of_interest(hidpp_short, hidpp_long): +def test_filter_products_of_interest(product_id, hidpp_short, hidpp_long): bus_id = 3 vendor_id = 0x046D - product_id = 0xC07E - receiver_info = base.filter_products_of_interest( + receiver_info = base._filter_products_of_interest( bus_id, vendor_id, product_id, @@ -63,8 +69,40 @@ def test_filter_products_of_interest(hidpp_short, hidpp_long): hidpp_long=hidpp_long, ) - assert receiver_info["bus_id"] == bus_id - assert receiver_info["product_id"] == product_id + if not hidpp_short and not hidpp_long: + assert receiver_info is None + else: + assert isinstance(receiver_info["vendor_id"], int) + assert receiver_info["product_id"] == product_id + + +@pytest.mark.parametrize( + "report_id, sub_id, address, valid_notification", + [ + (0x1, 0x72, 0x57, True), + (0x1, 0x40, 0x63, True), + (0x1, 0x40, 0x71, True), + (0x1, 0x80, 0x71, False), + (0x1, 0x00, 0x70, False), + (0x20, 0x09, 0x71, False), + (0x1, 0x37, 0x71, False), + ], +) +def test_make_notification(report_id, sub_id, address, valid_notification): + devnumber = 123 + data = bytes([sub_id, address, 0x02, 0x03, 0x04]) + + result = base.make_notification(report_id, devnumber, data) + + if valid_notification: + assert isinstance(result, base.HIDPPNotification) + assert result.report_id == report_id + assert result.devnumber == devnumber + assert result.sub_id == sub_id + assert result.address == address + assert result.data == bytes([0x02, 0x03, 0x04]) + else: + assert result is None def test_get_next_sw_id():