Use TypedDict to annotate disk-related serializations (#2935)

This commit is contained in:
correctmost 2024-11-30 06:56:46 -05:00 committed by GitHub
parent 11f8490b59
commit 007f2ff797
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 107 additions and 28 deletions

View File

@ -5,7 +5,7 @@ import uuid
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any, NotRequired, TypedDict
import parted import parted
from parted import Disk, Geometry, Partition from parted import Disk, Geometry, Partition
@ -39,6 +39,13 @@ class DiskLayoutType(Enum):
return str(_('Pre-mounted configuration')) return str(_('Pre-mounted configuration'))
class _DiskLayoutConfigurationSerialization(TypedDict):
config_type: str
device_modifications: NotRequired[list[_DeviceModificationSerialization]]
lvm_config: NotRequired[_LvmConfigurationSerialization]
mountpoint: NotRequired[str]
@dataclass @dataclass
class DiskLayoutConfiguration: class DiskLayoutConfiguration:
config_type: DiskLayoutType config_type: DiskLayoutType
@ -48,14 +55,14 @@ class DiskLayoutConfiguration:
# used for pre-mounted config # used for pre-mounted config
mountpoint: Path | None = None mountpoint: Path | None = None
def json(self) -> dict[str, Any]: def json(self) -> _DiskLayoutConfigurationSerialization:
if self.config_type == DiskLayoutType.Pre_mount: if self.config_type == DiskLayoutType.Pre_mount:
return { return {
'config_type': self.config_type.value, 'config_type': self.config_type.value,
'mountpoint': str(self.mountpoint) 'mountpoint': str(self.mountpoint)
} }
else: else:
config: dict[str, Any] = { config: _DiskLayoutConfigurationSerialization = {
'config_type': self.config_type.value, 'config_type': self.config_type.value,
'device_modifications': [mod.json() for mod in self.device_modifications], 'device_modifications': [mod.json() for mod in self.device_modifications],
} }
@ -66,7 +73,7 @@ class DiskLayoutConfiguration:
return config return config
@classmethod @classmethod
def parse_arg(cls, disk_config: dict[str, Any]) -> DiskLayoutConfiguration | None: def parse_arg(cls, disk_config: _DiskLayoutConfigurationSerialization) -> DiskLayoutConfiguration | None:
from .device_handler import device_handler from .device_handler import device_handler
device_modifications: list[DeviceModification] = [] device_modifications: list[DeviceModification] = []
@ -96,7 +103,7 @@ class DiskLayoutConfiguration:
return config return config
for entry in disk_config.get('device_modifications', []): for entry in disk_config.get('device_modifications', []):
device_path = Path(entry.get('device', None)) if entry.get('device', None) else None device_path = Path(entry['device']) if entry.get('device', None) else None
if not device_path: if not device_path:
continue continue
@ -190,6 +197,11 @@ class Unit(Enum):
return [u for u in Unit if 'i' in u.name or u.name == 'B'] return [u for u in Unit if 'i' in u.name or u.name == 'B']
class _SectorSizeSerialization(TypedDict):
value: int
unit: str
@dataclass @dataclass
class SectorSize: class SectorSize:
value: int value: int
@ -204,14 +216,14 @@ class SectorSize:
def default() -> SectorSize: def default() -> SectorSize:
return SectorSize(512, Unit.B) return SectorSize(512, Unit.B)
def json(self) -> dict[str, Any]: def json(self) -> _SectorSizeSerialization:
return { return {
'value': self.value, 'value': self.value,
'unit': self.unit.name, 'unit': self.unit.name,
} }
@classmethod @classmethod
def parse_args(cls, arg: dict[str, Any]) -> SectorSize: def parse_args(cls, arg: _SectorSizeSerialization) -> SectorSize:
return SectorSize( return SectorSize(
arg['value'], arg['value'],
Unit[arg['unit']] Unit[arg['unit']]
@ -224,6 +236,12 @@ class SectorSize:
return int(self.value * self.unit.value) return int(self.value * self.unit.value)
class _SizeSerialization(TypedDict):
value: int
unit: str
sector_size: _SectorSizeSerialization
@dataclass @dataclass
class Size: class Size:
value: int value: int
@ -234,15 +252,15 @@ class Size:
if not isinstance(self.sector_size, SectorSize): if not isinstance(self.sector_size, SectorSize):
raise ValueError('sector size must be of type SectorSize') raise ValueError('sector size must be of type SectorSize')
def json(self) -> dict[str, Any]: def json(self) -> _SizeSerialization:
return { return {
'value': self.value, 'value': self.value,
'unit': self.unit.name, 'unit': self.unit.name,
'sector_size': self.sector_size.json() if self.sector_size else None 'sector_size': self.sector_size.json()
} }
@classmethod @classmethod
def parse_args(cls, size_arg: dict[str, Any]) -> Size: def parse_args(cls, size_arg: _SizeSerialization) -> Size:
sector_size = size_arg['sector_size'] sector_size = size_arg['sector_size']
return Size( return Size(
@ -516,9 +534,14 @@ class _DeviceInfo:
) )
class _SubvolumeModificationSerialization(TypedDict):
name: str
mountpoint: str
@dataclass @dataclass
class SubvolumeModification: class SubvolumeModification:
name: Path name: Path | str
mountpoint: Path | None = None mountpoint: Path | None = None
@classmethod @classmethod
@ -526,7 +549,7 @@ class SubvolumeModification:
return SubvolumeModification(info.name, mountpoint=info.mountpoint) return SubvolumeModification(info.name, mountpoint=info.mountpoint)
@classmethod @classmethod
def parse_args(cls, subvol_args: list[dict[str, Any]]) -> list[SubvolumeModification]: def parse_args(cls, subvol_args: list[_SubvolumeModificationSerialization]) -> list[SubvolumeModification]:
mods = [] mods = []
for entry in subvol_args: for entry in subvol_args:
if not entry.get('name', None) or not entry.get('mountpoint', None): if not entry.get('name', None) or not entry.get('mountpoint', None):
@ -555,10 +578,10 @@ class SubvolumeModification:
return self.mountpoint == Path('/') return self.mountpoint == Path('/')
return False return False
def json(self) -> dict[str, Any]: def json(self) -> _SubvolumeModificationSerialization:
return {'name': str(self.name), 'mountpoint': str(self.mountpoint)} return {'name': str(self.name), 'mountpoint': str(self.mountpoint)}
def table_data(self) -> dict[str, Any]: def table_data(self) -> _SubvolumeModificationSerialization:
return self.json() return self.json()
@ -738,6 +761,20 @@ class ModificationStatus(Enum):
Create = 'create' Create = 'create'
class _PartitionModificationSerialization(TypedDict):
obj_id: str
status: str
type: str
start: _SizeSerialization
size: _SizeSerialization
fs_type: str | None
mountpoint: str | None
mount_options: list[str]
flags: list[str]
btrfs: list[_SubvolumeModificationSerialization]
dev_path: str | None
@dataclass @dataclass
class PartitionModification: class PartitionModification:
status: ModificationStatus status: ModificationStatus
@ -891,7 +928,7 @@ class PartitionModification:
else: else:
self.set_flag(flag) self.set_flag(flag)
def json(self) -> dict[str, Any]: def json(self) -> _PartitionModificationSerialization:
""" """
Called for configuration settings Called for configuration settings
""" """
@ -947,13 +984,19 @@ class LvmLayoutType(Enum):
raise ValueError(f'Unknown type: {self}') raise ValueError(f'Unknown type: {self}')
class _LvmVolumeGroupSerialization(TypedDict):
name: str
lvm_pvs: list[str]
volumes: list[_LvmVolumeSerialization]
@dataclass @dataclass
class LvmVolumeGroup: class LvmVolumeGroup:
name: str name: str
pvs: list[PartitionModification] pvs: list[PartitionModification]
volumes: list[LvmVolume] = field(default_factory=list) volumes: list[LvmVolume] = field(default_factory=list)
def json(self) -> dict[str, Any]: def json(self) -> _LvmVolumeGroupSerialization:
return { return {
'name': self.name, 'name': self.name,
'lvm_pvs': [p.obj_id for p in self.pvs], 'lvm_pvs': [p.obj_id for p in self.pvs],
@ -961,7 +1004,7 @@ class LvmVolumeGroup:
} }
@staticmethod @staticmethod
def parse_arg(arg: dict[str, Any], disk_config: DiskLayoutConfiguration) -> LvmVolumeGroup: def parse_arg(arg: _LvmVolumeGroupSerialization, disk_config: DiskLayoutConfiguration) -> LvmVolumeGroup:
lvm_pvs = [] lvm_pvs = []
for mod in disk_config.device_modifications: for mod in disk_config.device_modifications:
for part in mod.partitions: for part in mod.partitions:
@ -985,6 +1028,17 @@ class LvmVolumeStatus(Enum):
Create = 'create' Create = 'create'
class _LvmVolumeSerialization(TypedDict):
obj_id: str
status: str
name: str
fs_type: str
length: _SizeSerialization
mountpoint: str | None
mount_options: list[str]
btrfs: list[_SubvolumeModificationSerialization]
@dataclass @dataclass
class LvmVolume: class LvmVolume:
status: LvmVolumeStatus status: LvmVolumeStatus
@ -1051,7 +1105,7 @@ class LvmVolume:
raise ValueError('Mountpoint is not specified') raise ValueError('Mountpoint is not specified')
@staticmethod @staticmethod
def parse_arg(arg: dict[str, Any]) -> LvmVolume: def parse_arg(arg: _LvmVolumeSerialization) -> LvmVolume:
volume = LvmVolume( volume = LvmVolume(
status=LvmVolumeStatus(arg['status']), status=LvmVolumeStatus(arg['status']),
name=arg['name'], name=arg['name'],
@ -1066,7 +1120,7 @@ class LvmVolume:
return volume return volume
def json(self) -> dict[str, Any]: def json(self) -> _LvmVolumeSerialization:
return { return {
'obj_id': self.obj_id, 'obj_id': self.obj_id,
'status': self.status.value, 'status': self.status.value,
@ -1130,6 +1184,11 @@ class LvmPVInfo:
vg_name: str vg_name: str
class _LvmConfigurationSerialization(TypedDict):
config_type: str
vol_groups: list[_LvmVolumeGroupSerialization]
@dataclass @dataclass
class LvmConfiguration: class LvmConfiguration:
config_type: LvmLayoutType config_type: LvmLayoutType
@ -1144,18 +1203,19 @@ class LvmConfiguration:
raise ValueError('A PV cannot be used in multiple volume groups') raise ValueError('A PV cannot be used in multiple volume groups')
pvs.append(pv) pvs.append(pv)
def json(self) -> dict[str, Any]: def json(self) -> _LvmConfigurationSerialization:
return { return {
'config_type': self.config_type.value, 'config_type': self.config_type.value,
'vol_groups': [vol_gr.json() for vol_gr in self.vol_groups] 'vol_groups': [vol_gr.json() for vol_gr in self.vol_groups]
} }
@staticmethod @staticmethod
def parse_arg(arg: dict[str, Any], disk_config: DiskLayoutConfiguration) -> LvmConfiguration: def parse_arg(arg: _LvmConfigurationSerialization, disk_config: DiskLayoutConfiguration) -> LvmConfiguration:
lvm_pvs = [] lvm_pvs = []
for mod in disk_config.device_modifications: for mod in disk_config.device_modifications:
for part in mod.partitions: for part in mod.partitions:
if part.obj_id in arg.get('lvm_pvs', []): # FIXME: 'lvm_pvs' does not seem like it can ever exist in the 'arg' serialization
if part.obj_id in arg.get('lvm_pvs', []): # type: ignore[operator]
lvm_pvs.append(part) lvm_pvs.append(part)
return LvmConfiguration( return LvmConfiguration(
@ -1196,6 +1256,12 @@ class LvmConfiguration:
# if vg.contains_lv(lv): # if vg.contains_lv(lv):
class _DeviceModificationSerialization(TypedDict):
device: str
wipe: bool
partitions: list[_PartitionModificationSerialization]
@dataclass @dataclass
class DeviceModification: class DeviceModification:
device: BDevice device: BDevice
@ -1235,7 +1301,7 @@ class DeviceModification:
filtered = filter(lambda x: x.is_root(), self.partitions) filtered = filter(lambda x: x.is_root(), self.partitions)
return next(filtered, None) return next(filtered, None)
def json(self) -> dict[str, Any]: def json(self) -> _DeviceModificationSerialization:
""" """
Called when generating configuration files Called when generating configuration files
""" """
@ -1273,6 +1339,13 @@ class EncryptionType(Enum):
return type_to_text[type_] return type_to_text[type_]
class _DiskEncryptionSerialization(TypedDict):
encryption_type: str
partitions: list[str]
lvm_volumes: list[str]
hsm_device: NotRequired[_Fido2DeviceSerialization]
@dataclass @dataclass
class DiskEncryption: class DiskEncryption:
encryption_type: EncryptionType = EncryptionType.NoEncryption encryption_type: EncryptionType = EncryptionType.NoEncryption
@ -1295,8 +1368,8 @@ class DiskEncryption:
return dev in self.lvm_volumes and dev.mountpoint != Path('/') return dev in self.lvm_volumes and dev.mountpoint != Path('/')
return False return False
def json(self) -> dict[str, Any]: def json(self) -> _DiskEncryptionSerialization:
obj: dict[str, Any] = { obj: _DiskEncryptionSerialization = {
'encryption_type': self.encryption_type.value, 'encryption_type': self.encryption_type.value,
'partitions': [p.obj_id for p in self.partitions], 'partitions': [p.obj_id for p in self.partitions],
'lvm_volumes': [vol.obj_id for vol in self.lvm_volumes] 'lvm_volumes': [vol.obj_id for vol in self.lvm_volumes]
@ -1325,7 +1398,7 @@ class DiskEncryption:
def parse_arg( def parse_arg(
cls, cls,
disk_config: DiskLayoutConfiguration, disk_config: DiskLayoutConfiguration,
disk_encryption: dict[str, Any], disk_encryption: _DiskEncryptionSerialization,
password: str = '' password: str = ''
) -> 'DiskEncryption | None': ) -> 'DiskEncryption | None':
if not cls.validate_enc(disk_config): if not cls.validate_enc(disk_config):
@ -1359,13 +1432,19 @@ class DiskEncryption:
return enc return enc
class _Fido2DeviceSerialization(TypedDict):
path: str
manufacturer: str
product: str
@dataclass @dataclass
class Fido2Device: class Fido2Device:
path: Path path: Path
manufacturer: str manufacturer: str
product: str product: str
def json(self) -> dict[str, str]: def json(self) -> _Fido2DeviceSerialization:
return { return {
'path': str(self.path), 'path': str(self.path),
'manufacturer': self.manufacturer, 'manufacturer': self.manufacturer,
@ -1380,7 +1459,7 @@ class Fido2Device:
} }
@classmethod @classmethod
def parse_arg(cls, arg: dict[str, str]) -> 'Fido2Device': def parse_arg(cls, arg: _Fido2DeviceSerialization) -> 'Fido2Device':
return Fido2Device( return Fido2Device(
Path(arg['path']), Path(arg['path']),
arg['manufacturer'], arg['manufacturer'],