From 76ab9482e9f6822aa0a626eade5367fd54534826 Mon Sep 17 00:00:00 2001 From: Daniel Girtler Date: Sat, 1 Nov 2025 23:55:58 +1100 Subject: [PATCH] Wifi connection menu with textual (#3879) * Wifi connector * Update --------- Co-authored-by: Daniel Girtler --- .pre-commit-config.yaml | 2 + PKGBUILD | 1 + archinstall/__init__.py | 29 +- archinstall/lib/args.py | 7 + archinstall/lib/models/network.py | 102 ++++- archinstall/lib/network/__init__.py | 0 archinstall/lib/network/wifi_handler.py | 280 ++++++++++++ archinstall/lib/network/wpa_supplicant.py | 136 ++++++ archinstall/tui/menu_item.py | 12 +- archinstall/tui/ui/__init__.py | 0 archinstall/tui/ui/components.py | 509 ++++++++++++++++++++++ archinstall/tui/ui/result.py | 26 ++ pyproject.toml | 2 + 13 files changed, 1098 insertions(+), 8 deletions(-) create mode 100644 archinstall/lib/network/__init__.py create mode 100644 archinstall/lib/network/wifi_handler.py create mode 100644 archinstall/lib/network/wpa_supplicant.py create mode 100644 archinstall/tui/ui/__init__.py create mode 100644 archinstall/tui/ui/components.py create mode 100644 archinstall/tui/ui/result.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ecad241a..eff6fee8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -41,7 +41,9 @@ repos: additional_dependencies: - pydantic - pytest + - pytest-mock - cryptography + - textual - repo: local hooks: - id: pylint diff --git a/PKGBUILD b/PKGBUILD index 94fbe75a..7bfa8260 100644 --- a/PKGBUILD +++ b/PKGBUILD @@ -28,6 +28,7 @@ depends=( 'python-cryptography' 'python-pydantic' 'python-pyparted' + 'python-textual' 'systemd' 'util-linux' 'xfsprogs' diff --git a/archinstall/__init__.py b/archinstall/__init__.py index d8069d96..0cd69865 100644 --- a/archinstall/__init__.py +++ b/archinstall/__init__.py @@ -8,7 +8,10 @@ import traceback from archinstall.lib.args import arch_config_handler from archinstall.lib.disk.utils import disk_layouts +from archinstall.lib.network.wifi_handler import wifi_handler +from archinstall.lib.networking import ping from archinstall.lib.packages.packages import check_package_upgrade +from archinstall.tui.ui.components import tui as ttui from .lib.hardware import SysInfo from .lib.output import FormattedOutput, debug, error, info, log, warn @@ -36,6 +39,17 @@ def _log_sys_info() -> None: debug(f'Disk states before installing:\n{disk_layouts()}') +def _check_online() -> None: + try: + ping('1.1.1.1') + except OSError as ex: + if 'Network is unreachable' in str(ex): + if not arch_config_handler.args.skip_wifi_check: + success = not wifi_handler.setup() + if not success: + exit(0) + + def _fetch_arch_db() -> None: info('Fetching Arch Linux package database...') try: @@ -44,13 +58,14 @@ def _fetch_arch_db() -> None: error('Failed to sync Arch Linux package database.') if 'could not resolve host' in str(e).lower(): error('Most likely due to a missing network connection or DNS issue.') + error('Run archinstall --debug and check /var/log/archinstall/install.log for details.') debug(f'Failed to sync Arch Linux package database: {e}') exit(1) -def _check_new_version() -> None: +def check_version_upgrade() -> str | None: info('Checking version...') upgrade = None @@ -62,7 +77,7 @@ def _check_new_version() -> None: text = tr('New version available') + f': {upgrade}' info(text) - time.sleep(3) + return text def main() -> int: @@ -81,11 +96,19 @@ def main() -> int: _log_sys_info() + ttui.global_header = 'Archinstall' + if not arch_config_handler.args.offline: + _check_online() _fetch_arch_db() if not arch_config_handler.args.skip_version_check: - _check_new_version() + new_version = check_version_upgrade() + + if new_version: + ttui.global_header = f'{ttui.global_header} {new_version}' + info(new_version) + time.sleep(3) script = arch_config_handler.get_script() diff --git a/archinstall/lib/args.py b/archinstall/lib/args.py index 9adc04c2..48a379b0 100644 --- a/archinstall/lib/args.py +++ b/archinstall/lib/args.py @@ -49,6 +49,7 @@ class Arguments: no_pkg_lookups: bool = False plugin: str | None = None skip_version_check: bool = False + skip_wifi_check: bool = False advanced: bool = False verbose: bool = False @@ -410,6 +411,12 @@ class ArchConfigHandler: default=False, help='Skip the version check when running archinstall', ) + parser.add_argument( + '--skip-wifi-check', + action='store_true', + default=False, + help='Skip wifi check when running archinstall', + ) parser.add_argument( '--advanced', action='store_true', diff --git a/archinstall/lib/models/network.py b/archinstall/lib/models/network.py index 993169d8..88f52fd6 100644 --- a/archinstall/lib/models/network.py +++ b/archinstall/lib/models/network.py @@ -1,9 +1,11 @@ from __future__ import annotations +import re from dataclasses import dataclass, field from enum import Enum -from typing import TYPE_CHECKING, NotRequired, TypedDict +from typing import TYPE_CHECKING, NotRequired, TypedDict, override +from archinstall.lib.output import debug from archinstall.lib.translationhandler import tr from ..models.profile import ProfileConfiguration @@ -157,3 +159,101 @@ class NetworkConfiguration: installation.enable_service('systemd-networkd') installation.enable_service('systemd-resolved') + + +@dataclass +class WifiNetwork: + bssid: str + frequency: str + signal_level: str + flags: str + ssid: str + + @override + def __hash__(self) -> int: + return hash((self.bssid, self.frequency, self.signal_level, self.flags, self.ssid)) + + def table_data(self) -> dict[str, str | int]: + """Format WiFi data for table display""" + return { + 'SSID': self.ssid, + 'Signal': f'{self.signal_level} dBm', + 'Frequency': f'{self.frequency} MHz', + 'Security': self.flags, + 'BSSID': self.bssid, + } + + @staticmethod + def from_wpa(results: str) -> list[WifiNetwork]: + entries: list[WifiNetwork] = [] + + for line in results.splitlines(): + line = line.strip() + if not line: + continue + + parts = line.split() + if len(parts) != 5: + continue + + wifi = WifiNetwork(bssid=parts[0], frequency=parts[1], signal_level=parts[2], flags=parts[3], ssid=parts[4]) + entries.append(wifi) + + return entries + + +@dataclass +class WifiConfiguredNetwork: + network_id: int + ssid: str + bssid: str + flags: list[str] + + @classmethod + def from_wpa_cli_output(cls, list_networks: str) -> list[WifiConfiguredNetwork]: + """ + Example output from 'wpa_cli list_networks' + + Selected interface 'wlan0' + network id / ssid / bssid / flags + 0 WifiGuest any [CURRENT] + 1 any [DISABLED] + 2 any [DISABLED] + """ + + lines = list_networks.strip().splitlines() + lines = lines[1:] # remove the header row from the wpa_cli output + + networks = [] + + for line in lines: + line = line.strip() + parts = line.split('\t') + + if len(parts) < 3: + continue + + try: + # flags = cls._extract_flags(parts[3]) + flags: list[str] = [] + + networks.append( + WifiConfiguredNetwork( + network_id=int(parts[0]), + ssid=parts[1], + bssid=parts[2], + flags=flags, + ) + ) + except (ValueError, IndexError): + debug('Parsing error for network output') + + return networks + + @classmethod + def _extract_flags(cls, flag_string: str) -> list[str]: + pattern = r'\[([^\]]+)\]' + + extracted_values = re.findall(pattern, flag_string) + + return extracted_values diff --git a/archinstall/lib/network/__init__.py b/archinstall/lib/network/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/archinstall/lib/network/wifi_handler.py b/archinstall/lib/network/wifi_handler.py new file mode 100644 index 00000000..13fc3e43 --- /dev/null +++ b/archinstall/lib/network/wifi_handler.py @@ -0,0 +1,280 @@ +from asyncio import sleep +from dataclasses import dataclass +from pathlib import Path +from typing import Any, assert_never + +from archinstall.lib.exceptions import SysCallError +from archinstall.lib.general import SysCommand +from archinstall.lib.models.network import WifiConfiguredNetwork, WifiNetwork +from archinstall.lib.network.wpa_supplicant import WpaSupplicantConfig +from archinstall.lib.output import debug +from archinstall.lib.translationhandler import tr +from archinstall.tui.menu_item import MenuItemGroup +from archinstall.tui.ui.components import ConfirmationScreen, InputScreen, LoadingScreen, NotifyScreen, TableSelectionScreen, tui +from archinstall.tui.ui.result import ResultType + + +@dataclass +class WpaCliResult: + success: bool + response: str | None = None + error: str | None = None + + +class WifiHandler: + def __init__(self) -> None: + tui.set_main(self) + self._wpa_config = WpaSupplicantConfig() + + def setup(self) -> Any: + result = tui.run() + return result + + async def run(self) -> None: + """ + This is the entry point that is called by components.TApp + """ + wifi_iface = self._find_wifi_interface() + + if not wifi_iface: + debug('No wifi interface found') + tui.exit(False) + return None + + prompt = tr('No network connection found') + '\n\n' + prompt += tr('Would you like to connect to a Wifi?') + '\n' + + result = await ConfirmationScreen[bool]( + MenuItemGroup.yes_no(), + header=prompt, + allow_skip=True, + allow_reset=True, + ).run() + + match result.type_: + case ResultType.Selection: + if result.value() is False: + tui.exit(False) + return None + case ResultType.Skip | ResultType.Reset: + tui.exit(False) + return None + case _: + assert_never(result) + + setup_result = await self._setup_wifi(wifi_iface) + tui.exit(setup_result) + + async def _enable_supplicant(self, wifi_iface: str) -> bool: + self._wpa_config.load_config() + + result = self._wpa_cli('status') # if it it's running it will blow up + + if result.success: + debug('wpa_supplicant already running') + return True + + if result.error and 'failed to connect to non-global ctrl_ifname'.lower() not in result.error.lower(): + debug('Unexpected wpa_cli failure') + return False + + debug('wpa_supplicant not running, trying to enable') + + try: + SysCommand(f'wpa_supplicant -B -i {wifi_iface} -c {self._wpa_config.config_file}') + result = self._wpa_cli('status') # if it it's running it will blow up + + if result.success: + debug('successfully enabled wpa_supplicant') + return True + else: + debug(f'failed to enable wpa_supplicant: {result.error}') + return False + except SysCallError as err: + debug(f'failed to enable wpa_supplicant: {err}') + return False + + def _find_wifi_interface(self) -> str | None: + net_path = Path('/sys/class/net') + + for iface in net_path.iterdir(): + maybe_wireless_path = net_path / iface / 'wireless' + if maybe_wireless_path.is_dir(): + return iface.name + + return None + + async def _setup_wifi(self, wifi_iface: str) -> bool: + debug('Setting up wifi') + + if not await self._enable_supplicant(wifi_iface): + debug('Failed to enable wpa_supplicant') + return False + + if not wifi_iface: + debug('No wifi interface found') + await NotifyScreen(header=tr('No wifi interface found')).run() + return False + + debug(f'Found wifi interface: {wifi_iface}') + + async def get_wifi_networks() -> list[WifiNetwork]: + debug('Scanning Wifi networks') + result = self._wpa_cli('scan', wifi_iface) + + if not result.success: + debug(f'Failed to scan wifi networks: {result.error}') + return [] + + await sleep(5) + return self._get_scan_results(wifi_iface) + + result = await TableSelectionScreen[WifiNetwork]( + header=tr('Select wifi network to connect to'), + loading_header=tr('Scanning wifi networks...'), + data_callback=get_wifi_networks, + allow_skip=True, + allow_reset=True, + ).run() + + match result.type_: + case ResultType.Selection: + if not result.has_data(): + debug('No networks found') + await NotifyScreen(header=tr('No wifi networks found')).run() + tui.exit(False) + return False + + network = result.value() + case ResultType.Skip | ResultType.Reset: + tui.exit(False) + return False + case _: + assert_never(result.type_) + + existing_network = self._wpa_config.get_existing_network(network.ssid) + existing_psk = existing_network.psk if existing_network else None + psk = await self._prompt_psk(existing_psk) + + if not psk: + debug('No password specified') + return False + + self._wpa_config.set_network(network, psk) + self._wpa_config.write_config() + + wpa_result = self._wpa_cli('reconfigure') + + if not wpa_result.success: + debug(f'Failed to reconfigure wpa_supplicant: {wpa_result.error}') + await self._notify_failure() + return False + + await LoadingScreen(3, 'Setting up wifi...').run() + + network_id = self._find_network_id(network.ssid, wifi_iface) + + if not network_id: + debug('Failed to find network id') + await self._notify_failure() + return False + + wpa_result = self._wpa_cli(f'enable {network_id}', wifi_iface) + + if not wpa_result.success: + debug(f'Failed to enable network: {wpa_result.error}') + await self._notify_failure() + return False + + await LoadingScreen(5, 'Connecting wifi...').run() + + return True + + async def _notify_failure(self) -> None: + await NotifyScreen(header=tr('Failed setting up wifi')).run() + + def _wpa_cli(self, command: str, iface: str | None = None) -> WpaCliResult: + cmd = 'wpa_cli' + + if iface: + cmd += f' -i {iface}' + + cmd += f' {command}' + + try: + result = SysCommand(cmd).decode() + + if 'FAIL' in result: + debug(f'wpa_cli returned FAIL: {result}') + return WpaCliResult( + success=False, + error=f'wpa_cli returned a failure: {result}', + ) + + return WpaCliResult(success=True, response=result) + except SysCallError as err: + debug(f'error running wpa_cli command: {err}') + return WpaCliResult( + success=False, + error=f'Error running wpa_cli command: {err}', + ) + + def _find_network_id(self, ssid: str, iface: str) -> int | None: + result = self._wpa_cli('list_networks', iface) + + if not result.success: + debug(f'Failed to list networks: {result.error}') + return None + + list_networks = result.response + + if not list_networks: + debug('No networks found') + return None + + existing_networks = WifiConfiguredNetwork.from_wpa_cli_output(list_networks) + + for network in existing_networks: + if network.ssid == ssid: + return network.network_id + + return None + + async def _prompt_psk(self, existing: str | None = None) -> str | None: + result = await InputScreen( + header=tr('Enter wifi password'), + password=True, + allow_skip=True, + allow_reset=True, + default_value=existing, + ).run() + + if result.type_ != ResultType.Selection: + debug('No password provided, aborting connection') + return None + + return result.value() + + def _get_scan_results(self, iface: str) -> list[WifiNetwork]: + debug(f'Retrieving scan results: {iface}') + + try: + result = self._wpa_cli('scan_results', iface) + + if not result.success: + debug(f'Failed to retrieve scan results: {result.error}') + return [] + + if not result.response: + debug('No wifi networks found') + return [] + + networks = WifiNetwork.from_wpa(result.response) + + return networks + except SysCallError as err: + debug('Unable to retrieve wifi results') + raise err + + +wifi_handler = WifiHandler() diff --git a/archinstall/lib/network/wpa_supplicant.py b/archinstall/lib/network/wpa_supplicant.py new file mode 100644 index 00000000..935e9b4d --- /dev/null +++ b/archinstall/lib/network/wpa_supplicant.py @@ -0,0 +1,136 @@ +from dataclasses import dataclass, field +from pathlib import Path + +from archinstall.lib.models.network import WifiNetwork +from archinstall.lib.output import debug + + +@dataclass +class WpaSupplicantNetwork: + mappings: dict[str, str] = field(default_factory=dict) + + @property + def psk(self) -> str: + return self.mappings['psk'].strip('"') + + @property + def ssid(self) -> str: + return self.mappings['ssid'].strip('"') + + def to_config_entry(self) -> str: + wpa_net_config = '\n' + wpa_net_config += 'network={\n' + + for key, value in self.mappings.items(): + wpa_net_config += f'\t{key}={value}\n' + + if 'mesh_fwding' not in self.mappings: + wpa_net_config += '\tmesh_fwding=1\n' + + wpa_net_config += '}\n\n' + return wpa_net_config + + +class WpaSupplicantConfig: + def __init__(self) -> None: + self.config_file = Path('/etc/wpa_supplicant/wpa_supplicant.conf') + self._wpa_networks: list[WpaSupplicantNetwork] = [] + + def load_config(self) -> None: + if not self.config_file.is_file(): + debug('wpa_supplicant.conf not found, creating') + self._create_config() + else: + debug('wpa_supplicant.conf found') + content = self.config_file.read_text() + + config_header = '' + if 'ctrl_interface' not in content: + config_header += 'ctrl_interface=/run/wpa_supplicant\n' + if 'update_config' not in content: + config_header += 'update_config=1\n\n' + + if config_header: + config = config_header + content + self.config_file.write_text(config) + + self._wpa_networks = self._parse_config() + + def _config_header(self) -> str: + return 'ctrl_interface=/run/wpa_supplicant\nupdate_config=1' + + def get_existing_network(self, ssid: str) -> WpaSupplicantNetwork | None: + ssid = f'"{ssid}"' + + for network in self._wpa_networks: + if network.mappings['ssid'] == ssid: + return network + + return None + + def set_network(self, network: WifiNetwork, psk: str) -> None: + debug('setting new wifi network') + + existing_network = self.get_existing_network(network.ssid) + + if not existing_network: + wpa_net_config = WpaSupplicantNetwork( + mappings={ + 'ssid': f'"{network.ssid}"', + 'psk': f'"{psk}"', + } + ) + self._wpa_networks.append(wpa_net_config) + else: + existing_network.mappings['psk'] = f'"{psk}"' + + def write_config(self) -> None: + debug('writing wpa_supplicant config') + + config = self._config_header() + config += '\n\n' + + for network in self._wpa_networks: + config += network.to_config_entry() + + self.config_file.write_text(config) + + def _create_config(self) -> None: + self.config_file.touch() + + header = self._config_header() + self.config_file.write_text(header) + + def _parse_config(self) -> list[WpaSupplicantNetwork]: + content = self.config_file.read_text() + + networks: list[WpaSupplicantNetwork] = [] + in_network_block = False + cur_net_data: dict[str, str] = {} + + for line in content.splitlines(): + line = line.strip() + + if not line or line.startswith('#'): + continue + + if line == 'network={': + in_network_block = True + cur_net_data = {} + continue + + if in_network_block and line == '}': + new_network = WpaSupplicantNetwork( + mappings=cur_net_data, + ) + + networks.append(new_network) + in_network_block = False + continue + + if in_network_block: + if '=' in line: + key, value = line.split('=', 1) + cur_net_data[key.strip()] = value.strip() + + return networks diff --git a/archinstall/tui/menu_item.py b/archinstall/tui/menu_item.py index 3b8366b1..741e447d 100644 --- a/archinstall/tui/menu_item.py +++ b/archinstall/tui/menu_item.py @@ -33,16 +33,16 @@ class MenuItem: return self.value @classmethod - def yes(cls) -> 'MenuItem': + def yes(cls, action: Callable[[Any], Any] | None = None) -> 'MenuItem': if cls._yes is None: - cls._yes = cls(tr('Yes'), value=True) + cls._yes = cls(tr('Yes'), value=True, key='yes', action=action) return cls._yes @classmethod - def no(cls) -> 'MenuItem': + def no(cls, action: Callable[[Any], Any] | None = None) -> 'MenuItem': if cls._no is None: - cls._no = cls(tr('No'), value=True) + cls._no = cls(tr('No'), value=False, key='no', action=action) return cls._no @@ -223,6 +223,10 @@ class MenuItemGroup: return tr(' (default)') return '' + def set_action_for_all(self, action: Callable[[Any], Any]) -> None: + for item in self.items: + item.action = action + @cached_property def items(self) -> list[MenuItem]: pattern = self._filter_pattern.lower() diff --git a/archinstall/tui/ui/__init__.py b/archinstall/tui/ui/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/archinstall/tui/ui/components.py b/archinstall/tui/ui/components.py new file mode 100644 index 00000000..4a7c5d06 --- /dev/null +++ b/archinstall/tui/ui/components.py @@ -0,0 +1,509 @@ +from __future__ import annotations + +from collections.abc import Awaitable, Callable +from typing import Any, TypeVar, override + +from textual import work +from textual.app import App, ComposeResult +from textual.binding import Binding +from textual.containers import Center, Horizontal, Vertical +from textual.events import Key +from textual.screen import Screen +from textual.widgets import Button, DataTable, Input, LoadingIndicator, Static + +from archinstall.lib.output import debug +from archinstall.lib.translationhandler import tr +from archinstall.tui.menu_item import MenuItem, MenuItemGroup +from archinstall.tui.ui.result import Result, ResultType + +ValueT = TypeVar('ValueT') + + +class BaseScreen(Screen[Result[ValueT]]): + BINDINGS = [ # noqa: RUF012 + Binding('escape', 'cancel_operation', 'Cancel', show=True), + Binding('ctrl+c', 'reset_operation', 'Reset', show=True), + ] + + def __init__(self, allow_skip: bool = False, allow_reset: bool = False): + super().__init__() + self._allow_skip = allow_skip + self._allow_reset = allow_reset + + def action_cancel_operation(self) -> None: + if self._allow_skip: + self.dismiss(Result(ResultType.Skip, None)) # type: ignore[unused-awaitable] + + def action_reset_operation(self) -> None: + if self._allow_reset: + self.dismiss(Result(ResultType.Reset, None)) # type: ignore[unused-awaitable] + + def _compose_header(self) -> ComposeResult: + """Compose the app header if global header text is available.""" + if tui.global_header: + yield Static(tui.global_header, classes='app-header') + + +class LoadingScreen(BaseScreen[None]): + CSS = """ + LoadingScreen { + align: center middle; + } + + .dialog { + align: center middle; + width: 100%; + border: none; + background: transparent; + } + + .header { + text-align: center; + margin-bottom: 1; + } + + LoadingIndicator { + align: center middle; + } + """ + + def __init__( + self, + timer: int, + header: str | None = None, + ): + super().__init__() + self._timer = timer + self._header = header + + async def run(self) -> Result[None]: + return await tui.show(self) + + @override + def compose(self) -> ComposeResult: + yield from self._compose_header() + + with Center(): + with Vertical(classes='dialog'): + if self._header: + yield Static(self._header, classes='header') + yield Center(LoadingIndicator()) # ensures indicator is centered too + + def on_mount(self) -> None: + self.set_timer(self._timer, self.action_pop_screen) + + def action_pop_screen(self) -> None: + self.dismiss() # type: ignore[unused-awaitable] + + +class ConfirmationScreen(BaseScreen[ValueT]): + BINDINGS = [ # noqa: RUF012 + Binding('l', 'focus_right', 'Focus right', show=True), + Binding('h', 'focus_left', 'Focus left', show=True), + Binding('right', 'focus_right', 'Focus right', show=True), + Binding('left', 'focus_left', 'Focus left', show=True), + ] + + CSS = """ + ConfirmationScreen { + align: center middle; + } + + .dialog-wrapper { + align: center middle; + height: 100%; + width: 100%; + } + + .dialog { + width: 80; + height: 10; + border: none; + background: transparent; + } + + .dialog-content { + padding: 1; + height: 100%; + } + + .message { + text-align: center; + margin-bottom: 1; + } + + .buttons { + align: center middle; + background: transparent; + } + + Button { + width: 4; + height: 3; + background: transparent; + margin: 0 1; + } + + Button.-active { + background: #1793D1; + color: white; + border: none; + text-style: none; + } + """ + + def __init__( + self, + group: MenuItemGroup, + header: str, + allow_skip: bool = False, + allow_reset: bool = False, + ): + super().__init__(allow_skip, allow_reset) + self._group = group + self._header = header + + async def run(self) -> Result[ValueT]: + return await tui.show(self) + + @override + def compose(self) -> ComposeResult: + yield from self._compose_header() + + with Center(classes='dialog-wrapper'): + with Vertical(classes='dialog'): + with Vertical(classes='dialog-content'): + yield Static(self._header, classes='message') + with Horizontal(classes='buttons'): + for item in self._group.items: + yield Button(item.text, id=item.key) + + def on_mount(self) -> None: + self.update_selection() + + def update_selection(self) -> None: + focused = self._group.focus_item + buttons = self.query(Button) + + if not focused: + return + + for button in buttons: + if button.id == focused.key: + button.add_class('-active') + button.focus() + else: + button.remove_class('-active') + + def action_focus_right(self) -> None: + self._group.focus_next() + self.update_selection() + + def action_focus_left(self) -> None: + self._group.focus_prev() + self.update_selection() + + def on_key(self, event: Key) -> None: + if event.key == 'enter': + item = self._group.focus_item + if not item: + return None + self.dismiss(Result(ResultType.Selection, item.value)) # type: ignore[unused-awaitable] + + +class NotifyScreen(ConfirmationScreen[ValueT]): + def __init__(self, header: str): + group = MenuItemGroup([MenuItem(tr('Ok'))]) + super().__init__(group, header) + + +class InputScreen(BaseScreen[str]): + CSS = """ + InputScreen { + } + + .dialog-wrapper { + align: center middle; + height: 100%; + width: 100%; + } + + .input-dialog { + width: 60; + height: 10; + border: none; + background: transparent; + } + + .input-content { + padding: 1; + height: 100%; + } + + .input-header { + text-align: center; + margin: 0 0; + color: white; + text-style: bold; + background: transparent; + } + + .input-prompt { + text-align: center; + margin: 0 0 1 0; + background: transparent; + } + + Input { + margin: 1 2; + border: solid $accent; + background: transparent; + height: 3; + } + + Input .input--cursor { + color: white; + } + + Input:focus { + border: solid $primary; + } + """ + + def __init__( + self, + header: str, + placeholder: str | None = None, + password: bool = False, + default_value: str | None = None, + allow_reset: bool = False, + allow_skip: bool = False, + ): + super().__init__(allow_skip, allow_reset) + self._header = header + self._placeholder = placeholder or '' + self._password = password + self._default_value = default_value or '' + self._allow_reset = allow_reset + self._allow_skip = allow_skip + + async def run(self) -> Result[str]: + return await tui.show(self) + + @override + def compose(self) -> ComposeResult: + yield from self._compose_header() + + with Center(classes='dialog-wrapper'): + with Vertical(classes='input-dialog'): + with Vertical(classes='input-content'): + yield Static(self._header, classes='input-header') + yield Input( + placeholder=self._placeholder, + password=self._password, + value=self._default_value, + id='main_input', + ) + + def on_mount(self) -> None: + input_field = self.query_one('#main_input', Input) + input_field.focus() + + def on_key(self, event: Key) -> None: + if event.key == 'enter': + input_field = self.query_one('#main_input', Input) + value = input_field.value + self.dismiss(Result(ResultType.Selection, value)) # type: ignore[unused-awaitable] + + +class TableSelectionScreen(BaseScreen[ValueT]): + BINDINGS = [ # noqa: RUF012 + Binding('j', 'cursor_down', 'Down', show=True), + Binding('k', 'cursor_up', 'Up', show=True), + ] + + CSS = """ + TableSelectionScreen { + align: center middle; + background: transparent; + } + + DataTable { + height: auto; + width: auto; + border: none; + background: transparent; + } + + DataTable .datatable--header { + background: transparent; + border: solid; + } + + .content-container { + width: auto; + min-height: 10; + min-width: 40; + align: center middle; + background: transparent; + } + + .header { + text-align: center; + margin-bottom: 1; + } + + LoadingIndicator { + height: auto; + background: transparent; + } + """ + + def __init__( + self, + header: str | None = None, + data: list[ValueT] | None = None, + data_callback: Callable[[], Awaitable[list[ValueT]]] | None = None, + allow_reset: bool = False, + allow_skip: bool = False, + loading_header: str | None = None, + ): + super().__init__(allow_skip, allow_reset) + self._header = header + self._data = data + self._data_callback = data_callback + self._loading_header = loading_header + + if self._data is None and self._data_callback is None: + raise ValueError('Either data or data_callback must be provided') + + async def run(self) -> Result[ValueT]: + return await tui.show(self) + + def action_cursor_down(self) -> None: + table = self.query_one(DataTable) + if table.cursor_row is not None: + next_row = min(table.cursor_row + 1, len(table.rows) - 1) + table.move_cursor(row=next_row, column=table.cursor_column or 0) + + def action_cursor_up(self) -> None: + table = self.query_one(DataTable) + if table.cursor_row is not None: + prev_row = max(table.cursor_row - 1, 0) + table.move_cursor(row=prev_row, column=table.cursor_column or 0) + + @override + def compose(self) -> ComposeResult: + yield from self._compose_header() + + with Center(): + with Vertical(classes='content-container'): + if self._header: + yield Static(self._header, classes='header', id='header') + + if self._loading_header: + yield Static(self._loading_header, classes='header', id='loading-header') + + yield LoadingIndicator(id='loader') + yield DataTable(id='data_table') + + def on_mount(self) -> None: + self._display_header(True) + data_table = self.query_one(DataTable) + data_table.cell_padding = 2 + + if self._data: + self._put_data_to_table(data_table, self._data) + else: + self._load_data(data_table) + + @work + async def _load_data(self, table: DataTable[ValueT]) -> None: + assert self._data_callback is not None + data = await self._data_callback() + self._put_data_to_table(table, data) + + def _display_header(self, is_loading: bool) -> None: + try: + loading_header = self.query_one('#loading-header', Static) + header = self.query_one('#header', Static) + loading_header.display = is_loading + header.display = not is_loading + except Exception: + pass + + def _put_data_to_table(self, table: DataTable[ValueT], data: list[ValueT]) -> None: + if not data: + self.dismiss(Result(ResultType.Selection, None)) # type: ignore[unused-awaitable] + return + + cols = list(data[0].table_data().keys()) # type: ignore[attr-defined] + table.add_columns(*cols) + + for d in data: + row_values = list(d.table_data().values()) # type: ignore[attr-defined] + table.add_row(*row_values, key=d) # type: ignore[arg-type] + + table.cursor_type = 'row' + table.display = True + + loader = self.query_one('#loader') + loader.display = False + self._display_header(False) + table.focus() + + def on_data_table_row_selected(self, event: DataTable.RowSelected) -> None: + data: ValueT = event.row_key.value # type: ignore[assignment] + self.dismiss(Result(ResultType.Selection, data)) # type: ignore[unused-awaitable] + + +class TApp(App[Any]): + CSS = """ + .app-header { + dock: top; + height: auto; + width: 100%; + content-align: center middle; + background: $primary; + color: white; + text-style: bold; + } + """ + + def __init__(self) -> None: + super().__init__(ansi_color=True) + self._main = None + self._global_header: str | None = None + + @property + def global_header(self) -> str | None: + return self._global_header + + @global_header.setter + def global_header(self, value: str | None) -> None: + self._global_header = value + + def set_main(self, main: Any) -> None: + self._main = main + + def on_mount(self) -> None: + self._run_worker() + + @work + async def _run_worker(self) -> None: + try: + if self._main is not None: + await self._main.run() # type: ignore[unreachable] + except Exception as err: + debug(f'Error while running main app: {err}') + raise err from err + + @work + async def _show_async(self, screen: Screen[Result[ValueT]]) -> Result[ValueT]: + return await self.push_screen_wait(screen) + + async def show(self, screen: Screen[Result[ValueT]]) -> Result[ValueT]: + return await self._show_async(screen).wait() + + +tui = TApp() diff --git a/archinstall/tui/ui/result.py b/archinstall/tui/ui/result.py new file mode 100644 index 00000000..c4e92468 --- /dev/null +++ b/archinstall/tui/ui/result.py @@ -0,0 +1,26 @@ +from dataclasses import dataclass +from enum import Enum, auto +from typing import cast + + +class ResultType(Enum): + Selection = auto() + Skip = auto() + Reset = auto() + + +@dataclass +class Result[ValueT]: + type_: ResultType + _data: ValueT | list[ValueT] | None + + def has_data(self) -> bool: + return self._data is not None + + def value(self) -> ValueT: + assert type(self._data) is not list and self._data is not None + return cast(ValueT, self._data) + + def values(self) -> list[ValueT]: + assert type(self._data) is list + return cast(list[ValueT], self._data) diff --git a/pyproject.toml b/pyproject.toml index f0a37a93..ac1a44f2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,8 @@ dependencies = [ "pyparted>=3.13.0", "pydantic==2.12.3", "cryptography>=45.0.7", + "textual>=5.3.0", + "pytest-mock>=3.15.1", ] [project.urls]