Parse lsblk data with pydantic (#2775)

This commit is contained in:
codefiles 2024-11-17 15:42:09 -05:00 committed by GitHub
parent 74fd463873
commit 8fc3dc4358
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 54 additions and 124 deletions

View File

@ -32,7 +32,6 @@ from .device_model import (
DiskEncryption, DiskEncryption,
Fido2Device, Fido2Device,
LsblkInfo, LsblkInfo,
CleanType,
get_lsblk_info, get_lsblk_info,
get_all_lsblk_info, get_all_lsblk_info,
get_lsblk_by_mountpoint, get_lsblk_by_mountpoint,

View File

@ -219,13 +219,20 @@ class DeviceHandler:
debug(f'Failed to read btrfs subvolume information: {err}') debug(f'Failed to read btrfs subvolume information: {err}')
return subvol_infos return subvol_infos
# It is assumed that lsblk will contain the fields as
# "mountpoints": ["/mnt/archinstall/log", "/mnt/archinstall/home", "/mnt/archinstall", ...]
# "fsroots": ["/@log", "/@home", "/@"...]
# we'll thereby map the fsroot, which are the mounted filesystem roots
# to the corresponding mountpoints
btrfs_subvol_info = dict(zip(lsblk_info.fsroots, lsblk_info.mountpoints))
try: try:
# ID 256 gen 16 top level 5 path @ # ID 256 gen 16 top level 5 path @
for line in result.splitlines(): for line in result.splitlines():
# expected output format: # expected output format:
# ID 257 gen 8 top level 5 path @home # ID 257 gen 8 top level 5 path @home
name = Path(line.split(' ')[-1]) name = Path(line.split(' ')[-1])
sub_vol_mountpoint = lsblk_info.btrfs_subvol_info.get(name, None) sub_vol_mountpoint = btrfs_subvol_info.get(name, None)
subvol_infos.append(_BtrfsSubvolumeInfo(name, sub_vol_mountpoint)) subvol_infos.append(_BtrfsSubvolumeInfo(name, sub_vol_mountpoint))
except json.decoder.JSONDecodeError as err: except json.decoder.JSONDecodeError as err:
error(f"Could not decode lsblk JSON: {result}") error(f"Could not decode lsblk JSON: {result}")

View File

@ -1,18 +1,17 @@
from __future__ import annotations from __future__ import annotations
import dataclasses
import json import json
import math import math
import uuid import uuid
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import Enum from enum import Enum
from enum import auto
from pathlib import Path from pathlib import Path
from typing import Optional, List, Dict, TYPE_CHECKING, Any from typing import Optional, List, Dict, TYPE_CHECKING, Any
from typing import Union from typing import Union
import parted import parted
from parted import Disk, Geometry, Partition from parted import Disk, Geometry, Partition
from pydantic import BaseModel, Field, ValidationInfo, field_serializer, field_validator
from ..exceptions import DiskError, SysCallError from ..exceptions import DiskError, SysCallError
from ..general import SysCommand from ..general import SysCommand
@ -1311,129 +1310,56 @@ class Fido2Device:
) )
@dataclass class LsblkInfo(BaseModel):
class LsblkInfo: name: str
name: str = '' path: Path
path: Path = Path() pkname: str | None
pkname: str = '' log_sec: int = Field(alias='log-sec')
size: Size = field(default_factory=lambda: Size(0, Unit.B, SectorSize.default())) size: Size
log_sec: int = 0 pttype: str | None
pttype: str = '' ptuuid: str | None
ptuuid: str = '' rota: bool
rota: bool = False tran: str | None
tran: Optional[str] = None partn: int | None
partn: Optional[int] = None partuuid: str | None
partuuid: Optional[str] = None parttype: str | None
parttype: Optional[str] = None uuid: str | None
uuid: Optional[str] = None fstype: str | None
fstype: Optional[str] = None fsver: str | None
fsver: Optional[str] = None fsavail: int | None
fsavail: Optional[str] = None fsuse_percentage: str | None = Field(alias='fsuse%')
fsuse_percentage: Optional[str] = None type: str
type: Optional[str] = None mountpoint: Path | None
mountpoint: Optional[Path] = None mountpoints: list[Path]
mountpoints: List[Path] = field(default_factory=list) fsroots: list[Path]
fsroots: List[Path] = field(default_factory=list) children: list[LsblkInfo] = Field(default_factory=list)
children: List[LsblkInfo] = field(default_factory=list)
def json(self) -> Dict[str, Any]: @field_validator('size', mode='before')
return { @classmethod
'name': self.name, def convert_size(cls, v: int, info: ValidationInfo) -> Size:
'path': str(self.path), sector_size = SectorSize(info.data['log_sec'], Unit.B)
'pkname': self.pkname, return Size(v, Unit.B, sector_size)
'size': self.size.format_size(Unit.MiB),
'log_sec': self.log_sec,
'pttype': self.pttype,
'ptuuid': self.ptuuid,
'rota': self.rota,
'tran': self.tran,
'partn': self.partn,
'partuuid': self.partuuid,
'parttype': self.parttype,
'uuid': self.uuid,
'fstype': self.fstype,
'fsver': self.fsver,
'fsavail': self.fsavail,
'fsuse_percentage': self.fsuse_percentage,
'type': self.type,
'mountpoint': str(self.mountpoint) if self.mountpoint else None,
'mountpoints': [str(m) for m in self.mountpoints],
'fsroots': [str(r) for r in self.fsroots],
'children': [c.json() for c in self.children]
}
@property @field_validator('mountpoints', 'fsroots', mode='before')
def btrfs_subvol_info(self) -> Dict[Path, Path]: @classmethod
""" def remove_none(cls, v: list[Path | None]) -> list[Path]:
It is assumed that lsblk will contain the fields as return [item for item in v if item is not None]
"mountpoints": ["/mnt/archinstall/log", "/mnt/archinstall/home", "/mnt/archinstall", ...] @field_serializer('size', when_used='json')
"fsroots": ["/@log", "/@home", "/@"...] def serialize_size(self, size: Size) -> str:
return size.format_size(Unit.MiB)
we'll thereby map the fsroot, which are the mounted filesystem roots
to the corresponding mountpoints
"""
return dict(zip(self.fsroots, self.mountpoints))
@classmethod @classmethod
def exclude(cls) -> List[str]: def fields(cls) -> list[str]:
return ['children'] return [
field.alias or name
@classmethod for name, field in cls.model_fields.items()
def fields(cls) -> List[str]: if name != 'children'
return [f.name for f in dataclasses.fields(LsblkInfo) if f.name not in cls.exclude()] ]
@classmethod
def from_json(cls, blockdevice: Dict[str, Any]) -> LsblkInfo:
lsblk_info = cls()
for f in cls.fields():
lsblk_field = _clean_field(f, CleanType.Blockdevice)
data_field = _clean_field(f, CleanType.Dataclass)
val: Any = None
if isinstance(getattr(lsblk_info, data_field), Path):
val = Path(blockdevice[lsblk_field])
elif isinstance(getattr(lsblk_info, data_field), Size):
sector_size = SectorSize(blockdevice['log-sec'], Unit.B)
val = Size(blockdevice[lsblk_field], Unit.B, sector_size)
else:
val = blockdevice[lsblk_field]
setattr(lsblk_info, data_field, val)
lsblk_info.children = [LsblkInfo.from_json(child) for child in blockdevice.get('children', [])]
lsblk_info.mountpoint = Path(lsblk_info.mountpoint) if lsblk_info.mountpoint else None
# sometimes lsblk returns 'mountpoints': [null]
lsblk_info.mountpoints = [Path(mnt) for mnt in lsblk_info.mountpoints if mnt]
fs_roots = []
for r in lsblk_info.fsroots:
if r:
path = Path(r)
# store the fsroot entries without the leading /
fs_roots.append(path.relative_to(path.anchor))
lsblk_info.fsroots = fs_roots
return lsblk_info
class CleanType(Enum): class LsblkOutput(BaseModel):
Blockdevice = auto() blockdevices: list[LsblkInfo]
Dataclass = auto()
Lsblk = auto()
def _clean_field(name: str, clean_type: CleanType) -> str:
match clean_type:
case CleanType.Blockdevice:
return name.replace('_percentage', '%').replace('_', '-')
case CleanType.Dataclass:
return name.lower().replace('-', '_').replace('%', '_percentage')
case CleanType.Lsblk:
return name.replace('_percentage', '%').replace('_', '-')
def _fetch_lsblk_info( def _fetch_lsblk_info(
@ -1441,8 +1367,7 @@ def _fetch_lsblk_info(
reverse: bool = False, reverse: bool = False,
full_dev_path: bool = False full_dev_path: bool = False
) -> List[LsblkInfo]: ) -> List[LsblkInfo]:
fields = [_clean_field(f, CleanType.Lsblk) for f in LsblkInfo.fields()] cmd = ['lsblk', '--json', '--bytes', '--output', ','.join(LsblkInfo.fields())]
cmd = ['lsblk', '--json', '--bytes', '--output', ','.join(fields)]
if reverse: if reverse:
cmd.append('--inverse') cmd.append('--inverse')
@ -1472,8 +1397,7 @@ def _fetch_lsblk_info(
error(f"Could not decode lsblk JSON:\n{worker.output().decode().rstrip()}") error(f"Could not decode lsblk JSON:\n{worker.output().decode().rstrip()}")
raise err raise err
blockdevices = data['blockdevices'] return LsblkOutput(**data).blockdevices
return [LsblkInfo.from_json(device) for device in blockdevices]
def get_lsblk_info( def get_lsblk_info(