From 007f2ff7973f498569328bb66b4a9c02b6da478e Mon Sep 17 00:00:00 2001 From: correctmost <134317971+correctmost@users.noreply.github.com> Date: Sat, 30 Nov 2024 06:56:46 -0500 Subject: [PATCH] Use TypedDict to annotate disk-related serializations (#2935) --- archinstall/lib/disk/device_model.py | 135 +++++++++++++++++++++------ 1 file changed, 107 insertions(+), 28 deletions(-) diff --git a/archinstall/lib/disk/device_model.py b/archinstall/lib/disk/device_model.py index 86ae2ca5..a42834fa 100644 --- a/archinstall/lib/disk/device_model.py +++ b/archinstall/lib/disk/device_model.py @@ -5,7 +5,7 @@ import uuid from dataclasses import dataclass, field from enum import Enum from pathlib import Path -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, NotRequired, TypedDict import parted from parted import Disk, Geometry, Partition @@ -39,6 +39,13 @@ class DiskLayoutType(Enum): return str(_('Pre-mounted configuration')) +class _DiskLayoutConfigurationSerialization(TypedDict): + config_type: str + device_modifications: NotRequired[list[_DeviceModificationSerialization]] + lvm_config: NotRequired[_LvmConfigurationSerialization] + mountpoint: NotRequired[str] + + @dataclass class DiskLayoutConfiguration: config_type: DiskLayoutType @@ -48,14 +55,14 @@ class DiskLayoutConfiguration: # used for pre-mounted config mountpoint: Path | None = None - def json(self) -> dict[str, Any]: + def json(self) -> _DiskLayoutConfigurationSerialization: if self.config_type == DiskLayoutType.Pre_mount: return { 'config_type': self.config_type.value, 'mountpoint': str(self.mountpoint) } else: - config: dict[str, Any] = { + config: _DiskLayoutConfigurationSerialization = { 'config_type': self.config_type.value, 'device_modifications': [mod.json() for mod in self.device_modifications], } @@ -66,7 +73,7 @@ class DiskLayoutConfiguration: return config @classmethod - def parse_arg(cls, disk_config: dict[str, Any]) -> DiskLayoutConfiguration | None: + def parse_arg(cls, disk_config: _DiskLayoutConfigurationSerialization) -> DiskLayoutConfiguration | None: from .device_handler import device_handler device_modifications: list[DeviceModification] = [] @@ -96,7 +103,7 @@ class DiskLayoutConfiguration: return config for entry in disk_config.get('device_modifications', []): - device_path = Path(entry.get('device', None)) if entry.get('device', None) else None + device_path = Path(entry['device']) if entry.get('device', None) else None if not device_path: continue @@ -190,6 +197,11 @@ class Unit(Enum): return [u for u in Unit if 'i' in u.name or u.name == 'B'] +class _SectorSizeSerialization(TypedDict): + value: int + unit: str + + @dataclass class SectorSize: value: int @@ -204,14 +216,14 @@ class SectorSize: def default() -> SectorSize: return SectorSize(512, Unit.B) - def json(self) -> dict[str, Any]: + def json(self) -> _SectorSizeSerialization: return { 'value': self.value, 'unit': self.unit.name, } @classmethod - def parse_args(cls, arg: dict[str, Any]) -> SectorSize: + def parse_args(cls, arg: _SectorSizeSerialization) -> SectorSize: return SectorSize( arg['value'], Unit[arg['unit']] @@ -224,6 +236,12 @@ class SectorSize: return int(self.value * self.unit.value) +class _SizeSerialization(TypedDict): + value: int + unit: str + sector_size: _SectorSizeSerialization + + @dataclass class Size: value: int @@ -234,15 +252,15 @@ class Size: if not isinstance(self.sector_size, SectorSize): raise ValueError('sector size must be of type SectorSize') - def json(self) -> dict[str, Any]: + def json(self) -> _SizeSerialization: return { 'value': self.value, 'unit': self.unit.name, - 'sector_size': self.sector_size.json() if self.sector_size else None + 'sector_size': self.sector_size.json() } @classmethod - def parse_args(cls, size_arg: dict[str, Any]) -> Size: + def parse_args(cls, size_arg: _SizeSerialization) -> Size: sector_size = size_arg['sector_size'] return Size( @@ -516,9 +534,14 @@ class _DeviceInfo: ) +class _SubvolumeModificationSerialization(TypedDict): + name: str + mountpoint: str + + @dataclass class SubvolumeModification: - name: Path + name: Path | str mountpoint: Path | None = None @classmethod @@ -526,7 +549,7 @@ class SubvolumeModification: return SubvolumeModification(info.name, mountpoint=info.mountpoint) @classmethod - def parse_args(cls, subvol_args: list[dict[str, Any]]) -> list[SubvolumeModification]: + def parse_args(cls, subvol_args: list[_SubvolumeModificationSerialization]) -> list[SubvolumeModification]: mods = [] for entry in subvol_args: if not entry.get('name', None) or not entry.get('mountpoint', None): @@ -555,10 +578,10 @@ class SubvolumeModification: return self.mountpoint == Path('/') return False - def json(self) -> dict[str, Any]: + def json(self) -> _SubvolumeModificationSerialization: return {'name': str(self.name), 'mountpoint': str(self.mountpoint)} - def table_data(self) -> dict[str, Any]: + def table_data(self) -> _SubvolumeModificationSerialization: return self.json() @@ -738,6 +761,20 @@ class ModificationStatus(Enum): Create = 'create' +class _PartitionModificationSerialization(TypedDict): + obj_id: str + status: str + type: str + start: _SizeSerialization + size: _SizeSerialization + fs_type: str | None + mountpoint: str | None + mount_options: list[str] + flags: list[str] + btrfs: list[_SubvolumeModificationSerialization] + dev_path: str | None + + @dataclass class PartitionModification: status: ModificationStatus @@ -891,7 +928,7 @@ class PartitionModification: else: self.set_flag(flag) - def json(self) -> dict[str, Any]: + def json(self) -> _PartitionModificationSerialization: """ Called for configuration settings """ @@ -947,13 +984,19 @@ class LvmLayoutType(Enum): raise ValueError(f'Unknown type: {self}') +class _LvmVolumeGroupSerialization(TypedDict): + name: str + lvm_pvs: list[str] + volumes: list[_LvmVolumeSerialization] + + @dataclass class LvmVolumeGroup: name: str pvs: list[PartitionModification] volumes: list[LvmVolume] = field(default_factory=list) - def json(self) -> dict[str, Any]: + def json(self) -> _LvmVolumeGroupSerialization: return { 'name': self.name, 'lvm_pvs': [p.obj_id for p in self.pvs], @@ -961,7 +1004,7 @@ class LvmVolumeGroup: } @staticmethod - def parse_arg(arg: dict[str, Any], disk_config: DiskLayoutConfiguration) -> LvmVolumeGroup: + def parse_arg(arg: _LvmVolumeGroupSerialization, disk_config: DiskLayoutConfiguration) -> LvmVolumeGroup: lvm_pvs = [] for mod in disk_config.device_modifications: for part in mod.partitions: @@ -985,6 +1028,17 @@ class LvmVolumeStatus(Enum): Create = 'create' +class _LvmVolumeSerialization(TypedDict): + obj_id: str + status: str + name: str + fs_type: str + length: _SizeSerialization + mountpoint: str | None + mount_options: list[str] + btrfs: list[_SubvolumeModificationSerialization] + + @dataclass class LvmVolume: status: LvmVolumeStatus @@ -1051,7 +1105,7 @@ class LvmVolume: raise ValueError('Mountpoint is not specified') @staticmethod - def parse_arg(arg: dict[str, Any]) -> LvmVolume: + def parse_arg(arg: _LvmVolumeSerialization) -> LvmVolume: volume = LvmVolume( status=LvmVolumeStatus(arg['status']), name=arg['name'], @@ -1066,7 +1120,7 @@ class LvmVolume: return volume - def json(self) -> dict[str, Any]: + def json(self) -> _LvmVolumeSerialization: return { 'obj_id': self.obj_id, 'status': self.status.value, @@ -1130,6 +1184,11 @@ class LvmPVInfo: vg_name: str +class _LvmConfigurationSerialization(TypedDict): + config_type: str + vol_groups: list[_LvmVolumeGroupSerialization] + + @dataclass class LvmConfiguration: config_type: LvmLayoutType @@ -1144,18 +1203,19 @@ class LvmConfiguration: raise ValueError('A PV cannot be used in multiple volume groups') pvs.append(pv) - def json(self) -> dict[str, Any]: + def json(self) -> _LvmConfigurationSerialization: return { 'config_type': self.config_type.value, 'vol_groups': [vol_gr.json() for vol_gr in self.vol_groups] } @staticmethod - def parse_arg(arg: dict[str, Any], disk_config: DiskLayoutConfiguration) -> LvmConfiguration: + def parse_arg(arg: _LvmConfigurationSerialization, disk_config: DiskLayoutConfiguration) -> LvmConfiguration: lvm_pvs = [] for mod in disk_config.device_modifications: for part in mod.partitions: - if part.obj_id in arg.get('lvm_pvs', []): + # FIXME: 'lvm_pvs' does not seem like it can ever exist in the 'arg' serialization + if part.obj_id in arg.get('lvm_pvs', []): # type: ignore[operator] lvm_pvs.append(part) return LvmConfiguration( @@ -1196,6 +1256,12 @@ class LvmConfiguration: # if vg.contains_lv(lv): +class _DeviceModificationSerialization(TypedDict): + device: str + wipe: bool + partitions: list[_PartitionModificationSerialization] + + @dataclass class DeviceModification: device: BDevice @@ -1235,7 +1301,7 @@ class DeviceModification: filtered = filter(lambda x: x.is_root(), self.partitions) return next(filtered, None) - def json(self) -> dict[str, Any]: + def json(self) -> _DeviceModificationSerialization: """ Called when generating configuration files """ @@ -1273,6 +1339,13 @@ class EncryptionType(Enum): return type_to_text[type_] +class _DiskEncryptionSerialization(TypedDict): + encryption_type: str + partitions: list[str] + lvm_volumes: list[str] + hsm_device: NotRequired[_Fido2DeviceSerialization] + + @dataclass class DiskEncryption: encryption_type: EncryptionType = EncryptionType.NoEncryption @@ -1295,8 +1368,8 @@ class DiskEncryption: return dev in self.lvm_volumes and dev.mountpoint != Path('/') return False - def json(self) -> dict[str, Any]: - obj: dict[str, Any] = { + def json(self) -> _DiskEncryptionSerialization: + obj: _DiskEncryptionSerialization = { 'encryption_type': self.encryption_type.value, 'partitions': [p.obj_id for p in self.partitions], 'lvm_volumes': [vol.obj_id for vol in self.lvm_volumes] @@ -1325,7 +1398,7 @@ class DiskEncryption: def parse_arg( cls, disk_config: DiskLayoutConfiguration, - disk_encryption: dict[str, Any], + disk_encryption: _DiskEncryptionSerialization, password: str = '' ) -> 'DiskEncryption | None': if not cls.validate_enc(disk_config): @@ -1359,13 +1432,19 @@ class DiskEncryption: return enc +class _Fido2DeviceSerialization(TypedDict): + path: str + manufacturer: str + product: str + + @dataclass class Fido2Device: path: Path manufacturer: str product: str - def json(self) -> dict[str, str]: + def json(self) -> _Fido2DeviceSerialization: return { 'path': str(self.path), 'manufacturer': self.manufacturer, @@ -1380,7 +1459,7 @@ class Fido2Device: } @classmethod - def parse_arg(cls, arg: dict[str, str]) -> 'Fido2Device': + def parse_arg(cls, arg: _Fido2DeviceSerialization) -> 'Fido2Device': return Fido2Device( Path(arg['path']), arg['manufacturer'],