Use TypedDict to annotate disk-related serializations (#2935)

This commit is contained in:
correctmost 2024-11-30 06:56:46 -05:00 committed by GitHub
parent 11f8490b59
commit 007f2ff797
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 107 additions and 28 deletions

View File

@ -5,7 +5,7 @@ import uuid
from dataclasses import dataclass, field
from enum import Enum
from pathlib import Path
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, NotRequired, TypedDict
import parted
from parted import Disk, Geometry, Partition
@ -39,6 +39,13 @@ class DiskLayoutType(Enum):
return str(_('Pre-mounted configuration'))
class _DiskLayoutConfigurationSerialization(TypedDict):
config_type: str
device_modifications: NotRequired[list[_DeviceModificationSerialization]]
lvm_config: NotRequired[_LvmConfigurationSerialization]
mountpoint: NotRequired[str]
@dataclass
class DiskLayoutConfiguration:
config_type: DiskLayoutType
@ -48,14 +55,14 @@ class DiskLayoutConfiguration:
# used for pre-mounted config
mountpoint: Path | None = None
def json(self) -> dict[str, Any]:
def json(self) -> _DiskLayoutConfigurationSerialization:
if self.config_type == DiskLayoutType.Pre_mount:
return {
'config_type': self.config_type.value,
'mountpoint': str(self.mountpoint)
}
else:
config: dict[str, Any] = {
config: _DiskLayoutConfigurationSerialization = {
'config_type': self.config_type.value,
'device_modifications': [mod.json() for mod in self.device_modifications],
}
@ -66,7 +73,7 @@ class DiskLayoutConfiguration:
return config
@classmethod
def parse_arg(cls, disk_config: dict[str, Any]) -> DiskLayoutConfiguration | None:
def parse_arg(cls, disk_config: _DiskLayoutConfigurationSerialization) -> DiskLayoutConfiguration | None:
from .device_handler import device_handler
device_modifications: list[DeviceModification] = []
@ -96,7 +103,7 @@ class DiskLayoutConfiguration:
return config
for entry in disk_config.get('device_modifications', []):
device_path = Path(entry.get('device', None)) if entry.get('device', None) else None
device_path = Path(entry['device']) if entry.get('device', None) else None
if not device_path:
continue
@ -190,6 +197,11 @@ class Unit(Enum):
return [u for u in Unit if 'i' in u.name or u.name == 'B']
class _SectorSizeSerialization(TypedDict):
value: int
unit: str
@dataclass
class SectorSize:
value: int
@ -204,14 +216,14 @@ class SectorSize:
def default() -> SectorSize:
return SectorSize(512, Unit.B)
def json(self) -> dict[str, Any]:
def json(self) -> _SectorSizeSerialization:
return {
'value': self.value,
'unit': self.unit.name,
}
@classmethod
def parse_args(cls, arg: dict[str, Any]) -> SectorSize:
def parse_args(cls, arg: _SectorSizeSerialization) -> SectorSize:
return SectorSize(
arg['value'],
Unit[arg['unit']]
@ -224,6 +236,12 @@ class SectorSize:
return int(self.value * self.unit.value)
class _SizeSerialization(TypedDict):
value: int
unit: str
sector_size: _SectorSizeSerialization
@dataclass
class Size:
value: int
@ -234,15 +252,15 @@ class Size:
if not isinstance(self.sector_size, SectorSize):
raise ValueError('sector size must be of type SectorSize')
def json(self) -> dict[str, Any]:
def json(self) -> _SizeSerialization:
return {
'value': self.value,
'unit': self.unit.name,
'sector_size': self.sector_size.json() if self.sector_size else None
'sector_size': self.sector_size.json()
}
@classmethod
def parse_args(cls, size_arg: dict[str, Any]) -> Size:
def parse_args(cls, size_arg: _SizeSerialization) -> Size:
sector_size = size_arg['sector_size']
return Size(
@ -516,9 +534,14 @@ class _DeviceInfo:
)
class _SubvolumeModificationSerialization(TypedDict):
name: str
mountpoint: str
@dataclass
class SubvolumeModification:
name: Path
name: Path | str
mountpoint: Path | None = None
@classmethod
@ -526,7 +549,7 @@ class SubvolumeModification:
return SubvolumeModification(info.name, mountpoint=info.mountpoint)
@classmethod
def parse_args(cls, subvol_args: list[dict[str, Any]]) -> list[SubvolumeModification]:
def parse_args(cls, subvol_args: list[_SubvolumeModificationSerialization]) -> list[SubvolumeModification]:
mods = []
for entry in subvol_args:
if not entry.get('name', None) or not entry.get('mountpoint', None):
@ -555,10 +578,10 @@ class SubvolumeModification:
return self.mountpoint == Path('/')
return False
def json(self) -> dict[str, Any]:
def json(self) -> _SubvolumeModificationSerialization:
return {'name': str(self.name), 'mountpoint': str(self.mountpoint)}
def table_data(self) -> dict[str, Any]:
def table_data(self) -> _SubvolumeModificationSerialization:
return self.json()
@ -738,6 +761,20 @@ class ModificationStatus(Enum):
Create = 'create'
class _PartitionModificationSerialization(TypedDict):
obj_id: str
status: str
type: str
start: _SizeSerialization
size: _SizeSerialization
fs_type: str | None
mountpoint: str | None
mount_options: list[str]
flags: list[str]
btrfs: list[_SubvolumeModificationSerialization]
dev_path: str | None
@dataclass
class PartitionModification:
status: ModificationStatus
@ -891,7 +928,7 @@ class PartitionModification:
else:
self.set_flag(flag)
def json(self) -> dict[str, Any]:
def json(self) -> _PartitionModificationSerialization:
"""
Called for configuration settings
"""
@ -947,13 +984,19 @@ class LvmLayoutType(Enum):
raise ValueError(f'Unknown type: {self}')
class _LvmVolumeGroupSerialization(TypedDict):
name: str
lvm_pvs: list[str]
volumes: list[_LvmVolumeSerialization]
@dataclass
class LvmVolumeGroup:
name: str
pvs: list[PartitionModification]
volumes: list[LvmVolume] = field(default_factory=list)
def json(self) -> dict[str, Any]:
def json(self) -> _LvmVolumeGroupSerialization:
return {
'name': self.name,
'lvm_pvs': [p.obj_id for p in self.pvs],
@ -961,7 +1004,7 @@ class LvmVolumeGroup:
}
@staticmethod
def parse_arg(arg: dict[str, Any], disk_config: DiskLayoutConfiguration) -> LvmVolumeGroup:
def parse_arg(arg: _LvmVolumeGroupSerialization, disk_config: DiskLayoutConfiguration) -> LvmVolumeGroup:
lvm_pvs = []
for mod in disk_config.device_modifications:
for part in mod.partitions:
@ -985,6 +1028,17 @@ class LvmVolumeStatus(Enum):
Create = 'create'
class _LvmVolumeSerialization(TypedDict):
obj_id: str
status: str
name: str
fs_type: str
length: _SizeSerialization
mountpoint: str | None
mount_options: list[str]
btrfs: list[_SubvolumeModificationSerialization]
@dataclass
class LvmVolume:
status: LvmVolumeStatus
@ -1051,7 +1105,7 @@ class LvmVolume:
raise ValueError('Mountpoint is not specified')
@staticmethod
def parse_arg(arg: dict[str, Any]) -> LvmVolume:
def parse_arg(arg: _LvmVolumeSerialization) -> LvmVolume:
volume = LvmVolume(
status=LvmVolumeStatus(arg['status']),
name=arg['name'],
@ -1066,7 +1120,7 @@ class LvmVolume:
return volume
def json(self) -> dict[str, Any]:
def json(self) -> _LvmVolumeSerialization:
return {
'obj_id': self.obj_id,
'status': self.status.value,
@ -1130,6 +1184,11 @@ class LvmPVInfo:
vg_name: str
class _LvmConfigurationSerialization(TypedDict):
config_type: str
vol_groups: list[_LvmVolumeGroupSerialization]
@dataclass
class LvmConfiguration:
config_type: LvmLayoutType
@ -1144,18 +1203,19 @@ class LvmConfiguration:
raise ValueError('A PV cannot be used in multiple volume groups')
pvs.append(pv)
def json(self) -> dict[str, Any]:
def json(self) -> _LvmConfigurationSerialization:
return {
'config_type': self.config_type.value,
'vol_groups': [vol_gr.json() for vol_gr in self.vol_groups]
}
@staticmethod
def parse_arg(arg: dict[str, Any], disk_config: DiskLayoutConfiguration) -> LvmConfiguration:
def parse_arg(arg: _LvmConfigurationSerialization, disk_config: DiskLayoutConfiguration) -> LvmConfiguration:
lvm_pvs = []
for mod in disk_config.device_modifications:
for part in mod.partitions:
if part.obj_id in arg.get('lvm_pvs', []):
# FIXME: 'lvm_pvs' does not seem like it can ever exist in the 'arg' serialization
if part.obj_id in arg.get('lvm_pvs', []): # type: ignore[operator]
lvm_pvs.append(part)
return LvmConfiguration(
@ -1196,6 +1256,12 @@ class LvmConfiguration:
# if vg.contains_lv(lv):
class _DeviceModificationSerialization(TypedDict):
device: str
wipe: bool
partitions: list[_PartitionModificationSerialization]
@dataclass
class DeviceModification:
device: BDevice
@ -1235,7 +1301,7 @@ class DeviceModification:
filtered = filter(lambda x: x.is_root(), self.partitions)
return next(filtered, None)
def json(self) -> dict[str, Any]:
def json(self) -> _DeviceModificationSerialization:
"""
Called when generating configuration files
"""
@ -1273,6 +1339,13 @@ class EncryptionType(Enum):
return type_to_text[type_]
class _DiskEncryptionSerialization(TypedDict):
encryption_type: str
partitions: list[str]
lvm_volumes: list[str]
hsm_device: NotRequired[_Fido2DeviceSerialization]
@dataclass
class DiskEncryption:
encryption_type: EncryptionType = EncryptionType.NoEncryption
@ -1295,8 +1368,8 @@ class DiskEncryption:
return dev in self.lvm_volumes and dev.mountpoint != Path('/')
return False
def json(self) -> dict[str, Any]:
obj: dict[str, Any] = {
def json(self) -> _DiskEncryptionSerialization:
obj: _DiskEncryptionSerialization = {
'encryption_type': self.encryption_type.value,
'partitions': [p.obj_id for p in self.partitions],
'lvm_volumes': [vol.obj_id for vol in self.lvm_volumes]
@ -1325,7 +1398,7 @@ class DiskEncryption:
def parse_arg(
cls,
disk_config: DiskLayoutConfiguration,
disk_encryption: dict[str, Any],
disk_encryption: _DiskEncryptionSerialization,
password: str = ''
) -> 'DiskEncryption | None':
if not cls.validate_enc(disk_config):
@ -1359,13 +1432,19 @@ class DiskEncryption:
return enc
class _Fido2DeviceSerialization(TypedDict):
path: str
manufacturer: str
product: str
@dataclass
class Fido2Device:
path: Path
manufacturer: str
product: str
def json(self) -> dict[str, str]:
def json(self) -> _Fido2DeviceSerialization:
return {
'path': str(self.path),
'manufacturer': self.manufacturer,
@ -1380,7 +1459,7 @@ class Fido2Device:
}
@classmethod
def parse_arg(cls, arg: dict[str, str]) -> 'Fido2Device':
def parse_arg(cls, arg: _Fido2DeviceSerialization) -> 'Fido2Device':
return Fido2Device(
Path(arg['path']),
arg['manufacturer'],