Wifi connection menu with textual (#3879)

* Wifi connector

* Update

---------

Co-authored-by: Daniel Girtler <dgirtler@atlassian.com>
This commit is contained in:
Daniel Girtler 2025-11-01 23:55:58 +11:00 committed by GitHub
parent 7af94c8fe5
commit 76ab9482e9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 1098 additions and 8 deletions

View File

@ -41,7 +41,9 @@ repos:
additional_dependencies:
- pydantic
- pytest
- pytest-mock
- cryptography
- textual
- repo: local
hooks:
- id: pylint

View File

@ -28,6 +28,7 @@ depends=(
'python-cryptography'
'python-pydantic'
'python-pyparted'
'python-textual'
'systemd'
'util-linux'
'xfsprogs'

View File

@ -8,7 +8,10 @@ import traceback
from archinstall.lib.args import arch_config_handler
from archinstall.lib.disk.utils import disk_layouts
from archinstall.lib.network.wifi_handler import wifi_handler
from archinstall.lib.networking import ping
from archinstall.lib.packages.packages import check_package_upgrade
from archinstall.tui.ui.components import tui as ttui
from .lib.hardware import SysInfo
from .lib.output import FormattedOutput, debug, error, info, log, warn
@ -36,6 +39,17 @@ def _log_sys_info() -> None:
debug(f'Disk states before installing:\n{disk_layouts()}')
def _check_online() -> None:
try:
ping('1.1.1.1')
except OSError as ex:
if 'Network is unreachable' in str(ex):
if not arch_config_handler.args.skip_wifi_check:
success = not wifi_handler.setup()
if not success:
exit(0)
def _fetch_arch_db() -> None:
info('Fetching Arch Linux package database...')
try:
@ -44,13 +58,14 @@ def _fetch_arch_db() -> None:
error('Failed to sync Arch Linux package database.')
if 'could not resolve host' in str(e).lower():
error('Most likely due to a missing network connection or DNS issue.')
error('Run archinstall --debug and check /var/log/archinstall/install.log for details.')
debug(f'Failed to sync Arch Linux package database: {e}')
exit(1)
def _check_new_version() -> None:
def check_version_upgrade() -> str | None:
info('Checking version...')
upgrade = None
@ -62,7 +77,7 @@ def _check_new_version() -> None:
text = tr('New version available') + f': {upgrade}'
info(text)
time.sleep(3)
return text
def main() -> int:
@ -81,11 +96,19 @@ def main() -> int:
_log_sys_info()
ttui.global_header = 'Archinstall'
if not arch_config_handler.args.offline:
_check_online()
_fetch_arch_db()
if not arch_config_handler.args.skip_version_check:
_check_new_version()
new_version = check_version_upgrade()
if new_version:
ttui.global_header = f'{ttui.global_header} {new_version}'
info(new_version)
time.sleep(3)
script = arch_config_handler.get_script()

View File

@ -49,6 +49,7 @@ class Arguments:
no_pkg_lookups: bool = False
plugin: str | None = None
skip_version_check: bool = False
skip_wifi_check: bool = False
advanced: bool = False
verbose: bool = False
@ -410,6 +411,12 @@ class ArchConfigHandler:
default=False,
help='Skip the version check when running archinstall',
)
parser.add_argument(
'--skip-wifi-check',
action='store_true',
default=False,
help='Skip wifi check when running archinstall',
)
parser.add_argument(
'--advanced',
action='store_true',

View File

@ -1,9 +1,11 @@
from __future__ import annotations
import re
from dataclasses import dataclass, field
from enum import Enum
from typing import TYPE_CHECKING, NotRequired, TypedDict
from typing import TYPE_CHECKING, NotRequired, TypedDict, override
from archinstall.lib.output import debug
from archinstall.lib.translationhandler import tr
from ..models.profile import ProfileConfiguration
@ -157,3 +159,101 @@ class NetworkConfiguration:
installation.enable_service('systemd-networkd')
installation.enable_service('systemd-resolved')
@dataclass
class WifiNetwork:
bssid: str
frequency: str
signal_level: str
flags: str
ssid: str
@override
def __hash__(self) -> int:
return hash((self.bssid, self.frequency, self.signal_level, self.flags, self.ssid))
def table_data(self) -> dict[str, str | int]:
"""Format WiFi data for table display"""
return {
'SSID': self.ssid,
'Signal': f'{self.signal_level} dBm',
'Frequency': f'{self.frequency} MHz',
'Security': self.flags,
'BSSID': self.bssid,
}
@staticmethod
def from_wpa(results: str) -> list[WifiNetwork]:
entries: list[WifiNetwork] = []
for line in results.splitlines():
line = line.strip()
if not line:
continue
parts = line.split()
if len(parts) != 5:
continue
wifi = WifiNetwork(bssid=parts[0], frequency=parts[1], signal_level=parts[2], flags=parts[3], ssid=parts[4])
entries.append(wifi)
return entries
@dataclass
class WifiConfiguredNetwork:
network_id: int
ssid: str
bssid: str
flags: list[str]
@classmethod
def from_wpa_cli_output(cls, list_networks: str) -> list[WifiConfiguredNetwork]:
"""
Example output from 'wpa_cli list_networks'
Selected interface 'wlan0'
network id / ssid / bssid / flags
0 WifiGuest any [CURRENT]
1 any [DISABLED]
2 any [DISABLED]
"""
lines = list_networks.strip().splitlines()
lines = lines[1:] # remove the header row from the wpa_cli output
networks = []
for line in lines:
line = line.strip()
parts = line.split('\t')
if len(parts) < 3:
continue
try:
# flags = cls._extract_flags(parts[3])
flags: list[str] = []
networks.append(
WifiConfiguredNetwork(
network_id=int(parts[0]),
ssid=parts[1],
bssid=parts[2],
flags=flags,
)
)
except (ValueError, IndexError):
debug('Parsing error for network output')
return networks
@classmethod
def _extract_flags(cls, flag_string: str) -> list[str]:
pattern = r'\[([^\]]+)\]'
extracted_values = re.findall(pattern, flag_string)
return extracted_values

View File

View File

@ -0,0 +1,280 @@
from asyncio import sleep
from dataclasses import dataclass
from pathlib import Path
from typing import Any, assert_never
from archinstall.lib.exceptions import SysCallError
from archinstall.lib.general import SysCommand
from archinstall.lib.models.network import WifiConfiguredNetwork, WifiNetwork
from archinstall.lib.network.wpa_supplicant import WpaSupplicantConfig
from archinstall.lib.output import debug
from archinstall.lib.translationhandler import tr
from archinstall.tui.menu_item import MenuItemGroup
from archinstall.tui.ui.components import ConfirmationScreen, InputScreen, LoadingScreen, NotifyScreen, TableSelectionScreen, tui
from archinstall.tui.ui.result import ResultType
@dataclass
class WpaCliResult:
success: bool
response: str | None = None
error: str | None = None
class WifiHandler:
def __init__(self) -> None:
tui.set_main(self)
self._wpa_config = WpaSupplicantConfig()
def setup(self) -> Any:
result = tui.run()
return result
async def run(self) -> None:
"""
This is the entry point that is called by components.TApp
"""
wifi_iface = self._find_wifi_interface()
if not wifi_iface:
debug('No wifi interface found')
tui.exit(False)
return None
prompt = tr('No network connection found') + '\n\n'
prompt += tr('Would you like to connect to a Wifi?') + '\n'
result = await ConfirmationScreen[bool](
MenuItemGroup.yes_no(),
header=prompt,
allow_skip=True,
allow_reset=True,
).run()
match result.type_:
case ResultType.Selection:
if result.value() is False:
tui.exit(False)
return None
case ResultType.Skip | ResultType.Reset:
tui.exit(False)
return None
case _:
assert_never(result)
setup_result = await self._setup_wifi(wifi_iface)
tui.exit(setup_result)
async def _enable_supplicant(self, wifi_iface: str) -> bool:
self._wpa_config.load_config()
result = self._wpa_cli('status') # if it it's running it will blow up
if result.success:
debug('wpa_supplicant already running')
return True
if result.error and 'failed to connect to non-global ctrl_ifname'.lower() not in result.error.lower():
debug('Unexpected wpa_cli failure')
return False
debug('wpa_supplicant not running, trying to enable')
try:
SysCommand(f'wpa_supplicant -B -i {wifi_iface} -c {self._wpa_config.config_file}')
result = self._wpa_cli('status') # if it it's running it will blow up
if result.success:
debug('successfully enabled wpa_supplicant')
return True
else:
debug(f'failed to enable wpa_supplicant: {result.error}')
return False
except SysCallError as err:
debug(f'failed to enable wpa_supplicant: {err}')
return False
def _find_wifi_interface(self) -> str | None:
net_path = Path('/sys/class/net')
for iface in net_path.iterdir():
maybe_wireless_path = net_path / iface / 'wireless'
if maybe_wireless_path.is_dir():
return iface.name
return None
async def _setup_wifi(self, wifi_iface: str) -> bool:
debug('Setting up wifi')
if not await self._enable_supplicant(wifi_iface):
debug('Failed to enable wpa_supplicant')
return False
if not wifi_iface:
debug('No wifi interface found')
await NotifyScreen(header=tr('No wifi interface found')).run()
return False
debug(f'Found wifi interface: {wifi_iface}')
async def get_wifi_networks() -> list[WifiNetwork]:
debug('Scanning Wifi networks')
result = self._wpa_cli('scan', wifi_iface)
if not result.success:
debug(f'Failed to scan wifi networks: {result.error}')
return []
await sleep(5)
return self._get_scan_results(wifi_iface)
result = await TableSelectionScreen[WifiNetwork](
header=tr('Select wifi network to connect to'),
loading_header=tr('Scanning wifi networks...'),
data_callback=get_wifi_networks,
allow_skip=True,
allow_reset=True,
).run()
match result.type_:
case ResultType.Selection:
if not result.has_data():
debug('No networks found')
await NotifyScreen(header=tr('No wifi networks found')).run()
tui.exit(False)
return False
network = result.value()
case ResultType.Skip | ResultType.Reset:
tui.exit(False)
return False
case _:
assert_never(result.type_)
existing_network = self._wpa_config.get_existing_network(network.ssid)
existing_psk = existing_network.psk if existing_network else None
psk = await self._prompt_psk(existing_psk)
if not psk:
debug('No password specified')
return False
self._wpa_config.set_network(network, psk)
self._wpa_config.write_config()
wpa_result = self._wpa_cli('reconfigure')
if not wpa_result.success:
debug(f'Failed to reconfigure wpa_supplicant: {wpa_result.error}')
await self._notify_failure()
return False
await LoadingScreen(3, 'Setting up wifi...').run()
network_id = self._find_network_id(network.ssid, wifi_iface)
if not network_id:
debug('Failed to find network id')
await self._notify_failure()
return False
wpa_result = self._wpa_cli(f'enable {network_id}', wifi_iface)
if not wpa_result.success:
debug(f'Failed to enable network: {wpa_result.error}')
await self._notify_failure()
return False
await LoadingScreen(5, 'Connecting wifi...').run()
return True
async def _notify_failure(self) -> None:
await NotifyScreen(header=tr('Failed setting up wifi')).run()
def _wpa_cli(self, command: str, iface: str | None = None) -> WpaCliResult:
cmd = 'wpa_cli'
if iface:
cmd += f' -i {iface}'
cmd += f' {command}'
try:
result = SysCommand(cmd).decode()
if 'FAIL' in result:
debug(f'wpa_cli returned FAIL: {result}')
return WpaCliResult(
success=False,
error=f'wpa_cli returned a failure: {result}',
)
return WpaCliResult(success=True, response=result)
except SysCallError as err:
debug(f'error running wpa_cli command: {err}')
return WpaCliResult(
success=False,
error=f'Error running wpa_cli command: {err}',
)
def _find_network_id(self, ssid: str, iface: str) -> int | None:
result = self._wpa_cli('list_networks', iface)
if not result.success:
debug(f'Failed to list networks: {result.error}')
return None
list_networks = result.response
if not list_networks:
debug('No networks found')
return None
existing_networks = WifiConfiguredNetwork.from_wpa_cli_output(list_networks)
for network in existing_networks:
if network.ssid == ssid:
return network.network_id
return None
async def _prompt_psk(self, existing: str | None = None) -> str | None:
result = await InputScreen(
header=tr('Enter wifi password'),
password=True,
allow_skip=True,
allow_reset=True,
default_value=existing,
).run()
if result.type_ != ResultType.Selection:
debug('No password provided, aborting connection')
return None
return result.value()
def _get_scan_results(self, iface: str) -> list[WifiNetwork]:
debug(f'Retrieving scan results: {iface}')
try:
result = self._wpa_cli('scan_results', iface)
if not result.success:
debug(f'Failed to retrieve scan results: {result.error}')
return []
if not result.response:
debug('No wifi networks found')
return []
networks = WifiNetwork.from_wpa(result.response)
return networks
except SysCallError as err:
debug('Unable to retrieve wifi results')
raise err
wifi_handler = WifiHandler()

View File

@ -0,0 +1,136 @@
from dataclasses import dataclass, field
from pathlib import Path
from archinstall.lib.models.network import WifiNetwork
from archinstall.lib.output import debug
@dataclass
class WpaSupplicantNetwork:
mappings: dict[str, str] = field(default_factory=dict)
@property
def psk(self) -> str:
return self.mappings['psk'].strip('"')
@property
def ssid(self) -> str:
return self.mappings['ssid'].strip('"')
def to_config_entry(self) -> str:
wpa_net_config = '\n'
wpa_net_config += 'network={\n'
for key, value in self.mappings.items():
wpa_net_config += f'\t{key}={value}\n'
if 'mesh_fwding' not in self.mappings:
wpa_net_config += '\tmesh_fwding=1\n'
wpa_net_config += '}\n\n'
return wpa_net_config
class WpaSupplicantConfig:
def __init__(self) -> None:
self.config_file = Path('/etc/wpa_supplicant/wpa_supplicant.conf')
self._wpa_networks: list[WpaSupplicantNetwork] = []
def load_config(self) -> None:
if not self.config_file.is_file():
debug('wpa_supplicant.conf not found, creating')
self._create_config()
else:
debug('wpa_supplicant.conf found')
content = self.config_file.read_text()
config_header = ''
if 'ctrl_interface' not in content:
config_header += 'ctrl_interface=/run/wpa_supplicant\n'
if 'update_config' not in content:
config_header += 'update_config=1\n\n'
if config_header:
config = config_header + content
self.config_file.write_text(config)
self._wpa_networks = self._parse_config()
def _config_header(self) -> str:
return 'ctrl_interface=/run/wpa_supplicant\nupdate_config=1'
def get_existing_network(self, ssid: str) -> WpaSupplicantNetwork | None:
ssid = f'"{ssid}"'
for network in self._wpa_networks:
if network.mappings['ssid'] == ssid:
return network
return None
def set_network(self, network: WifiNetwork, psk: str) -> None:
debug('setting new wifi network')
existing_network = self.get_existing_network(network.ssid)
if not existing_network:
wpa_net_config = WpaSupplicantNetwork(
mappings={
'ssid': f'"{network.ssid}"',
'psk': f'"{psk}"',
}
)
self._wpa_networks.append(wpa_net_config)
else:
existing_network.mappings['psk'] = f'"{psk}"'
def write_config(self) -> None:
debug('writing wpa_supplicant config')
config = self._config_header()
config += '\n\n'
for network in self._wpa_networks:
config += network.to_config_entry()
self.config_file.write_text(config)
def _create_config(self) -> None:
self.config_file.touch()
header = self._config_header()
self.config_file.write_text(header)
def _parse_config(self) -> list[WpaSupplicantNetwork]:
content = self.config_file.read_text()
networks: list[WpaSupplicantNetwork] = []
in_network_block = False
cur_net_data: dict[str, str] = {}
for line in content.splitlines():
line = line.strip()
if not line or line.startswith('#'):
continue
if line == 'network={':
in_network_block = True
cur_net_data = {}
continue
if in_network_block and line == '}':
new_network = WpaSupplicantNetwork(
mappings=cur_net_data,
)
networks.append(new_network)
in_network_block = False
continue
if in_network_block:
if '=' in line:
key, value = line.split('=', 1)
cur_net_data[key.strip()] = value.strip()
return networks

View File

@ -33,16 +33,16 @@ class MenuItem:
return self.value
@classmethod
def yes(cls) -> 'MenuItem':
def yes(cls, action: Callable[[Any], Any] | None = None) -> 'MenuItem':
if cls._yes is None:
cls._yes = cls(tr('Yes'), value=True)
cls._yes = cls(tr('Yes'), value=True, key='yes', action=action)
return cls._yes
@classmethod
def no(cls) -> 'MenuItem':
def no(cls, action: Callable[[Any], Any] | None = None) -> 'MenuItem':
if cls._no is None:
cls._no = cls(tr('No'), value=True)
cls._no = cls(tr('No'), value=False, key='no', action=action)
return cls._no
@ -223,6 +223,10 @@ class MenuItemGroup:
return tr(' (default)')
return ''
def set_action_for_all(self, action: Callable[[Any], Any]) -> None:
for item in self.items:
item.action = action
@cached_property
def items(self) -> list[MenuItem]:
pattern = self._filter_pattern.lower()

View File

View File

@ -0,0 +1,509 @@
from __future__ import annotations
from collections.abc import Awaitable, Callable
from typing import Any, TypeVar, override
from textual import work
from textual.app import App, ComposeResult
from textual.binding import Binding
from textual.containers import Center, Horizontal, Vertical
from textual.events import Key
from textual.screen import Screen
from textual.widgets import Button, DataTable, Input, LoadingIndicator, Static
from archinstall.lib.output import debug
from archinstall.lib.translationhandler import tr
from archinstall.tui.menu_item import MenuItem, MenuItemGroup
from archinstall.tui.ui.result import Result, ResultType
ValueT = TypeVar('ValueT')
class BaseScreen(Screen[Result[ValueT]]):
BINDINGS = [ # noqa: RUF012
Binding('escape', 'cancel_operation', 'Cancel', show=True),
Binding('ctrl+c', 'reset_operation', 'Reset', show=True),
]
def __init__(self, allow_skip: bool = False, allow_reset: bool = False):
super().__init__()
self._allow_skip = allow_skip
self._allow_reset = allow_reset
def action_cancel_operation(self) -> None:
if self._allow_skip:
self.dismiss(Result(ResultType.Skip, None)) # type: ignore[unused-awaitable]
def action_reset_operation(self) -> None:
if self._allow_reset:
self.dismiss(Result(ResultType.Reset, None)) # type: ignore[unused-awaitable]
def _compose_header(self) -> ComposeResult:
"""Compose the app header if global header text is available."""
if tui.global_header:
yield Static(tui.global_header, classes='app-header')
class LoadingScreen(BaseScreen[None]):
CSS = """
LoadingScreen {
align: center middle;
}
.dialog {
align: center middle;
width: 100%;
border: none;
background: transparent;
}
.header {
text-align: center;
margin-bottom: 1;
}
LoadingIndicator {
align: center middle;
}
"""
def __init__(
self,
timer: int,
header: str | None = None,
):
super().__init__()
self._timer = timer
self._header = header
async def run(self) -> Result[None]:
return await tui.show(self)
@override
def compose(self) -> ComposeResult:
yield from self._compose_header()
with Center():
with Vertical(classes='dialog'):
if self._header:
yield Static(self._header, classes='header')
yield Center(LoadingIndicator()) # ensures indicator is centered too
def on_mount(self) -> None:
self.set_timer(self._timer, self.action_pop_screen)
def action_pop_screen(self) -> None:
self.dismiss() # type: ignore[unused-awaitable]
class ConfirmationScreen(BaseScreen[ValueT]):
BINDINGS = [ # noqa: RUF012
Binding('l', 'focus_right', 'Focus right', show=True),
Binding('h', 'focus_left', 'Focus left', show=True),
Binding('right', 'focus_right', 'Focus right', show=True),
Binding('left', 'focus_left', 'Focus left', show=True),
]
CSS = """
ConfirmationScreen {
align: center middle;
}
.dialog-wrapper {
align: center middle;
height: 100%;
width: 100%;
}
.dialog {
width: 80;
height: 10;
border: none;
background: transparent;
}
.dialog-content {
padding: 1;
height: 100%;
}
.message {
text-align: center;
margin-bottom: 1;
}
.buttons {
align: center middle;
background: transparent;
}
Button {
width: 4;
height: 3;
background: transparent;
margin: 0 1;
}
Button.-active {
background: #1793D1;
color: white;
border: none;
text-style: none;
}
"""
def __init__(
self,
group: MenuItemGroup,
header: str,
allow_skip: bool = False,
allow_reset: bool = False,
):
super().__init__(allow_skip, allow_reset)
self._group = group
self._header = header
async def run(self) -> Result[ValueT]:
return await tui.show(self)
@override
def compose(self) -> ComposeResult:
yield from self._compose_header()
with Center(classes='dialog-wrapper'):
with Vertical(classes='dialog'):
with Vertical(classes='dialog-content'):
yield Static(self._header, classes='message')
with Horizontal(classes='buttons'):
for item in self._group.items:
yield Button(item.text, id=item.key)
def on_mount(self) -> None:
self.update_selection()
def update_selection(self) -> None:
focused = self._group.focus_item
buttons = self.query(Button)
if not focused:
return
for button in buttons:
if button.id == focused.key:
button.add_class('-active')
button.focus()
else:
button.remove_class('-active')
def action_focus_right(self) -> None:
self._group.focus_next()
self.update_selection()
def action_focus_left(self) -> None:
self._group.focus_prev()
self.update_selection()
def on_key(self, event: Key) -> None:
if event.key == 'enter':
item = self._group.focus_item
if not item:
return None
self.dismiss(Result(ResultType.Selection, item.value)) # type: ignore[unused-awaitable]
class NotifyScreen(ConfirmationScreen[ValueT]):
def __init__(self, header: str):
group = MenuItemGroup([MenuItem(tr('Ok'))])
super().__init__(group, header)
class InputScreen(BaseScreen[str]):
CSS = """
InputScreen {
}
.dialog-wrapper {
align: center middle;
height: 100%;
width: 100%;
}
.input-dialog {
width: 60;
height: 10;
border: none;
background: transparent;
}
.input-content {
padding: 1;
height: 100%;
}
.input-header {
text-align: center;
margin: 0 0;
color: white;
text-style: bold;
background: transparent;
}
.input-prompt {
text-align: center;
margin: 0 0 1 0;
background: transparent;
}
Input {
margin: 1 2;
border: solid $accent;
background: transparent;
height: 3;
}
Input .input--cursor {
color: white;
}
Input:focus {
border: solid $primary;
}
"""
def __init__(
self,
header: str,
placeholder: str | None = None,
password: bool = False,
default_value: str | None = None,
allow_reset: bool = False,
allow_skip: bool = False,
):
super().__init__(allow_skip, allow_reset)
self._header = header
self._placeholder = placeholder or ''
self._password = password
self._default_value = default_value or ''
self._allow_reset = allow_reset
self._allow_skip = allow_skip
async def run(self) -> Result[str]:
return await tui.show(self)
@override
def compose(self) -> ComposeResult:
yield from self._compose_header()
with Center(classes='dialog-wrapper'):
with Vertical(classes='input-dialog'):
with Vertical(classes='input-content'):
yield Static(self._header, classes='input-header')
yield Input(
placeholder=self._placeholder,
password=self._password,
value=self._default_value,
id='main_input',
)
def on_mount(self) -> None:
input_field = self.query_one('#main_input', Input)
input_field.focus()
def on_key(self, event: Key) -> None:
if event.key == 'enter':
input_field = self.query_one('#main_input', Input)
value = input_field.value
self.dismiss(Result(ResultType.Selection, value)) # type: ignore[unused-awaitable]
class TableSelectionScreen(BaseScreen[ValueT]):
BINDINGS = [ # noqa: RUF012
Binding('j', 'cursor_down', 'Down', show=True),
Binding('k', 'cursor_up', 'Up', show=True),
]
CSS = """
TableSelectionScreen {
align: center middle;
background: transparent;
}
DataTable {
height: auto;
width: auto;
border: none;
background: transparent;
}
DataTable .datatable--header {
background: transparent;
border: solid;
}
.content-container {
width: auto;
min-height: 10;
min-width: 40;
align: center middle;
background: transparent;
}
.header {
text-align: center;
margin-bottom: 1;
}
LoadingIndicator {
height: auto;
background: transparent;
}
"""
def __init__(
self,
header: str | None = None,
data: list[ValueT] | None = None,
data_callback: Callable[[], Awaitable[list[ValueT]]] | None = None,
allow_reset: bool = False,
allow_skip: bool = False,
loading_header: str | None = None,
):
super().__init__(allow_skip, allow_reset)
self._header = header
self._data = data
self._data_callback = data_callback
self._loading_header = loading_header
if self._data is None and self._data_callback is None:
raise ValueError('Either data or data_callback must be provided')
async def run(self) -> Result[ValueT]:
return await tui.show(self)
def action_cursor_down(self) -> None:
table = self.query_one(DataTable)
if table.cursor_row is not None:
next_row = min(table.cursor_row + 1, len(table.rows) - 1)
table.move_cursor(row=next_row, column=table.cursor_column or 0)
def action_cursor_up(self) -> None:
table = self.query_one(DataTable)
if table.cursor_row is not None:
prev_row = max(table.cursor_row - 1, 0)
table.move_cursor(row=prev_row, column=table.cursor_column or 0)
@override
def compose(self) -> ComposeResult:
yield from self._compose_header()
with Center():
with Vertical(classes='content-container'):
if self._header:
yield Static(self._header, classes='header', id='header')
if self._loading_header:
yield Static(self._loading_header, classes='header', id='loading-header')
yield LoadingIndicator(id='loader')
yield DataTable(id='data_table')
def on_mount(self) -> None:
self._display_header(True)
data_table = self.query_one(DataTable)
data_table.cell_padding = 2
if self._data:
self._put_data_to_table(data_table, self._data)
else:
self._load_data(data_table)
@work
async def _load_data(self, table: DataTable[ValueT]) -> None:
assert self._data_callback is not None
data = await self._data_callback()
self._put_data_to_table(table, data)
def _display_header(self, is_loading: bool) -> None:
try:
loading_header = self.query_one('#loading-header', Static)
header = self.query_one('#header', Static)
loading_header.display = is_loading
header.display = not is_loading
except Exception:
pass
def _put_data_to_table(self, table: DataTable[ValueT], data: list[ValueT]) -> None:
if not data:
self.dismiss(Result(ResultType.Selection, None)) # type: ignore[unused-awaitable]
return
cols = list(data[0].table_data().keys()) # type: ignore[attr-defined]
table.add_columns(*cols)
for d in data:
row_values = list(d.table_data().values()) # type: ignore[attr-defined]
table.add_row(*row_values, key=d) # type: ignore[arg-type]
table.cursor_type = 'row'
table.display = True
loader = self.query_one('#loader')
loader.display = False
self._display_header(False)
table.focus()
def on_data_table_row_selected(self, event: DataTable.RowSelected) -> None:
data: ValueT = event.row_key.value # type: ignore[assignment]
self.dismiss(Result(ResultType.Selection, data)) # type: ignore[unused-awaitable]
class TApp(App[Any]):
CSS = """
.app-header {
dock: top;
height: auto;
width: 100%;
content-align: center middle;
background: $primary;
color: white;
text-style: bold;
}
"""
def __init__(self) -> None:
super().__init__(ansi_color=True)
self._main = None
self._global_header: str | None = None
@property
def global_header(self) -> str | None:
return self._global_header
@global_header.setter
def global_header(self, value: str | None) -> None:
self._global_header = value
def set_main(self, main: Any) -> None:
self._main = main
def on_mount(self) -> None:
self._run_worker()
@work
async def _run_worker(self) -> None:
try:
if self._main is not None:
await self._main.run() # type: ignore[unreachable]
except Exception as err:
debug(f'Error while running main app: {err}')
raise err from err
@work
async def _show_async(self, screen: Screen[Result[ValueT]]) -> Result[ValueT]:
return await self.push_screen_wait(screen)
async def show(self, screen: Screen[Result[ValueT]]) -> Result[ValueT]:
return await self._show_async(screen).wait()
tui = TApp()

View File

@ -0,0 +1,26 @@
from dataclasses import dataclass
from enum import Enum, auto
from typing import cast
class ResultType(Enum):
Selection = auto()
Skip = auto()
Reset = auto()
@dataclass
class Result[ValueT]:
type_: ResultType
_data: ValueT | list[ValueT] | None
def has_data(self) -> bool:
return self._data is not None
def value(self) -> ValueT:
assert type(self._data) is not list and self._data is not None
return cast(ValueT, self._data)
def values(self) -> list[ValueT]:
assert type(self._data) is list
return cast(list[ValueT], self._data)

View File

@ -21,6 +21,8 @@ dependencies = [
"pyparted>=3.13.0",
"pydantic==2.12.3",
"cryptography>=45.0.7",
"textual>=5.3.0",
"pytest-mock>=3.15.1",
]
[project.urls]