base: Add more unit tests

Make internal functions private.
This commit is contained in:
MattHag 2024-10-02 22:07:28 +02:00 committed by Peter F. Patel-Schneider
parent 58ddb0d6cd
commit 1f85ec01e7
2 changed files with 66 additions and 26 deletions

View File

@ -137,7 +137,7 @@ for _ignore, d in descriptors.DEVICES.items():
KNOWN_DEVICE_IDS.append(_bluetooth_device(d.btid))
def other_device_check(bus_id: int, vendor_id: int, product_id: int) -> dict[str, Any] | None:
def _other_device_check(bus_id: int, vendor_id: int, product_id: int) -> dict[str, Any] | None:
"""Check whether product is a Logitech USB-connected or Bluetooth device based on bus, vendor, and product IDs
This allows Solaar to support receiverless HID++ 2.0 devices that it knows nothing about"""
if vendor_id != LOGITECH_VENDOR_ID:
@ -164,10 +164,10 @@ def _match(record: dict[str, Any], bus_id: int, vendor_id: int, product_id: int)
)
def filter_receivers(
def _filter_receivers(
bus_id: int, vendor_id: int, product_id: int, _hidpp_short: bool = False, _hidpp_long: bool = False
) -> dict[str, Any]:
"""Check that this product is a Logitech receiver.
"""Check that this product is a Logitech receiver and return it.
Filters based on bus_id, vendor_id and product_id.
@ -182,18 +182,19 @@ def filter_receivers(
if vendor_id == LOGITECH_VENDOR_ID and 0xC500 <= product_id <= 0xC5FF: # unknown receiver
return {"vendor_id": vendor_id, "product_id": product_id, "bus_id": bus_id, "isDevice": False}
return None
def receivers():
"""Enumerate all the receivers attached to the machine."""
yield from hidapi.enumerate(filter_receivers)
yield from hidapi.enumerate(_filter_receivers)
def filter_products_of_interest(
def _filter_products_of_interest(
bus_id: int, vendor_id: int, product_id: int, hidpp_short: bool = False, hidpp_long: bool = False
) -> dict[str, Any] | None:
"""Check that this product is of interest and if so return the device record for further checking"""
record = filter_receivers(bus_id, vendor_id, product_id, hidpp_short, hidpp_long)
record = _filter_receivers(bus_id, vendor_id, product_id, hidpp_short, hidpp_long)
if record: # known or unknown receiver
return record
@ -203,12 +204,13 @@ def filter_products_of_interest(
if hidpp_short or hidpp_long: # unknown devices that use HID++
return {"vendor_id": vendor_id, "product_id": product_id, "bus_id": bus_id, "isDevice": True}
elif hidpp_short is None and hidpp_long is None: # unknown devices in correct range of IDs
return other_device_check(bus_id, vendor_id, product_id)
return _other_device_check(bus_id, vendor_id, product_id)
return None
def receivers_and_devices():
"""Enumerate all the receivers and devices directly attached to the machine."""
yield from hidapi.enumerate(filter_products_of_interest)
yield from hidapi.enumerate(_filter_products_of_interest)
def notify_on_receivers_glib(glib: GLib, callback: Callable):
@ -219,7 +221,7 @@ def notify_on_receivers_glib(glib: GLib, callback: Callable):
glib
GLib instance.
"""
return hidapi.monitor_glib(glib, callback, filter_products_of_interest)
return hidapi.monitor_glib(glib, callback, _filter_products_of_interest)
def open_path(path):
@ -321,7 +323,7 @@ def read(handle, timeout=DEFAULT_TIMEOUT):
return reply
def is_relevant_message(data: bytes) -> bool:
def _is_relevant_message(data: bytes) -> bool:
"""Checks if given id is a HID++ or DJ message.
Applies sanity checks on message report ID and message size.
@ -363,7 +365,7 @@ def _read(handle, timeout):
close(handle)
raise exceptions.NoReceiver(reason=reason) from reason
if data and is_relevant_message(data): # ignore messages that fail check
if data and _is_relevant_message(data): # ignore messages that fail check
report_id = ord(data[:1])
devnumber = ord(data[1:2])
@ -398,7 +400,7 @@ def _skip_incoming(handle, ihandle, notifications_hook):
raise exceptions.NoReceiver(reason=reason) from reason
if data:
if is_relevant_message(data): # only process messages that pass check
if _is_relevant_message(data): # only process messages that pass check
# report_id = ord(data[:1])
if notifications_hook:
n = make_notification(ord(data[:1]), ord(data[1:2]), data[2:])
@ -409,23 +411,23 @@ def _skip_incoming(handle, ihandle, notifications_hook):
return
def make_notification(report_id, devnumber, data) -> HIDPPNotification | None:
def make_notification(report_id: int, devnumber: int, data: bytes) -> HIDPPNotification | None:
"""Guess if this is a notification (and not just a request reply), and
return a Notification if it is."""
sub_id = ord(data[:1])
if sub_id & 0x80 == 0x80:
# this is either a HID++1.0 register r/w, or an error reply
return
return None
# DJ input records are not notifications
if report_id == DJ_MESSAGE_ID and (sub_id < 0x10):
return
return None
address = ord(data[1:2])
if sub_id == 0x00 and (address & 0x0F == 0x00):
# this is a no-op notification - don't do anything with it
return
return None
if (
# standard HID++ 1.0 notification, SubId may be 0x40 - 0x7F
@ -441,6 +443,7 @@ def make_notification(report_id, devnumber, data) -> HIDPPNotification | None:
(address & 0x0F == 0x00)
): # noqa: E129
return HIDPPNotification(report_id, devnumber, sub_id, address, data[2:])
return None
request_lock = threading.Lock() # serialize all requests
@ -554,7 +557,6 @@ def request(
while delta < timeout:
reply = _read(handle, timeout)
if reply:
report_id, reply_devnumber, reply_data = reply
if reply_devnumber == devnumber or reply_devnumber == devnumber ^ 0xFF: # BT device returning 0x00

View File

@ -29,7 +29,7 @@ def test_filter_receivers_known():
vendor_id = 0x046D
product_id = 0xC548
receiver_info = base.filter_receivers(bus_id, vendor_id, product_id)
receiver_info = base._filter_receivers(bus_id, vendor_id, product_id)
assert receiver_info["name"] == "Bolt Receiver"
assert receiver_info["receiver_kind"] == "bolt"
@ -40,22 +40,28 @@ def test_filter_receivers_unknown():
vendor_id = 0x046D
product_id = 0xC500
receiver_info = base.filter_receivers(bus_id, vendor_id, product_id)
receiver_info = base._filter_receivers(bus_id, vendor_id, product_id)
assert receiver_info["bus_id"] == bus_id
assert receiver_info["product_id"] == product_id
@pytest.mark.parametrize(
"hidpp_short, hidpp_long",
[(True, False), (False, True), (False, False)],
"product_id, hidpp_short, hidpp_long",
[
(0xC548, True, False),
(0xC07E, True, False),
(0xC07E, False, True),
(0xA07E, False, True),
(0xA07E, None, None),
(0xA07C, False, False),
],
)
def test_filter_products_of_interest(hidpp_short, hidpp_long):
def test_filter_products_of_interest(product_id, hidpp_short, hidpp_long):
bus_id = 3
vendor_id = 0x046D
product_id = 0xC07E
receiver_info = base.filter_products_of_interest(
receiver_info = base._filter_products_of_interest(
bus_id,
vendor_id,
product_id,
@ -63,8 +69,40 @@ def test_filter_products_of_interest(hidpp_short, hidpp_long):
hidpp_long=hidpp_long,
)
assert receiver_info["bus_id"] == bus_id
assert receiver_info["product_id"] == product_id
if not hidpp_short and not hidpp_long:
assert receiver_info is None
else:
assert isinstance(receiver_info["vendor_id"], int)
assert receiver_info["product_id"] == product_id
@pytest.mark.parametrize(
"report_id, sub_id, address, valid_notification",
[
(0x1, 0x72, 0x57, True),
(0x1, 0x40, 0x63, True),
(0x1, 0x40, 0x71, True),
(0x1, 0x80, 0x71, False),
(0x1, 0x00, 0x70, False),
(0x20, 0x09, 0x71, False),
(0x1, 0x37, 0x71, False),
],
)
def test_make_notification(report_id, sub_id, address, valid_notification):
devnumber = 123
data = bytes([sub_id, address, 0x02, 0x03, 0x04])
result = base.make_notification(report_id, devnumber, data)
if valid_notification:
assert isinstance(result, base.HIDPPNotification)
assert result.report_id == report_id
assert result.devnumber == devnumber
assert result.sub_id == sub_id
assert result.address == address
assert result.data == bytes([0x02, 0x03, 0x04])
else:
assert result is None
def test_get_next_sw_id():