base: Refactor device filtering

Related #2273
This commit is contained in:
MattHag 2024-12-29 23:05:43 +01:00 committed by Peter F. Patel-Schneider
parent 1e6af7fa7d
commit 3186d880fc
2 changed files with 54 additions and 35 deletions

View File

@ -56,7 +56,7 @@ else:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class HIDAPI(typing.Protocol): class HIDProtocol(typing.Protocol):
def find_paired_node_wpid(self, receiver_path: str, index: int): def find_paired_node_wpid(self, receiver_path: str, index: int):
... ...
@ -106,7 +106,7 @@ _DEVICE_REQUEST_TIMEOUT = DEFAULT_TIMEOUT
# when pinging, be extra patient (no longer) # when pinging, be extra patient (no longer)
_PING_TIMEOUT = DEFAULT_TIMEOUT _PING_TIMEOUT = DEFAULT_TIMEOUT
hidapi = typing.cast(HIDAPI, hidapi) hidapi = typing.cast(HIDProtocol, hidapi)
request_lock = threading.Lock() # serialize all requests request_lock = threading.Lock() # serialize all requests
handles_lock = {} handles_lock = {}
@ -156,7 +156,7 @@ def product_information(usb_id: int) -> dict[str, Any]:
def receivers(): def receivers():
"""Enumerate all the receivers attached to the machine.""" """Enumerate all the receivers attached to the machine."""
yield from hidapi.enumerate(_filter_receivers) yield from hidapi.enumerate(get_known_receiver_info)
def filter_products_of_interest( def filter_products_of_interest(
@ -164,34 +164,53 @@ def filter_products_of_interest(
) -> dict[str, Any] | None: ) -> dict[str, Any] | None:
"""Check that this product is of interest and if so return the device record for further checking""" """Check that this product is of interest and if so return the device record for further checking"""
def _other_device_check(bus_id: int, vendor_id: int, product_id: int) -> dict[str, Any] | None: recv = get_known_receiver_info(bus_id, vendor_id, product_id, hidpp_short, hidpp_long)
"""Check whether product is a Logitech USB-connected or Bluetooth device based on bus, vendor, and product IDs if recv: # known or unknown receiver
This allows Solaar to support receiverless HID++ 2.0 devices that it knows nothing about""" return recv
if vendor_id != LOGITECH_VENDOR_ID:
return
device_info = None device = get_known_device_info(bus_id, vendor_id, product_id)
if bus_id == BusID.USB and (0xC07D <= product_id <= 0xC094 or 0xC32B <= product_id <= 0xC344): if device:
device_info = _usb_device(product_id, 2) return device
elif bus_id == BusID.BLUETOOTH and (0xB012 <= product_id <= 0xB0FF or 0xB317 <= product_id <= 0xB3FF):
device_info = _bluetooth_device(product_id)
return device_info
record = _filter_receivers(bus_id, vendor_id, product_id, hidpp_short, hidpp_long) if hidpp_short or hidpp_long:
if record: # known or unknown receiver return get_unknown_hid_device_info(bus_id, vendor_id, product_id)
return record
for record in KNOWN_DEVICE_IDS: if hidpp_short is None and hidpp_long is None:
if _match(record, bus_id, vendor_id, product_id): return get_unknown_logitech_device_info(bus_id, vendor_id, product_id)
return record
if hidpp_short or hidpp_long: # unknown devices that use HID++
return {"vendor_id": vendor_id, "product_id": product_id, "bus_id": bus_id, "isDevice": True}
elif hidpp_short is None and hidpp_long is None: # unknown devices in correct range of IDs
return _other_device_check(bus_id, vendor_id, product_id)
return None return None
def _match(record: dict[str, Any], bus_id: int, vendor_id: int, product_id: int): def get_known_device_info(bus_id: int, vendor_id: int, product_id: int) -> dict[str, Any]:
for recv in KNOWN_DEVICE_IDS:
if _match_device(recv, bus_id, vendor_id, product_id):
return recv
def get_unknown_hid_device_info(bus_id: int, vendor_id: int, product_id: int) -> dict[str, Any]:
return {"vendor_id": vendor_id, "product_id": product_id, "bus_id": bus_id, "isDevice": True}
def get_unknown_logitech_device_info(bus_id: int, vendor_id: int, product_id: int) -> dict[str, Any] | None:
"""Get info from unknown device in Logitech product range.
Check whether product is a Logitech USB-connected or Bluetooth
device based on bus, vendor, and product ID. This allows Solaar to
support receiverless HID++ 2.0 devices that it knows nothing about.
"""
if vendor_id != LOGITECH_VENDOR_ID:
return None
if bus_id == BusID.USB.value and (0xC07D <= product_id <= 0xC094 or 0xC32B <= product_id <= 0xC344):
device_info = _usb_device(product_id, 2)
return device_info
elif bus_id == BusID.BLUETOOTH.value and (0xB012 <= product_id <= 0xB0FF or 0xB317 <= product_id <= 0xB3FF):
device_info = _bluetooth_device(product_id)
return device_info
return None
def _match_device(record: dict[str, Any], bus_id: int, vendor_id: int, product_id: int):
return ( return (
(record.get("bus_id") is None or record.get("bus_id") == bus_id) (record.get("bus_id") is None or record.get("bus_id") == bus_id)
and (record.get("vendor_id") is None or record.get("vendor_id") == vendor_id) and (record.get("vendor_id") is None or record.get("vendor_id") == vendor_id)
@ -199,7 +218,7 @@ def _match(record: dict[str, Any], bus_id: int, vendor_id: int, product_id: int)
) )
def _filter_receivers( def get_known_receiver_info(
bus_id: int, vendor_id: int, product_id: int, _hidpp_short: bool = False, _hidpp_long: bool = False bus_id: int, vendor_id: int, product_id: int, _hidpp_short: bool = False, _hidpp_long: bool = False
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Check that this product is a Logitech receiver and return it. """Check that this product is a Logitech receiver and return it.
@ -210,7 +229,7 @@ def _filter_receivers(
""" """
try: try:
record = base_usb.get_receiver_info(product_id) record = base_usb.get_receiver_info(product_id)
if _match(record, bus_id, vendor_id, product_id): if _match_device(record, bus_id, vendor_id, product_id):
return record return record
except ValueError: except ValueError:
pass pass
@ -507,7 +526,7 @@ def request(
ihandle = int(handle) ihandle = int(handle)
notifications_hook = getattr(handle, "notifications_hook", None) notifications_hook = getattr(handle, "notifications_hook", None)
try: try:
_skip_incoming(handle, ihandle, notifications_hook) _read_input_buffer(handle, ihandle, notifications_hook)
except exceptions.NoReceiver: except exceptions.NoReceiver:
logger.warning("device or receiver disconnected") logger.warning("device or receiver disconnected")
return None return None
@ -604,7 +623,7 @@ def ping(handle, devnumber, long_message: bool = False):
with acquire_timeout(handle_lock(handle), handle, 10.0): with acquire_timeout(handle_lock(handle), handle, 10.0):
notifications_hook = getattr(handle, "notifications_hook", None) notifications_hook = getattr(handle, "notifications_hook", None)
try: try:
_skip_incoming(handle, int(handle), notifications_hook) _read_input_buffer(handle, int(handle), notifications_hook)
except exceptions.NoReceiver: except exceptions.NoReceiver:
logger.warning("device or receiver disconnected") logger.warning("device or receiver disconnected")
return return
@ -651,8 +670,8 @@ def ping(handle, devnumber, long_message: bool = False):
logger.warning("(%s) timeout (%0.2f/%0.2f) on device %d ping", handle, delta, _PING_TIMEOUT, devnumber) logger.warning("(%s) timeout (%0.2f/%0.2f) on device %d ping", handle, delta, _PING_TIMEOUT, devnumber)
def _skip_incoming(handle, ihandle, notifications_hook): def _read_input_buffer(handle, ihandle, notifications_hook):
"""Read anything already in the input buffer. """Consume anything already in the input buffer.
Used by request() and ping() before their write. Used by request() and ping() before their write.
""" """

View File

@ -40,7 +40,7 @@ def test_filter_receivers_known():
bus_id = 2 bus_id = 2
product_id = 0xC548 product_id = 0xC548
receiver_info = base._filter_receivers(bus_id, LOGITECH_VENDOR_ID, product_id) receiver_info = base.get_known_receiver_info(bus_id, LOGITECH_VENDOR_ID, product_id)
assert receiver_info["name"] == "Bolt Receiver" assert receiver_info["name"] == "Bolt Receiver"
assert receiver_info["receiver_kind"] == "bolt" assert receiver_info["receiver_kind"] == "bolt"
@ -50,7 +50,7 @@ def test_filter_receivers_unknown():
bus_id = 1 bus_id = 1
product_id = 0xC500 product_id = 0xC500
receiver_info = base._filter_receivers(bus_id, LOGITECH_VENDOR_ID, product_id) receiver_info = base.get_known_receiver_info(bus_id, LOGITECH_VENDOR_ID, product_id)
assert receiver_info["bus_id"] == bus_id assert receiver_info["bus_id"] == bus_id
assert receiver_info["product_id"] == product_id assert receiver_info["product_id"] == product_id
@ -90,7 +90,7 @@ def test_filter_products_of_interest(product_id, bus, hidpp_short, hidpp_long, e
def test_match(): def test_match():
record = {"vendor_id": LOGITECH_VENDOR_ID} record = {"vendor_id": LOGITECH_VENDOR_ID}
res = base._match(record, 0, LOGITECH_VENDOR_ID, 0) res = base._match_device(record, 0, LOGITECH_VENDOR_ID, 0)
assert res is True assert res is True
@ -152,7 +152,7 @@ def test_request_errors(
with mock.patch( with mock.patch(
"logitech_receiver.base._read", "logitech_receiver.base._read",
return_value=(HIDPP_SHORT_MESSAGE_ID, device_number, prefix + reply_data_sw_id + struct.pack("B", error_code)), return_value=(HIDPP_SHORT_MESSAGE_ID, device_number, prefix + reply_data_sw_id + struct.pack("B", error_code)),
), mock.patch("logitech_receiver.base._skip_incoming", return_value=None), mock.patch( ), mock.patch("logitech_receiver.base._read_input_buffer"), mock.patch(
"logitech_receiver.base.write", return_value=None "logitech_receiver.base.write", return_value=None
), mock.patch("logitech_receiver.base._get_next_sw_id", return_value=next_sw_id): ), mock.patch("logitech_receiver.base._get_next_sw_id", return_value=next_sw_id):
if raise_exception: if raise_exception: