diversion: Introduce protocols to unite Action and Condition classes

Enforce a common interface for all Action and Condition related classes
and connect them to a common protocol class to support isinstance
checks.

Related #2659
This commit is contained in:
MattHag 2024-12-31 18:31:53 +01:00
parent 9af34b33e8
commit a70366a786
4 changed files with 60 additions and 41 deletions

View File

@ -512,25 +512,25 @@ MOUSE_GESTURE_TESTS = {
} }
def compile_component(c): def compile_component(c) -> Rule | type[ConditionProtocol] | type[ActionProtocol]:
if isinstance(c, Rule) or isinstance(c, Condition) or isinstance(c, Action): if isinstance(c, Rule) or isinstance(c, ConditionProtocol) or isinstance(c, ActionProtocol):
return c return c
elif isinstance(c, dict) and len(c) == 1: elif isinstance(c, dict) and len(c) == 1:
k, v = next(iter(c.items())) k, v = next(iter(c.items()))
if k in COMPONENTS: if k in COMPONENTS:
cls = COMPONENTS[k] cls: Rule | type[ConditionProtocol] | type[ActionProtocol] = COMPONENTS[k]
return cls(v) return cls(v)
logger.warning("illegal component in rule: %s", c) logger.warning("illegal component in rule: %s", c)
return FalllbackCondition() return FallbackCondition()
def _evaluate(components, feature, notification: HIDPPNotification, device, result) -> Any: def _evaluate(components, feature, notification: HIDPPNotification, device, result) -> Any:
res = True res = True
for component in components: for component in components:
res = component.evaluate(feature, notification, device, result) res = component.evaluate(feature, notification, device, result)
if not isinstance(component, Action) and res is None: if not isinstance(component, ActionProtocol) and res is None:
return None return None
if isinstance(component, Condition) and not res: if isinstance(component, ConditionProtocol) and not res:
return res return res
return res return res
@ -557,7 +557,22 @@ class Rule:
return {"Rule": [c.data() for c in self.components]} return {"Rule": [c.data() for c in self.components]}
class Condition: @typing.runtime_checkable
class ConditionProtocol(typing.Protocol):
def __init__(self, args: Any, warn: bool) -> None:
...
def __str__(self) -> str:
...
def evaluate(self, feature, notification: HIDPPNotification, device, last_result) -> bool:
...
def data(self) -> dict[str, Any]:
...
class FallbackCondition(ConditionProtocol):
def __init__(self, *args): def __init__(self, *args):
pass pass
@ -570,7 +585,7 @@ class Condition:
return False return False
class Not(Condition): class Not(ConditionProtocol):
def __init__(self, op, warn=True): def __init__(self, op, warn=True):
if isinstance(op, list) and len(op) == 1: if isinstance(op, list) and len(op) == 1:
op = op[0] op = op[0]
@ -590,7 +605,7 @@ class Not(Condition):
return {"Not": self.component.data()} return {"Not": self.component.data()}
class Or(Condition): class Or(ConditionProtocol):
def __init__(self, args, warn=True): def __init__(self, args, warn=True):
self.components = [compile_component(a) for a in args] self.components = [compile_component(a) for a in args]
@ -603,9 +618,9 @@ class Or(Condition):
result = False result = False
for component in self.components: for component in self.components:
result = component.evaluate(feature, notification, device, last_result) result = component.evaluate(feature, notification, device, last_result)
if not isinstance(component, Action) and result is None: if not isinstance(component, ActionProtocol) and result is None:
return None return None
if isinstance(component, Condition) and result: if isinstance(component, ConditionProtocol) and result:
return result return result
return result return result
@ -613,7 +628,7 @@ class Or(Condition):
return {"Or": [c.data() for c in self.components]} return {"Or": [c.data() for c in self.components]}
class And(Condition): class And(ConditionProtocol):
def __init__(self, args, warn=True): def __init__(self, args, warn=True):
self.components = [compile_component(a) for a in args] self.components = [compile_component(a) for a in args]
@ -675,7 +690,7 @@ def gnome_dbus_pointer_prog():
return (wm_class,) if wm_class else None return (wm_class,) if wm_class else None
class Process(Condition): class Process(ConditionProtocol):
def __init__(self, process, warn=True): def __init__(self, process, warn=True):
self.process = process self.process = process
if (not wayland and not x11_setup()) or (wayland and not gnome_dbus_interface_setup()): if (not wayland and not x11_setup()) or (wayland and not gnome_dbus_interface_setup()):
@ -706,7 +721,7 @@ class Process(Condition):
return {"Process": str(self.process)} return {"Process": str(self.process)}
class MouseProcess(Condition): class MouseProcess(ConditionProtocol):
def __init__(self, process, warn=True): def __init__(self, process, warn=True):
self.process = process self.process = process
if (not wayland and not x11_setup()) or (wayland and not gnome_dbus_interface_setup()): if (not wayland and not x11_setup()) or (wayland and not gnome_dbus_interface_setup()):
@ -737,7 +752,7 @@ class MouseProcess(Condition):
return {"MouseProcess": str(self.process)} return {"MouseProcess": str(self.process)}
class Feature(Condition): class Feature(ConditionProtocol):
def __init__(self, feature: str, warn: bool = True): def __init__(self, feature: str, warn: bool = True):
try: try:
self.feature = SupportedFeature[feature] self.feature = SupportedFeature[feature]
@ -758,7 +773,7 @@ class Feature(Condition):
return {"Feature": str(self.feature)} return {"Feature": str(self.feature)}
class Report(Condition): class Report(ConditionProtocol):
def __init__(self, report, warn=True): def __init__(self, report, warn=True):
if not (isinstance(report, int)): if not (isinstance(report, int)):
if warn: if warn:
@ -780,7 +795,7 @@ class Report(Condition):
# Setting(device, setting, [key], value...) # Setting(device, setting, [key], value...)
class Setting(Condition): class Setting(ConditionProtocol):
def __init__(self, args, warn=True): def __init__(self, args, warn=True):
if not (isinstance(args, list) and len(args) > 2): if not (isinstance(args, list) and len(args) > 2):
if warn: if warn:
@ -827,7 +842,7 @@ MODIFIERS = {
MODIFIER_MASK = MODIFIERS["Shift"] + MODIFIERS["Control"] + MODIFIERS["Alt"] + MODIFIERS["Super"] MODIFIER_MASK = MODIFIERS["Shift"] + MODIFIERS["Control"] + MODIFIERS["Alt"] + MODIFIERS["Super"]
class Modifiers(Condition): class Modifiers(ConditionProtocol):
def __init__(self, modifiers, warn=True): def __init__(self, modifiers, warn=True):
modifiers = [modifiers] if isinstance(modifiers, str) else modifiers modifiers = [modifiers] if isinstance(modifiers, str) else modifiers
self.desired = 0 self.desired = 0
@ -857,7 +872,7 @@ class Modifiers(Condition):
return {"Modifiers": [str(m) for m in self.modifiers]} return {"Modifiers": [str(m) for m in self.modifiers]}
class Key(Condition): class Key(ConditionProtocol):
DOWN = "pressed" DOWN = "pressed"
UP = "released" UP = "released"
@ -912,7 +927,7 @@ class Key(Condition):
return {"Key": [str(self.key), self.action]} return {"Key": [str(self.key), self.action]}
class KeyIsDown(Condition): class KeyIsDown(ConditionProtocol):
def __init__(self, args, warn=True): def __init__(self, args, warn=True):
default_key = 0 default_key = 0
@ -956,7 +971,7 @@ def range_test(start, end, min, max):
return range_test_helper return range_test_helper
class Test(Condition): class Test(ConditionProtocol):
def __init__(self, test, warn=True): def __init__(self, test, warn=True):
self.test = "" self.test = ""
self.parameter = None self.parameter = None
@ -998,7 +1013,7 @@ class Test(Condition):
return {"Test": ([self.test, self.parameter] if self.parameter is not None else [self.test])} return {"Test": ([self.test, self.parameter] if self.parameter is not None else [self.test])}
class TestBytes(Condition): class TestBytes(ConditionProtocol):
def __init__(self, test, warn=True): def __init__(self, test, warn=True):
self.test = test self.test = test
if ( if (
@ -1026,7 +1041,7 @@ class TestBytes(Condition):
return {"TestBytes": self.test[:]} return {"TestBytes": self.test[:]}
class MouseGesture(Condition): class MouseGesture(ConditionProtocol):
MOVEMENTS = [ MOVEMENTS = [
"Mouse Up", "Mouse Up",
"Mouse Down", "Mouse Down",
@ -1081,7 +1096,7 @@ class MouseGesture(Condition):
return {"MouseGesture": [str(m) for m in self.movements]} return {"MouseGesture": [str(m) for m in self.movements]}
class Active(Condition): class Active(ConditionProtocol):
def __init__(self, devID, warn=True): def __init__(self, devID, warn=True):
if not (isinstance(devID, str)): if not (isinstance(devID, str)):
if warn: if warn:
@ -1102,7 +1117,7 @@ class Active(Condition):
return {"Active": self.devID} return {"Active": self.devID}
class Device(Condition): class Device(ConditionProtocol):
def __init__(self, devID, warn=True): def __init__(self, devID, warn=True):
if not (isinstance(devID, str)): if not (isinstance(devID, str)):
if warn: if warn:
@ -1122,7 +1137,7 @@ class Device(Condition):
return {"Device": self.devID} return {"Device": self.devID}
class Host(Condition): class Host(ConditionProtocol):
def __init__(self, host, warn=True): def __init__(self, host, warn=True):
if not (isinstance(host, str)): if not (isinstance(host, str)):
if warn: if warn:
@ -1143,12 +1158,16 @@ class Host(Condition):
return {"Host": self.host} return {"Host": self.host}
class Action: @typing.runtime_checkable
def __init__(self, *args): class ActionProtocol(typing.Protocol):
pass def __init__(self, args: Any, warn: bool) -> None:
...
def evaluate(self, feature, notification: HIDPPNotification, device, last_result): def evaluate(self, feature, notification: HIDPPNotification, device, last_result) -> None:
return None ...
def data(self) -> dict[str, Any]:
...
def keysym_to_keycode(keysym, _modifiers) -> Tuple[int, int]: # maybe should take shift into account def keysym_to_keycode(keysym, _modifiers) -> Tuple[int, int]: # maybe should take shift into account
@ -1177,7 +1196,7 @@ def keysym_to_keycode(keysym, _modifiers) -> Tuple[int, int]: # maybe should ta
return keycode, level return keycode, level
class KeyPress(Action): class KeyPress(ActionProtocol):
def __init__(self, args, warn=True): def __init__(self, args, warn=True):
self.key_names, self.action = self.regularize_args(args) self.key_names, self.action = self.regularize_args(args)
if not isinstance(self.key_names, list): if not isinstance(self.key_names, list):
@ -1267,7 +1286,7 @@ class KeyPress(Action):
# super().keyUp(self.keys, current_key_modifiers) # super().keyUp(self.keys, current_key_modifiers)
class MouseScroll(Action): class MouseScroll(ActionProtocol):
def __init__(self, amounts, warn=True): def __init__(self, amounts, warn=True):
if len(amounts) == 1 and isinstance(amounts[0], list): if len(amounts) == 1 and isinstance(amounts[0], list):
amounts = amounts[0] amounts = amounts[0]
@ -1295,7 +1314,7 @@ class MouseScroll(Action):
return {"MouseScroll": self.amounts[:]} return {"MouseScroll": self.amounts[:]}
class MouseClick(Action): class MouseClick(ActionProtocol):
def __init__(self, args, warn=True): def __init__(self, args, warn=True):
if len(args) == 1 and isinstance(args[0], list): if len(args) == 1 and isinstance(args[0], list):
args = args[0] args = args[0]
@ -1334,7 +1353,7 @@ class MouseClick(Action):
return {"MouseClick": [self.button, self.count]} return {"MouseClick": [self.button, self.count]}
class Set(Action): class Set(ActionProtocol):
def __init__(self, args, warn=True): def __init__(self, args, warn=True):
if not (isinstance(args, list) and len(args) > 2): if not (isinstance(args, list) and len(args) > 2):
if warn: if warn:
@ -1380,7 +1399,7 @@ class Set(Action):
return {"Set": self.args[:]} return {"Set": self.args[:]}
class Execute(Action): class Execute(ActionProtocol):
def __init__(self, args, warn=True): def __init__(self, args, warn=True):
if isinstance(args, str): if isinstance(args, str):
args = [args] args = [args]
@ -1404,7 +1423,7 @@ class Execute(Action):
return {"Execute": self.args[:]} return {"Execute": self.args[:]}
class Later(Action): class Later(ActionProtocol):
def __init__(self, args, warn=True): def __init__(self, args, warn=True):
self.delay = 0 self.delay = 0
self.rule = Rule([]) self.rule = Rule([])
@ -1439,7 +1458,7 @@ class Later(Action):
return {"Later": data} return {"Later": data}
COMPONENTS = { COMPONENTS: dict[str, Rule | ConditionProtocol | ActionProtocol] = {
"Rule": Rule, "Rule": Rule,
"Not": Not, "Not": Not,
"Or": Or, "Or": Or,

View File

@ -1205,7 +1205,7 @@ class NotUI(RuleComponentUI):
class ActionUI(RuleComponentUI): class ActionUI(RuleComponentUI):
CLASS = diversion.Action CLASS = diversion.ActionProtocol
@classmethod @classmethod
def icon_name(cls): def icon_name(cls):

View File

@ -36,7 +36,7 @@ class GtkSignal(Enum):
class ActionUI(RuleComponentUI): class ActionUI(RuleComponentUI):
CLASS = diversion.Action CLASS = diversion.ActionProtocol
@classmethod @classmethod
def icon_name(cls): def icon_name(cls):

View File

@ -36,7 +36,7 @@ class GtkSignal(Enum):
class ConditionUI(RuleComponentUI): class ConditionUI(RuleComponentUI):
CLASS = diversion.Condition CLASS = diversion.ConditionProtocol
@classmethod @classmethod
def icon_name(cls): def icon_name(cls):