diff --git a/lib/logitech_receiver/receiver.py b/lib/logitech_receiver/receiver.py index f260807b..efd85ec3 100644 --- a/lib/logitech_receiver/receiver.py +++ b/lib/logitech_receiver/receiver.py @@ -20,13 +20,12 @@ import logging import time from dataclasses import dataclass +from typing import Any from typing import Callable from typing import Optional from typing import Protocol from typing import cast -import hidapi - from solaar.i18n import _ from solaar.i18n import ngettext @@ -88,8 +87,18 @@ class Receiver: number = 0xFF kind = None - def __init__(self, receiver_kind, product_info, handle, path, product_id, setting_callback=None): + def __init__( + self, + find_paired_node_wpid_func: Callable[[str, int], Any], + receiver_kind, + product_info, + handle, + path, + product_id, + setting_callback=None, + ): assert handle + self._find_paired_node_wpid_func = find_paired_node_wpid_func self.isDevice = False # some devices act as receiver so we need a property to distinguish them self.handle = handle self.path = path @@ -389,8 +398,10 @@ class Receiver: class BoltReceiver(Receiver): """Bolt receivers use a different pairing prototol and have different pairing registers""" - def __init__(self, receiver_kind, product_info, handle, path, product_id, setting_callback=None): - super().__init__(receiver_kind, product_info, handle, path, product_id, setting_callback) + def __init__( + self, find_paired_node_wpid_func, receiver_kind, product_info, handle, path, product_id, setting_callback=None + ): + super().__init__(find_paired_node_wpid_func, receiver_kind, product_info, handle, path, product_id, setting_callback) def initialize(self, product_info: dict): serial_reply = self.read_register(Registers.BOLT_UNIQUE_ID) @@ -437,25 +448,27 @@ class BoltReceiver(Receiver): class UnifyingReceiver(Receiver): - def __init__(self, receiver_kind, product_info, handle, path, product_id, setting_callback=None): - super().__init__(receiver_kind, product_info, handle, path, product_id, setting_callback) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) class NanoReceiver(Receiver): - def __init__(self, receiver_kind, product_info, handle, path, product_id, setting_callback=None): - super().__init__(receiver_kind, product_info, handle, path, product_id, setting_callback) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) class LightSpeedReceiver(Receiver): - def __init__(self, receiver_kind, product_info, handle, path, product_id, setting_callback=None): - super().__init__(receiver_kind, product_info, handle, path, product_id, setting_callback) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) class Ex100Receiver(Receiver): """A very old style receiver, somewhat different from newer receivers""" - def __init__(self, receiver_kind, product_info, handle, path, product_id, setting_callback=None): - super().__init__(receiver_kind, product_info, handle, path, product_id, setting_callback) + def __init__( + self, find_paired_node_wpid_func, receiver_kind, product_info, handle, path, product_id, setting_callback=None + ): + super().__init__(find_paired_node_wpid_func, receiver_kind, product_info, handle, path, product_id, setting_callback) def initialize(self, product_info: dict): self.serial = None @@ -471,7 +484,7 @@ class Ex100Receiver(Receiver): return online, encrypted, wpid, kind def device_pairing_information(self, number: int) -> dict: - wpid = hidapi.find_paired_node_wpid(self.path, number) # extract WPID from udev path + wpid = self._find_paired_node_wpid_func(self.path, number) # extract WPID from udev path if not wpid: logger.error("Unable to get wpid from udev for device %d of %s", number, self) raise exceptions.NoSuchDevice(number=number, receiver=self, error="Not present 27Mhz device") @@ -507,7 +520,9 @@ receiver_class_mapping = { class ReceiverFactory: @staticmethod - def create_receiver(device_info, setting_callback=None) -> Optional[Receiver]: + def create_receiver( + find_paired_node_wpid_func: Callable[[str, int], Any], device_info, setting_callback=None + ) -> Optional[Receiver]: """Opens a Logitech Receiver found attached to the machine, by Linux device path.""" try: @@ -522,7 +537,15 @@ class ReceiverFactory: product_info = {} kind = product_info.get("receiver_kind", "unknown") rclass = receiver_class_mapping.get(kind, Receiver) - return rclass(kind, product_info, handle, device_info.path, device_info.product_id, setting_callback) + return rclass( + find_paired_node_wpid_func, + kind, + product_info, + handle, + device_info.path, + device_info.product_id, + setting_callback, + ) except OSError as e: logger.exception("open %s", device_info) if e.errno == errno.EACCES: diff --git a/lib/solaar/listener.py b/lib/solaar/listener.py index 9dfd53fd..66df8ea6 100644 --- a/lib/solaar/listener.py +++ b/lib/solaar/listener.py @@ -256,7 +256,9 @@ def _start(device_info): assert _status_callback and _setting_callback isDevice = device_info.isDevice if not isDevice: - receiver_ = logitech_receiver.receiver.ReceiverFactory.create_receiver(device_info, _setting_callback) + receiver_ = logitech_receiver.receiver.ReceiverFactory.create_receiver( + hidapi.find_paired_node_wpid, device_info, _setting_callback + ) else: receiver_ = logitech_receiver.device.DeviceFactory.create_device( hidapi.find_paired_node, base, device_info, _setting_callback diff --git a/tests/logitech_receiver/test_receiver.py b/tests/logitech_receiver/test_receiver.py index 98881d47..484a04c4 100644 --- a/tests/logitech_receiver/test_receiver.py +++ b/tests/logitech_receiver/test_receiver.py @@ -4,6 +4,7 @@ from dataclasses import dataclass from functools import partial from unittest import mock +import hidapi import pytest from logitech_receiver import common @@ -115,15 +116,16 @@ c534_info = {"kind": common.NamedInt(0, "unknown"), "polling": "", "power_switch def test_ReceiverFactory_create_receiver(device_info, responses, handle, serial, max_devices, mock_base): mock_base[0].side_effect = fake_hidpp.open_path mock_base[1].side_effect = partial(fake_hidpp.request, responses) + find_paired_node_wpid_func = hidapi.find_paired_node_wpid if handle is False: with pytest.raises(Exception): # noqa: B017 - r = receiver.ReceiverFactory.create_receiver(device_info, lambda x: x) + receiver.ReceiverFactory.create_receiver(find_paired_node_wpid_func, device_info, lambda x: x) elif handle is None: - r = receiver.ReceiverFactory.create_receiver(device_info, lambda x: x) + r = receiver.ReceiverFactory.create_receiver(find_paired_node_wpid_func, device_info, lambda x: x) assert r is None else: - r = receiver.ReceiverFactory.create_receiver(device_info, lambda x: x) + r = receiver.ReceiverFactory.create_receiver(find_paired_node_wpid_func, device_info, lambda x: x) assert r.handle == handle assert r.serial == serial assert r.max_devices == max_devices @@ -142,7 +144,7 @@ def test_ReceiverFactory_props(device_info, responses, firmware, codename, remai mock_base[0].side_effect = fake_hidpp.open_path mock_base[1].side_effect = partial(fake_hidpp.request, responses) - r = receiver.ReceiverFactory.create_receiver(device_info, lambda x: x) + r = receiver.ReceiverFactory.create_receiver(mock.Mock(), device_info, lambda x: x) assert len(r.firmware) == firmware if firmware is not None else firmware is None assert r.device_codename(2) == codename @@ -164,7 +166,7 @@ def test_ReceiverFactory_string(device_info, responses, status_str, strng, mock_ mock_base[0].side_effect = fake_hidpp.open_path mock_base[1].side_effect = partial(fake_hidpp.request, responses) - r = receiver.ReceiverFactory.create_receiver(device_info, lambda x: x) + r = receiver.ReceiverFactory.create_receiver(mock.Mock(), device_info, lambda x: x) assert r.status_string() == status_str assert str(r) == strng @@ -182,7 +184,7 @@ def test_ReceiverFactory_nodevice(device_info, responses, mock_base): mock_base[0].side_effect = fake_hidpp.open_path mock_base[1].side_effect = partial(fake_hidpp.request, responses) - r = receiver.ReceiverFactory.create_receiver(device_info, lambda x: x) + r = receiver.ReceiverFactory.create_receiver(mock.Mock(), device_info, lambda x: x) with pytest.raises(exceptions.NoSuchDevice): r.device_pairing_information(1) diff --git a/tests/solaar/ui/test_pair_window.py b/tests/solaar/ui/test_pair_window.py index 1593f5c5..ec6c3309 100644 --- a/tests/solaar/ui/test_pair_window.py +++ b/tests/solaar/ui/test_pair_window.py @@ -1,8 +1,10 @@ from dataclasses import dataclass from dataclasses import field from typing import Any +from typing import Callable from typing import List from typing import Optional +from unittest import mock import gi import pytest @@ -24,6 +26,7 @@ class Device: @dataclass class Receiver: + find_paired_node_wpid_func: Callable[[str, int], Any] name: str receiver_kind: str _set_lock: bool = True @@ -84,12 +87,12 @@ class Assistant: @pytest.mark.parametrize( "receiver, lock_open, discovering, page_type", [ - (Receiver("unifying", "unifying", True), True, False, Gtk.AssistantPageType.PROGRESS), - (Receiver("unifying", "unifying", False), False, False, Gtk.AssistantPageType.SUMMARY), - (Receiver("nano", "nano", True, _remaining_pairings=5), True, False, Gtk.AssistantPageType.PROGRESS), - (Receiver("nano", "nano", False), False, False, Gtk.AssistantPageType.SUMMARY), - (Receiver("bolt", "bolt", True), False, True, Gtk.AssistantPageType.PROGRESS), - (Receiver("bolt", "bolt", False), False, False, Gtk.AssistantPageType.SUMMARY), + (Receiver(mock.Mock(), "unifying", "unifying", True), True, False, Gtk.AssistantPageType.PROGRESS), + (Receiver(mock.Mock(), "unifying", "unifying", False), False, False, Gtk.AssistantPageType.SUMMARY), + (Receiver(mock.Mock(), "nano", "nano", True, _remaining_pairings=5), True, False, Gtk.AssistantPageType.PROGRESS), + (Receiver(mock.Mock(), "nano", "nano", False), False, False, Gtk.AssistantPageType.SUMMARY), + (Receiver(mock.Mock(), "bolt", "bolt", True), False, True, Gtk.AssistantPageType.PROGRESS), + (Receiver(mock.Mock(), "bolt", "bolt", False), False, False, Gtk.AssistantPageType.SUMMARY), ], ) def test_create(receiver, lock_open, discovering, page_type): @@ -105,10 +108,10 @@ def test_create(receiver, lock_open, discovering, page_type): @pytest.mark.parametrize( "receiver, expected_result, expected_error", [ - (Receiver("unifying", "unifying", True), True, False), - (Receiver("unifying", "unifying", False), False, True), - (Receiver("bolt", "bolt", True), True, False), - (Receiver("bolt", "bolt", False), False, True), + (Receiver(mock.Mock(), "unifying", "unifying", True), True, False), + (Receiver(mock.Mock(), "unifying", "unifying", False), False, True), + (Receiver(mock.Mock(), "bolt", "bolt", True), True, False), + (Receiver(mock.Mock(), "bolt", "bolt", False), False, True), ], ) def test_prepare(receiver, expected_result, expected_error): @@ -120,7 +123,7 @@ def test_prepare(receiver, expected_result, expected_error): @pytest.mark.parametrize("assistant, expected_result", [(Assistant(True), True), (Assistant(False), False)]) def test_check_lock_state_drawable(assistant, expected_result): - r = Receiver("succeed", "unifying", True, receiver.Pairing(lock_open=True)) + r = Receiver(mock.Mock(), "succeed", "unifying", True, receiver.Pairing(lock_open=True)) result = pair_window.check_lock_state(assistant, r, 2) @@ -131,42 +134,68 @@ def test_check_lock_state_drawable(assistant, expected_result): @pytest.mark.parametrize( "receiver, count, expected_result", [ - (Receiver("fail", "unifying", False, receiver.Pairing(lock_open=False)), 2, False), - (Receiver("succeed", "unifying", True, receiver.Pairing(lock_open=True)), 1, True), - (Receiver("error", "unifying", True, receiver.Pairing(error="error")), 0, False), - (Receiver("new device", "unifying", True, receiver.Pairing(new_device=Device())), 2, False), - (Receiver("closed", "unifying", True, receiver.Pairing()), 2, False), - (Receiver("closed", "unifying", True, receiver.Pairing()), 1, False), - (Receiver("closed", "unifying", True, receiver.Pairing()), 0, False), - (Receiver("fail bolt", "bolt", False), 1, False), - (Receiver("succeed bolt", "bolt", True, receiver.Pairing(lock_open=True)), 0, True), - (Receiver("error bolt", "bolt", True, receiver.Pairing(error="error")), 2, False), - (Receiver("new device", "bolt", True, receiver.Pairing(lock_open=True, new_device=Device())), 1, False), - (Receiver("discovering", "bolt", True, receiver.Pairing(lock_open=True)), 1, True), - (Receiver("closed", "bolt", True, receiver.Pairing()), 2, False), - (Receiver("closed", "bolt", True, receiver.Pairing()), 1, False), - (Receiver("closed", "bolt", True, receiver.Pairing()), 0, False), + (Receiver(mock.Mock(), "fail", "unifying", False, receiver.Pairing(lock_open=False)), 2, False), + (Receiver(mock.Mock(), "succeed", "unifying", True, receiver.Pairing(lock_open=True)), 1, True), + (Receiver(mock.Mock(), "error", "unifying", True, receiver.Pairing(error="error")), 0, False), + (Receiver(mock.Mock(), "new device", "unifying", True, receiver.Pairing(new_device=Device())), 2, False), + (Receiver(mock.Mock(), "closed", "unifying", True, receiver.Pairing()), 2, False), + (Receiver(mock.Mock(), "closed", "unifying", True, receiver.Pairing()), 1, False), + (Receiver(mock.Mock(), "closed", "unifying", True, receiver.Pairing()), 0, False), + (Receiver(mock.Mock(), "fail bolt", "bolt", False), 1, False), + (Receiver(mock.Mock(), "succeed bolt", "bolt", True, receiver.Pairing(lock_open=True)), 0, True), + (Receiver(mock.Mock(), "error bolt", "bolt", True, receiver.Pairing(error="error")), 2, False), + (Receiver(mock.Mock(), "new device", "bolt", True, receiver.Pairing(lock_open=True, new_device=Device())), 1, False), + (Receiver(mock.Mock(), "discovering", "bolt", True, receiver.Pairing(lock_open=True)), 1, True), + (Receiver(mock.Mock(), "closed", "bolt", True, receiver.Pairing()), 2, False), + (Receiver(mock.Mock(), "closed", "bolt", True, receiver.Pairing()), 1, False), + (Receiver(mock.Mock(), "closed", "bolt", True, receiver.Pairing()), 0, False), ( - Receiver("pass1", "bolt", True, receiver.Pairing(lock_open=True, device_passkey=50, device_authentication=0x01)), + Receiver( + mock.Mock(), + "pass1", + "bolt", + True, + receiver.Pairing(lock_open=True, device_passkey=50, device_authentication=0x01), + ), 0, True, ), ( - Receiver("pass2", "bolt", True, receiver.Pairing(lock_open=True, device_passkey=50, device_authentication=0x02)), + Receiver( + mock.Mock(), + "pass2", + "bolt", + True, + receiver.Pairing(lock_open=True, device_passkey=50, device_authentication=0x02), + ), 0, True, ), ( - Receiver("adt", "bolt", True, receiver.Pairing(discovering=True, device_address=2, device_name=5), pairable=True), + Receiver( + mock.Mock(), + "adt", + "bolt", + True, + receiver.Pairing(discovering=True, device_address=2, device_name=5), + pairable=True, + ), 2, True, ), ( - Receiver("adf", "bolt", True, receiver.Pairing(discovering=True, device_address=2, device_name=5), pairable=False), + Receiver( + mock.Mock(), + "adf", + "bolt", + True, + receiver.Pairing(discovering=True, device_address=2, device_name=5), + pairable=False, + ), 2, False, ), - (Receiver("add fail", "bolt", False, receiver.Pairing(device_address=2, device_passkey=5)), 2, False), + (Receiver(mock.Mock(), "add fail", "bolt", False, receiver.Pairing(device_address=2, device_passkey=5)), 2, False), ], ) def test_check_lock_state(receiver, count, expected_result): @@ -180,11 +209,23 @@ def test_check_lock_state(receiver, count, expected_result): @pytest.mark.parametrize( "receiver, pair_device, set_lock, discover, error", [ - (Receiver("unifying", "unifying", pairing=receiver.Pairing(lock_open=False, error="error")), 0, 0, 0, None), - (Receiver("unifying", "unifying", pairing=receiver.Pairing(lock_open=True, error="error")), 0, 1, 0, "error"), - (Receiver("bolt", "bolt", pairing=receiver.Pairing(lock_open=False, error="error")), 0, 0, 0, None), - (Receiver("bolt", "bolt", pairing=receiver.Pairing(lock_open=True, error="error")), 1, 0, 0, "error"), - (Receiver("bolt", "bolt", pairing=receiver.Pairing(discovering=True, error="error")), 0, 0, 1, "error"), + ( + Receiver(mock.Mock(), "unifying", "unifying", pairing=receiver.Pairing(lock_open=False, error="error")), + 0, + 0, + 0, + None, + ), + ( + Receiver(mock.Mock(), "unifying", "unifying", pairing=receiver.Pairing(lock_open=True, error="error")), + 0, + 1, + 0, + "error", + ), + (Receiver(mock.Mock(), "bolt", "bolt", pairing=receiver.Pairing(lock_open=False, error="error")), 0, 0, 0, None), + (Receiver(mock.Mock(), "bolt", "bolt", pairing=receiver.Pairing(lock_open=True, error="error")), 1, 0, 0, "error"), + (Receiver(mock.Mock(), "bolt", "bolt", pairing=receiver.Pairing(discovering=True, error="error")), 0, 0, 1, "error"), ], ) def test_finish(receiver, pair_device, set_lock, discover, error, mocker): @@ -206,6 +247,6 @@ def test_finish(receiver, pair_device, set_lock, discover, error, mocker): def test_create_failure_page(error, mocker): spy_create = mocker.spy(pair_window, "_create_page") - pair_window._pairing_failed(Assistant(True), Receiver("nano", "nano"), error) + pair_window._pairing_failed(Assistant(True), Receiver(mock.Mock(), "nano", "nano"), error) assert spy_create.call_count == 1