Remove remaining Optional and Union usage from the codebase (#2868)

This commit is contained in:
correctmost 2024-11-18 04:59:08 -05:00 committed by GitHub
parent 97d6d84c3c
commit 80b4dab092
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 63 additions and 65 deletions

View File

@ -7,7 +7,7 @@ import time
import uuid import uuid
from collections.abc import Iterable from collections.abc import Iterable
from pathlib import Path from pathlib import Path
from typing import Any, Optional, TYPE_CHECKING, Literal from typing import Any, TYPE_CHECKING, Literal
from parted import ( from parted import (
Disk, Geometry, FileSystem, Disk, Geometry, FileSystem,
@ -138,8 +138,8 @@ class DeviceHandler:
def _determine_fs_type( def _determine_fs_type(
self, self,
partition: Partition, partition: Partition,
lsblk_info: Optional[LsblkInfo] = None lsblk_info: LsblkInfo | None = None
) -> Optional[FilesystemType]: ) -> FilesystemType | None:
try: try:
if partition.fileSystem: if partition.fileSystem:
return FilesystemType(partition.fileSystem.type) return FilesystemType(partition.fileSystem.type)
@ -151,17 +151,17 @@ class DeviceHandler:
return None return None
def get_device(self, path: Path) -> Optional[BDevice]: def get_device(self, path: Path) -> BDevice | None:
return self._devices.get(path, None) return self._devices.get(path, None)
def get_device_by_partition_path(self, partition_path: Path) -> Optional[BDevice]: def get_device_by_partition_path(self, partition_path: Path) -> BDevice | None:
partition = self.find_partition(partition_path) partition = self.find_partition(partition_path)
if partition: if partition:
device: Device = partition.disk.device device: Device = partition.disk.device
return self.get_device(Path(device.path)) return self.get_device(Path(device.path))
return None return None
def find_partition(self, path: Path) -> Optional[_PartitionInfo]: def find_partition(self, path: Path) -> _PartitionInfo | None:
for device in self._devices.values(): for device in self._devices.values():
part = next(filter(lambda x: str(x.path) == str(path), device.partition_infos), None) part = next(filter(lambda x: str(x.path) == str(path), device.partition_infos), None)
if part is not None: if part is not None:
@ -172,7 +172,7 @@ class DeviceHandler:
lsblk = get_lsblk_info(dev_path) lsblk = get_lsblk_info(dev_path)
return Path(f'/dev/{lsblk.pkname}') return Path(f'/dev/{lsblk.pkname}')
def get_unique_path_for_device(self, dev_path: Path) -> Optional[Path]: def get_unique_path_for_device(self, dev_path: Path) -> Path | None:
paths = Path('/dev/disk/by-id').glob('*') paths = Path('/dev/disk/by-id').glob('*')
linked_targets = {p.resolve(): p for p in paths} linked_targets = {p.resolve(): p for p in paths}
linked_wwn_targets = { linked_wwn_targets = {
@ -188,14 +188,14 @@ class DeviceHandler:
return None return None
def get_uuid_for_path(self, path: Path) -> Optional[str]: def get_uuid_for_path(self, path: Path) -> str | None:
partition = self.find_partition(path) partition = self.find_partition(path)
return partition.partuuid if partition else None return partition.partuuid if partition else None
def get_btrfs_info( def get_btrfs_info(
self, self,
dev_path: Path, dev_path: Path,
lsblk_info: Optional[LsblkInfo] = None lsblk_info: LsblkInfo | None = None
) -> list[_BtrfsSubvolumeInfo]: ) -> list[_BtrfsSubvolumeInfo]:
if not lsblk_info: if not lsblk_info:
lsblk_info = get_lsblk_info(dev_path) lsblk_info = get_lsblk_info(dev_path)
@ -286,7 +286,7 @@ class DeviceHandler:
def encrypt( def encrypt(
self, self,
dev_path: Path, dev_path: Path,
mapper_name: Optional[str], mapper_name: str | None,
enc_password: str, enc_password: str,
lock_after_create: bool = True lock_after_create: bool = True
) -> Luks2: ) -> Luks2:
@ -312,7 +312,7 @@ class DeviceHandler:
def format_encrypted( def format_encrypted(
self, self,
dev_path: Path, dev_path: Path,
mapper_name: Optional[str], mapper_name: str | None,
fs_type: FilesystemType, fs_type: FilesystemType,
enc_conf: DiskEncryption enc_conf: DiskEncryption
) -> None: ) -> None:
@ -339,7 +339,7 @@ class DeviceHandler:
self, self,
cmd: str, cmd: str,
info_type: Literal['lv', 'vg', 'pvseg'] info_type: Literal['lv', 'vg', 'pvseg']
) -> Optional[Any]: ) -> Any | None:
raw_info = SysCommand(cmd).decode().split('\n') raw_info = SysCommand(cmd).decode().split('\n')
# for whatever reason the output sometimes contains # for whatever reason the output sometimes contains
@ -377,14 +377,14 @@ class DeviceHandler:
return None return None
def _lvm_info_with_retry(self, cmd: str, info_type: Literal['lv', 'vg', 'pvseg']) -> Optional[Any]: def _lvm_info_with_retry(self, cmd: str, info_type: Literal['lv', 'vg', 'pvseg']) -> Any | None:
while True: while True:
try: try:
return self._lvm_info(cmd, info_type) return self._lvm_info(cmd, info_type)
except ValueError: except ValueError:
time.sleep(3) time.sleep(3)
def lvm_vol_info(self, lv_name: str) -> Optional[LvmVolumeInfo]: def lvm_vol_info(self, lv_name: str) -> LvmVolumeInfo | None:
cmd = ( cmd = (
'lvs --reportformat json ' 'lvs --reportformat json '
'--unit B ' '--unit B '
@ -393,7 +393,7 @@ class DeviceHandler:
return self._lvm_info_with_retry(cmd, 'lv') return self._lvm_info_with_retry(cmd, 'lv')
def lvm_group_info(self, vg_name: str) -> Optional[LvmGroupInfo]: def lvm_group_info(self, vg_name: str) -> LvmGroupInfo | None:
cmd = ( cmd = (
'vgs --reportformat json ' 'vgs --reportformat json '
'--unit B ' '--unit B '
@ -403,7 +403,7 @@ class DeviceHandler:
return self._lvm_info_with_retry(cmd, 'vg') return self._lvm_info_with_retry(cmd, 'vg')
def lvm_pvseg_info(self, vg_name: str, lv_name: str) -> Optional[LvmPVInfo]: def lvm_pvseg_info(self, vg_name: str, lv_name: str) -> LvmPVInfo | None:
cmd = ( cmd = (
'pvs ' 'pvs '
'--segments -o+lv_name,vg_name ' '--segments -o+lv_name,vg_name '
@ -457,7 +457,7 @@ class DeviceHandler:
worker.poll() worker.poll()
worker.write(b'y\n', line_ending=False) worker.write(b'y\n', line_ending=False)
def lvm_vol_create(self, vg_name: str, volume: LvmVolume, offset: Optional[Size] = None) -> None: def lvm_vol_create(self, vg_name: str, volume: LvmVolume, offset: Size | None = None) -> None:
if offset is not None: if offset is not None:
length = volume.length - offset length = volume.length - offset
else: else:
@ -593,7 +593,7 @@ class DeviceHandler:
def create_btrfs_volumes( def create_btrfs_volumes(
self, self,
part_mod: PartitionModification, part_mod: PartitionModification,
enc_conf: Optional['DiskEncryption'] = None enc_conf: 'DiskEncryption | None' = None
) -> None: ) -> None:
info(f'Creating subvolumes: {part_mod.safe_dev_path}') info(f'Creating subvolumes: {part_mod.safe_dev_path}')
@ -663,7 +663,7 @@ class DeviceHandler:
def partition( def partition(
self, self,
modification: DeviceModification, modification: DeviceModification,
partition_table: Optional[PartitionTable] = None partition_table: PartitionTable | None = None
) -> None: ) -> None:
""" """
Create a partition table on the block device and create all partitions. Create a partition table on the block device and create all partitions.
@ -701,7 +701,7 @@ class DeviceHandler:
self, self,
dev_path: Path, dev_path: Path,
target_mountpoint: Path, target_mountpoint: Path,
mount_fs: Optional[str] = None, mount_fs: str | None = None,
create_target_mountpoint: bool = True, create_target_mountpoint: bool = True,
options: list[str] = [] options: list[str] = []
) -> None: ) -> None:
@ -777,7 +777,7 @@ class DeviceHandler:
return device_mods return device_mods
def partprobe(self, path: Optional[Path] = None) -> None: def partprobe(self, path: Path | None = None) -> None:
if path is not None: if path is not None:
command = f'partprobe {path}' command = f'partprobe {path}'
else: else:

View File

@ -6,8 +6,7 @@ import uuid
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
from typing import Optional, TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
from typing import Union
import parted import parted
from parted import Disk, Geometry, Partition from parted import Disk, Geometry, Partition
@ -41,10 +40,10 @@ class DiskLayoutType(Enum):
class DiskLayoutConfiguration: class DiskLayoutConfiguration:
config_type: DiskLayoutType config_type: DiskLayoutType
device_modifications: list[DeviceModification] = field(default_factory=list) device_modifications: list[DeviceModification] = field(default_factory=list)
lvm_config: Optional[LvmConfiguration] = None lvm_config: LvmConfiguration | None = None
# used for pre-mounted config # used for pre-mounted config
mountpoint: Optional[Path] = None mountpoint: Path | None = None
def json(self) -> dict[str, Any]: def json(self) -> dict[str, Any]:
if self.config_type == DiskLayoutType.Pre_mount: if self.config_type == DiskLayoutType.Pre_mount:
@ -64,7 +63,7 @@ class DiskLayoutConfiguration:
return config return config
@classmethod @classmethod
def parse_arg(cls, disk_config: dict[str, Any]) -> Optional[DiskLayoutConfiguration]: def parse_arg(cls, disk_config: dict[str, Any]) -> DiskLayoutConfiguration | None:
from .device_handler import device_handler from .device_handler import device_handler
device_modifications: list[DeviceModification] = [] device_modifications: list[DeviceModification] = []
@ -238,7 +237,7 @@ class Size:
def convert( def convert(
self, self,
target_unit: Unit, target_unit: Unit,
sector_size: Optional[SectorSize] = None sector_size: SectorSize | None = None
) -> Size: ) -> Size:
if target_unit == Unit.sectors and sector_size is None: if target_unit == Unit.sectors and sector_size is None:
raise ValueError('If target has unit sector, a sector size must be provided') raise ValueError('If target has unit sector, a sector size must be provided')
@ -266,7 +265,7 @@ class Size:
def format_size( def format_size(
self, self,
target_unit: Unit, target_unit: Unit,
sector_size: Optional[SectorSize] = None, sector_size: SectorSize | None = None,
include_unit: bool = True include_unit: bool = True
) -> str: ) -> str:
target_size = self.convert(target_unit, sector_size) target_size = self.convert(target_unit, sector_size)
@ -333,7 +332,7 @@ class BtrfsMountOption(Enum):
@dataclass @dataclass
class _BtrfsSubvolumeInfo: class _BtrfsSubvolumeInfo:
name: Path name: Path
mountpoint: Optional[Path] mountpoint: Path | None
@dataclass @dataclass
@ -341,14 +340,14 @@ class _PartitionInfo:
partition: Partition partition: Partition
name: str name: str
type: PartitionType type: PartitionType
fs_type: Optional[FilesystemType] fs_type: FilesystemType | None
path: Path path: Path
start: Size start: Size
length: Size length: Size
flags: list[PartitionFlag] flags: list[PartitionFlag]
partn: Optional[int] partn: int | None
partuuid: Optional[str] partuuid: str | None
uuid: Optional[str] uuid: str | None
disk: Disk disk: Disk
mountpoints: list[Path] mountpoints: list[Path]
btrfs_subvol_infos: list[_BtrfsSubvolumeInfo] = field(default_factory=list) btrfs_subvol_infos: list[_BtrfsSubvolumeInfo] = field(default_factory=list)
@ -381,10 +380,10 @@ class _PartitionInfo:
def from_partition( def from_partition(
cls, cls,
partition: Partition, partition: Partition,
fs_type: Optional[FilesystemType], fs_type: FilesystemType | None,
partn: Optional[int], partn: int | None,
partuuid: Optional[str], partuuid: str | None,
uuid: Optional[str], uuid: str | None,
mountpoints: list[Path], mountpoints: list[Path],
btrfs_subvol_infos: list[_BtrfsSubvolumeInfo] = [] btrfs_subvol_infos: list[_BtrfsSubvolumeInfo] = []
) -> _PartitionInfo: ) -> _PartitionInfo:
@ -473,7 +472,7 @@ class _DeviceInfo:
@dataclass @dataclass
class SubvolumeModification: class SubvolumeModification:
name: Path name: Path
mountpoint: Optional[Path] = None mountpoint: Path | None = None
@classmethod @classmethod
def from_existing_subvol_info(cls, info: _BtrfsSubvolumeInfo) -> SubvolumeModification: def from_existing_subvol_info(cls, info: _BtrfsSubvolumeInfo) -> SubvolumeModification:
@ -572,7 +571,7 @@ class PartitionType(Enum):
debug(f'Partition code not supported: {code}') debug(f'Partition code not supported: {code}')
return PartitionType._Unknown return PartitionType._Unknown
def get_partition_code(self) -> Optional[int]: def get_partition_code(self) -> int | None:
if self == PartitionType.Primary: if self == PartitionType.Primary:
return parted.PARTITION_NORMAL return parted.PARTITION_NORMAL
elif self == PartitionType.Boot: elif self == PartitionType.Boot:
@ -623,7 +622,7 @@ class FilesystemType(Enum):
return self.value return self.value
@property @property
def installation_pkg(self) -> Optional[str]: def installation_pkg(self) -> str | None:
match self: match self:
case FilesystemType.Btrfs: case FilesystemType.Btrfs:
return 'btrfs-progs' return 'btrfs-progs'
@ -635,7 +634,7 @@ class FilesystemType(Enum):
return None return None
@property @property
def installation_module(self) -> Optional[str]: def installation_module(self) -> str | None:
match self: match self:
case FilesystemType.Btrfs: case FilesystemType.Btrfs:
return 'btrfs' return 'btrfs'
@ -643,7 +642,7 @@ class FilesystemType(Enum):
return None return None
@property @property
def installation_binary(self) -> Optional[str]: def installation_binary(self) -> str | None:
match self: match self:
case FilesystemType.Btrfs: case FilesystemType.Btrfs:
return '/usr/bin/btrfs' return '/usr/bin/btrfs'
@ -651,7 +650,7 @@ class FilesystemType(Enum):
return None return None
@property @property
def installation_hooks(self) -> Optional[str]: def installation_hooks(self) -> str | None:
match self: match self:
case FilesystemType.Btrfs: case FilesystemType.Btrfs:
return 'btrfs' return 'btrfs'
@ -672,17 +671,17 @@ class PartitionModification:
type: PartitionType type: PartitionType
start: Size start: Size
length: Size length: Size
fs_type: Optional[FilesystemType] = None fs_type: FilesystemType | None = None
mountpoint: Optional[Path] = None mountpoint: Path | None = None
mount_options: list[str] = field(default_factory=list) mount_options: list[str] = field(default_factory=list)
flags: list[PartitionFlag] = field(default_factory=list) flags: list[PartitionFlag] = field(default_factory=list)
btrfs_subvols: list[SubvolumeModification] = field(default_factory=list) btrfs_subvols: list[SubvolumeModification] = field(default_factory=list)
# only set if the device was created or exists # only set if the device was created or exists
dev_path: Optional[Path] = None dev_path: Path | None = None
partn: Optional[int] = None partn: int | None = None
partuuid: Optional[str] = None partuuid: str | None = None
uuid: Optional[str] = None uuid: str | None = None
_efi_indicator_flags = (PartitionFlag.Boot, PartitionFlag.ESP) _efi_indicator_flags = (PartitionFlag.Boot, PartitionFlag.ESP)
_boot_indicator_flags = (PartitionFlag.Boot, PartitionFlag.XBOOTLDR) _boot_indicator_flags = (PartitionFlag.Boot, PartitionFlag.XBOOTLDR)
@ -798,7 +797,7 @@ class PartitionModification:
return self.status in [ModificationStatus.Create, ModificationStatus.Modify] return self.status in [ModificationStatus.Create, ModificationStatus.Modify]
@property @property
def mapper_name(self) -> Optional[str]: def mapper_name(self) -> str | None:
if self.dev_path: if self.dev_path:
return f'{storage.get("ENC_IDENTIFIER", "ai")}{self.dev_path.name}' return f'{storage.get("ENC_IDENTIFIER", "ai")}{self.dev_path.name}'
return None return None
@ -913,14 +912,14 @@ class LvmVolume:
name: str name: str
fs_type: FilesystemType fs_type: FilesystemType
length: Size length: Size
mountpoint: Optional[Path] mountpoint: Path | None
mount_options: list[str] = field(default_factory=list) mount_options: list[str] = field(default_factory=list)
btrfs_subvols: list[SubvolumeModification] = field(default_factory=list) btrfs_subvols: list[SubvolumeModification] = field(default_factory=list)
# volume group name # volume group name
vg_name: Optional[str] = None vg_name: str | None = None
# mapper device path /dev/<vg>/<vol> # mapper device path /dev/<vg>/<vol>
dev_path: Optional[Path] = None dev_path: Path | None = None
def __post_init__(self) -> None: def __post_init__(self) -> None:
# needed to use the object as a dictionary key due to hash func # needed to use the object as a dictionary key due to hash func
@ -937,7 +936,7 @@ class LvmVolume:
return '' return ''
@property @property
def mapper_name(self) -> Optional[str]: def mapper_name(self) -> str | None:
if self.dev_path: if self.dev_path:
return f'{storage.get("ENC_IDENTIFIER", "ai")}{self.safe_dev_path.name}' return f'{storage.get("ENC_IDENTIFIER", "ai")}{self.safe_dev_path.name}'
return None return None
@ -1100,7 +1099,7 @@ class LvmConfiguration:
return volumes return volumes
def get_root_volume(self) -> Optional[LvmVolume]: def get_root_volume(self) -> LvmVolume | None:
for vg in self.vol_groups: for vg in self.vol_groups:
filtered = next(filter(lambda x: x.is_root(), vg.volumes), None) filtered = next(filter(lambda x: x.is_root(), vg.volumes), None)
if filtered: if filtered:
@ -1131,14 +1130,14 @@ class DeviceModification:
def add_partition(self, partition: PartitionModification) -> None: def add_partition(self, partition: PartitionModification) -> None:
self.partitions.append(partition) self.partitions.append(partition)
def get_efi_partition(self) -> Optional[PartitionModification]: def get_efi_partition(self) -> PartitionModification | None:
""" """
Similar to get_boot_partition() but excludes XBOOTLDR partitions from it's candidates. Similar to get_boot_partition() but excludes XBOOTLDR partitions from it's candidates.
""" """
filtered = filter(lambda x: x.is_efi() and x.mountpoint, self.partitions) filtered = filter(lambda x: x.is_efi() and x.mountpoint, self.partitions)
return next(filtered, None) return next(filtered, None)
def get_boot_partition(self) -> Optional[PartitionModification]: def get_boot_partition(self) -> PartitionModification | None:
""" """
Returns the first partition marked as XBOOTLDR (PARTTYPE id of bc13c2ff-...) or Boot and has a mountpoint. Returns the first partition marked as XBOOTLDR (PARTTYPE id of bc13c2ff-...) or Boot and has a mountpoint.
Only returns XBOOTLDR if separate EFI is detected using self.get_efi_partition() Only returns XBOOTLDR if separate EFI is detected using self.get_efi_partition()
@ -1153,7 +1152,7 @@ class DeviceModification:
filtered = filter(lambda x: x.is_boot() and x.mountpoint, self.partitions) filtered = filter(lambda x: x.is_boot() and x.mountpoint, self.partitions)
return next(filtered, None) return next(filtered, None)
def get_root_partition(self) -> Optional[PartitionModification]: def get_root_partition(self) -> PartitionModification | None:
filtered = filter(lambda x: x.is_root(), self.partitions) filtered = filter(lambda x: x.is_root(), self.partitions)
return next(filtered, None) return next(filtered, None)
@ -1201,7 +1200,7 @@ class DiskEncryption:
encryption_password: str = '' encryption_password: str = ''
partitions: list[PartitionModification] = field(default_factory=list) partitions: list[PartitionModification] = field(default_factory=list)
lvm_volumes: list[LvmVolume] = field(default_factory=list) lvm_volumes: list[LvmVolume] = field(default_factory=list)
hsm_device: Optional[Fido2Device] = None hsm_device: Fido2Device | None = None
def __post_init__(self) -> None: def __post_init__(self) -> None:
if self.encryption_type in [EncryptionType.Luks, EncryptionType.LvmOnLuks] and not self.partitions: if self.encryption_type in [EncryptionType.Luks, EncryptionType.LvmOnLuks] and not self.partitions:
@ -1249,7 +1248,7 @@ class DiskEncryption:
disk_config: DiskLayoutConfiguration, disk_config: DiskLayoutConfiguration,
disk_encryption: dict[str, Any], disk_encryption: dict[str, Any],
password: str = '' password: str = ''
) -> Optional['DiskEncryption']: ) -> 'DiskEncryption | None':
if not cls.validate_enc(disk_config): if not cls.validate_enc(disk_config):
return None return None
@ -1363,7 +1362,7 @@ class LsblkOutput(BaseModel):
def _fetch_lsblk_info( def _fetch_lsblk_info(
dev_path: Optional[Union[Path, str]] = None, dev_path: Path | str | None = None,
reverse: bool = False, reverse: bool = False,
full_dev_path: bool = False full_dev_path: bool = False
) -> list[LsblkInfo]: ) -> list[LsblkInfo]:
@ -1401,7 +1400,7 @@ def _fetch_lsblk_info(
def get_lsblk_info( def get_lsblk_info(
dev_path: Union[Path, str], dev_path: Path | str,
reverse: bool = False, reverse: bool = False,
full_dev_path: bool = False full_dev_path: bool = False
) -> LsblkInfo: ) -> LsblkInfo:
@ -1416,9 +1415,9 @@ def get_all_lsblk_info() -> list[LsblkInfo]:
def find_lsblk_info( def find_lsblk_info(
dev_path: Union[Path, str], dev_path: Path | str,
info: list[LsblkInfo] info: list[LsblkInfo]
) -> Optional[LsblkInfo]: ) -> LsblkInfo | None:
if isinstance(dev_path, str): if isinstance(dev_path, str):
dev_path = Path(dev_path) dev_path = Path(dev_path)

View File

@ -194,7 +194,6 @@ select = [
ignore = [ ignore = [
"E722", # bare-except "E722", # bare-except
"PLW2901", # redefined-loop-name "PLW2901", # redefined-loop-name
"UP007", # non-pep604-annotation
"UP027", # unpacked-list-comprehension "UP027", # unpacked-list-comprehension
"UP028", # yield-in-for-loop "UP028", # yield-in-for-loop
"UP031", # printf-string-formatting "UP031", # printf-string-formatting