Parse lsblk data with pydantic (#2775)

This commit is contained in:
codefiles 2024-11-17 15:42:09 -05:00 committed by GitHub
parent 74fd463873
commit 8fc3dc4358
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 54 additions and 124 deletions

View File

@ -32,7 +32,6 @@ from .device_model import (
DiskEncryption,
Fido2Device,
LsblkInfo,
CleanType,
get_lsblk_info,
get_all_lsblk_info,
get_lsblk_by_mountpoint,

View File

@ -219,13 +219,20 @@ class DeviceHandler:
debug(f'Failed to read btrfs subvolume information: {err}')
return subvol_infos
# It is assumed that lsblk will contain the fields as
# "mountpoints": ["/mnt/archinstall/log", "/mnt/archinstall/home", "/mnt/archinstall", ...]
# "fsroots": ["/@log", "/@home", "/@"...]
# we'll thereby map the fsroot, which are the mounted filesystem roots
# to the corresponding mountpoints
btrfs_subvol_info = dict(zip(lsblk_info.fsroots, lsblk_info.mountpoints))
try:
# ID 256 gen 16 top level 5 path @
for line in result.splitlines():
# expected output format:
# ID 257 gen 8 top level 5 path @home
name = Path(line.split(' ')[-1])
sub_vol_mountpoint = lsblk_info.btrfs_subvol_info.get(name, None)
sub_vol_mountpoint = btrfs_subvol_info.get(name, None)
subvol_infos.append(_BtrfsSubvolumeInfo(name, sub_vol_mountpoint))
except json.decoder.JSONDecodeError as err:
error(f"Could not decode lsblk JSON: {result}")

View File

@ -1,18 +1,17 @@
from __future__ import annotations
import dataclasses
import json
import math
import uuid
from dataclasses import dataclass, field
from enum import Enum
from enum import auto
from pathlib import Path
from typing import Optional, List, Dict, TYPE_CHECKING, Any
from typing import Union
import parted
from parted import Disk, Geometry, Partition
from pydantic import BaseModel, Field, ValidationInfo, field_serializer, field_validator
from ..exceptions import DiskError, SysCallError
from ..general import SysCommand
@ -1311,129 +1310,56 @@ class Fido2Device:
)
@dataclass
class LsblkInfo:
name: str = ''
path: Path = Path()
pkname: str = ''
size: Size = field(default_factory=lambda: Size(0, Unit.B, SectorSize.default()))
log_sec: int = 0
pttype: str = ''
ptuuid: str = ''
rota: bool = False
tran: Optional[str] = None
partn: Optional[int] = None
partuuid: Optional[str] = None
parttype: Optional[str] = None
uuid: Optional[str] = None
fstype: Optional[str] = None
fsver: Optional[str] = None
fsavail: Optional[str] = None
fsuse_percentage: Optional[str] = None
type: Optional[str] = None
mountpoint: Optional[Path] = None
mountpoints: List[Path] = field(default_factory=list)
fsroots: List[Path] = field(default_factory=list)
children: List[LsblkInfo] = field(default_factory=list)
class LsblkInfo(BaseModel):
name: str
path: Path
pkname: str | None
log_sec: int = Field(alias='log-sec')
size: Size
pttype: str | None
ptuuid: str | None
rota: bool
tran: str | None
partn: int | None
partuuid: str | None
parttype: str | None
uuid: str | None
fstype: str | None
fsver: str | None
fsavail: int | None
fsuse_percentage: str | None = Field(alias='fsuse%')
type: str
mountpoint: Path | None
mountpoints: list[Path]
fsroots: list[Path]
children: list[LsblkInfo] = Field(default_factory=list)
def json(self) -> Dict[str, Any]:
return {
'name': self.name,
'path': str(self.path),
'pkname': self.pkname,
'size': self.size.format_size(Unit.MiB),
'log_sec': self.log_sec,
'pttype': self.pttype,
'ptuuid': self.ptuuid,
'rota': self.rota,
'tran': self.tran,
'partn': self.partn,
'partuuid': self.partuuid,
'parttype': self.parttype,
'uuid': self.uuid,
'fstype': self.fstype,
'fsver': self.fsver,
'fsavail': self.fsavail,
'fsuse_percentage': self.fsuse_percentage,
'type': self.type,
'mountpoint': str(self.mountpoint) if self.mountpoint else None,
'mountpoints': [str(m) for m in self.mountpoints],
'fsroots': [str(r) for r in self.fsroots],
'children': [c.json() for c in self.children]
}
@field_validator('size', mode='before')
@classmethod
def convert_size(cls, v: int, info: ValidationInfo) -> Size:
sector_size = SectorSize(info.data['log_sec'], Unit.B)
return Size(v, Unit.B, sector_size)
@property
def btrfs_subvol_info(self) -> Dict[Path, Path]:
"""
It is assumed that lsblk will contain the fields as
@field_validator('mountpoints', 'fsroots', mode='before')
@classmethod
def remove_none(cls, v: list[Path | None]) -> list[Path]:
return [item for item in v if item is not None]
"mountpoints": ["/mnt/archinstall/log", "/mnt/archinstall/home", "/mnt/archinstall", ...]
"fsroots": ["/@log", "/@home", "/@"...]
we'll thereby map the fsroot, which are the mounted filesystem roots
to the corresponding mountpoints
"""
return dict(zip(self.fsroots, self.mountpoints))
@field_serializer('size', when_used='json')
def serialize_size(self, size: Size) -> str:
return size.format_size(Unit.MiB)
@classmethod
def exclude(cls) -> List[str]:
return ['children']
@classmethod
def fields(cls) -> List[str]:
return [f.name for f in dataclasses.fields(LsblkInfo) if f.name not in cls.exclude()]
@classmethod
def from_json(cls, blockdevice: Dict[str, Any]) -> LsblkInfo:
lsblk_info = cls()
for f in cls.fields():
lsblk_field = _clean_field(f, CleanType.Blockdevice)
data_field = _clean_field(f, CleanType.Dataclass)
val: Any = None
if isinstance(getattr(lsblk_info, data_field), Path):
val = Path(blockdevice[lsblk_field])
elif isinstance(getattr(lsblk_info, data_field), Size):
sector_size = SectorSize(blockdevice['log-sec'], Unit.B)
val = Size(blockdevice[lsblk_field], Unit.B, sector_size)
else:
val = blockdevice[lsblk_field]
setattr(lsblk_info, data_field, val)
lsblk_info.children = [LsblkInfo.from_json(child) for child in blockdevice.get('children', [])]
lsblk_info.mountpoint = Path(lsblk_info.mountpoint) if lsblk_info.mountpoint else None
# sometimes lsblk returns 'mountpoints': [null]
lsblk_info.mountpoints = [Path(mnt) for mnt in lsblk_info.mountpoints if mnt]
fs_roots = []
for r in lsblk_info.fsroots:
if r:
path = Path(r)
# store the fsroot entries without the leading /
fs_roots.append(path.relative_to(path.anchor))
lsblk_info.fsroots = fs_roots
return lsblk_info
def fields(cls) -> list[str]:
return [
field.alias or name
for name, field in cls.model_fields.items()
if name != 'children'
]
class CleanType(Enum):
Blockdevice = auto()
Dataclass = auto()
Lsblk = auto()
def _clean_field(name: str, clean_type: CleanType) -> str:
match clean_type:
case CleanType.Blockdevice:
return name.replace('_percentage', '%').replace('_', '-')
case CleanType.Dataclass:
return name.lower().replace('-', '_').replace('%', '_percentage')
case CleanType.Lsblk:
return name.replace('_percentage', '%').replace('_', '-')
class LsblkOutput(BaseModel):
blockdevices: list[LsblkInfo]
def _fetch_lsblk_info(
@ -1441,8 +1367,7 @@ def _fetch_lsblk_info(
reverse: bool = False,
full_dev_path: bool = False
) -> List[LsblkInfo]:
fields = [_clean_field(f, CleanType.Lsblk) for f in LsblkInfo.fields()]
cmd = ['lsblk', '--json', '--bytes', '--output', ','.join(fields)]
cmd = ['lsblk', '--json', '--bytes', '--output', ','.join(LsblkInfo.fields())]
if reverse:
cmd.append('--inverse')
@ -1472,8 +1397,7 @@ def _fetch_lsblk_info(
error(f"Could not decode lsblk JSON:\n{worker.output().decode().rstrip()}")
raise err
blockdevices = data['blockdevices']
return [LsblkInfo.from_json(device) for device in blockdevices]
return LsblkOutput(**data).blockdevices
def get_lsblk_info(