base: Refactor device filtering

Related #2273
This commit is contained in:
MattHag 2024-12-29 23:05:43 +01:00 committed by Peter F. Patel-Schneider
parent 1e6af7fa7d
commit 3186d880fc
2 changed files with 54 additions and 35 deletions

View File

@ -56,7 +56,7 @@ else:
logger = logging.getLogger(__name__)
class HIDAPI(typing.Protocol):
class HIDProtocol(typing.Protocol):
def find_paired_node_wpid(self, receiver_path: str, index: int):
...
@ -106,7 +106,7 @@ _DEVICE_REQUEST_TIMEOUT = DEFAULT_TIMEOUT
# when pinging, be extra patient (no longer)
_PING_TIMEOUT = DEFAULT_TIMEOUT
hidapi = typing.cast(HIDAPI, hidapi)
hidapi = typing.cast(HIDProtocol, hidapi)
request_lock = threading.Lock() # serialize all requests
handles_lock = {}
@ -156,7 +156,7 @@ def product_information(usb_id: int) -> dict[str, Any]:
def receivers():
"""Enumerate all the receivers attached to the machine."""
yield from hidapi.enumerate(_filter_receivers)
yield from hidapi.enumerate(get_known_receiver_info)
def filter_products_of_interest(
@ -164,34 +164,53 @@ def filter_products_of_interest(
) -> dict[str, Any] | None:
"""Check that this product is of interest and if so return the device record for further checking"""
def _other_device_check(bus_id: int, vendor_id: int, product_id: int) -> dict[str, Any] | None:
"""Check whether product is a Logitech USB-connected or Bluetooth device based on bus, vendor, and product IDs
This allows Solaar to support receiverless HID++ 2.0 devices that it knows nothing about"""
if vendor_id != LOGITECH_VENDOR_ID:
return
recv = get_known_receiver_info(bus_id, vendor_id, product_id, hidpp_short, hidpp_long)
if recv: # known or unknown receiver
return recv
device_info = None
if bus_id == BusID.USB and (0xC07D <= product_id <= 0xC094 or 0xC32B <= product_id <= 0xC344):
device_info = _usb_device(product_id, 2)
elif bus_id == BusID.BLUETOOTH and (0xB012 <= product_id <= 0xB0FF or 0xB317 <= product_id <= 0xB3FF):
device_info = _bluetooth_device(product_id)
return device_info
device = get_known_device_info(bus_id, vendor_id, product_id)
if device:
return device
record = _filter_receivers(bus_id, vendor_id, product_id, hidpp_short, hidpp_long)
if record: # known or unknown receiver
return record
if hidpp_short or hidpp_long:
return get_unknown_hid_device_info(bus_id, vendor_id, product_id)
for record in KNOWN_DEVICE_IDS:
if _match(record, bus_id, vendor_id, product_id):
return record
if hidpp_short or hidpp_long: # unknown devices that use HID++
return {"vendor_id": vendor_id, "product_id": product_id, "bus_id": bus_id, "isDevice": True}
elif hidpp_short is None and hidpp_long is None: # unknown devices in correct range of IDs
return _other_device_check(bus_id, vendor_id, product_id)
if hidpp_short is None and hidpp_long is None:
return get_unknown_logitech_device_info(bus_id, vendor_id, product_id)
return None
def _match(record: dict[str, Any], bus_id: int, vendor_id: int, product_id: int):
def get_known_device_info(bus_id: int, vendor_id: int, product_id: int) -> dict[str, Any]:
for recv in KNOWN_DEVICE_IDS:
if _match_device(recv, bus_id, vendor_id, product_id):
return recv
def get_unknown_hid_device_info(bus_id: int, vendor_id: int, product_id: int) -> dict[str, Any]:
return {"vendor_id": vendor_id, "product_id": product_id, "bus_id": bus_id, "isDevice": True}
def get_unknown_logitech_device_info(bus_id: int, vendor_id: int, product_id: int) -> dict[str, Any] | None:
"""Get info from unknown device in Logitech product range.
Check whether product is a Logitech USB-connected or Bluetooth
device based on bus, vendor, and product ID. This allows Solaar to
support receiverless HID++ 2.0 devices that it knows nothing about.
"""
if vendor_id != LOGITECH_VENDOR_ID:
return None
if bus_id == BusID.USB.value and (0xC07D <= product_id <= 0xC094 or 0xC32B <= product_id <= 0xC344):
device_info = _usb_device(product_id, 2)
return device_info
elif bus_id == BusID.BLUETOOTH.value and (0xB012 <= product_id <= 0xB0FF or 0xB317 <= product_id <= 0xB3FF):
device_info = _bluetooth_device(product_id)
return device_info
return None
def _match_device(record: dict[str, Any], bus_id: int, vendor_id: int, product_id: int):
return (
(record.get("bus_id") is None or record.get("bus_id") == bus_id)
and (record.get("vendor_id") is None or record.get("vendor_id") == vendor_id)
@ -199,7 +218,7 @@ def _match(record: dict[str, Any], bus_id: int, vendor_id: int, product_id: int)
)
def _filter_receivers(
def get_known_receiver_info(
bus_id: int, vendor_id: int, product_id: int, _hidpp_short: bool = False, _hidpp_long: bool = False
) -> dict[str, Any]:
"""Check that this product is a Logitech receiver and return it.
@ -210,7 +229,7 @@ def _filter_receivers(
"""
try:
record = base_usb.get_receiver_info(product_id)
if _match(record, bus_id, vendor_id, product_id):
if _match_device(record, bus_id, vendor_id, product_id):
return record
except ValueError:
pass
@ -507,7 +526,7 @@ def request(
ihandle = int(handle)
notifications_hook = getattr(handle, "notifications_hook", None)
try:
_skip_incoming(handle, ihandle, notifications_hook)
_read_input_buffer(handle, ihandle, notifications_hook)
except exceptions.NoReceiver:
logger.warning("device or receiver disconnected")
return None
@ -604,7 +623,7 @@ def ping(handle, devnumber, long_message: bool = False):
with acquire_timeout(handle_lock(handle), handle, 10.0):
notifications_hook = getattr(handle, "notifications_hook", None)
try:
_skip_incoming(handle, int(handle), notifications_hook)
_read_input_buffer(handle, int(handle), notifications_hook)
except exceptions.NoReceiver:
logger.warning("device or receiver disconnected")
return
@ -651,8 +670,8 @@ def ping(handle, devnumber, long_message: bool = False):
logger.warning("(%s) timeout (%0.2f/%0.2f) on device %d ping", handle, delta, _PING_TIMEOUT, devnumber)
def _skip_incoming(handle, ihandle, notifications_hook):
"""Read anything already in the input buffer.
def _read_input_buffer(handle, ihandle, notifications_hook):
"""Consume anything already in the input buffer.
Used by request() and ping() before their write.
"""

View File

@ -40,7 +40,7 @@ def test_filter_receivers_known():
bus_id = 2
product_id = 0xC548
receiver_info = base._filter_receivers(bus_id, LOGITECH_VENDOR_ID, product_id)
receiver_info = base.get_known_receiver_info(bus_id, LOGITECH_VENDOR_ID, product_id)
assert receiver_info["name"] == "Bolt Receiver"
assert receiver_info["receiver_kind"] == "bolt"
@ -50,7 +50,7 @@ def test_filter_receivers_unknown():
bus_id = 1
product_id = 0xC500
receiver_info = base._filter_receivers(bus_id, LOGITECH_VENDOR_ID, product_id)
receiver_info = base.get_known_receiver_info(bus_id, LOGITECH_VENDOR_ID, product_id)
assert receiver_info["bus_id"] == bus_id
assert receiver_info["product_id"] == product_id
@ -90,7 +90,7 @@ def test_filter_products_of_interest(product_id, bus, hidpp_short, hidpp_long, e
def test_match():
record = {"vendor_id": LOGITECH_VENDOR_ID}
res = base._match(record, 0, LOGITECH_VENDOR_ID, 0)
res = base._match_device(record, 0, LOGITECH_VENDOR_ID, 0)
assert res is True
@ -152,7 +152,7 @@ def test_request_errors(
with mock.patch(
"logitech_receiver.base._read",
return_value=(HIDPP_SHORT_MESSAGE_ID, device_number, prefix + reply_data_sw_id + struct.pack("B", error_code)),
), mock.patch("logitech_receiver.base._skip_incoming", return_value=None), mock.patch(
), mock.patch("logitech_receiver.base._read_input_buffer"), mock.patch(
"logitech_receiver.base.write", return_value=None
), mock.patch("logitech_receiver.base._get_next_sw_id", return_value=next_sw_id):
if raise_exception: