Refactor syscommand (#4176)
* Refactor SysCommand * Refactor syscommand
This commit is contained in:
parent
316251f6e0
commit
bde544caca
File diff suppressed because it is too large
Load Diff
|
After Width: | Height: | Size: 577 KiB |
|
|
@ -4,7 +4,7 @@ import getpass
|
|||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from archinstall.lib.general import SysCommandWorker
|
||||
from archinstall.lib.command import SysCommandWorker
|
||||
from archinstall.lib.models.authentication import AuthenticationConfiguration, U2FLoginConfiguration, U2FLoginMethod
|
||||
from archinstall.lib.models.users import User
|
||||
from archinstall.lib.output import debug, info
|
||||
|
|
|
|||
|
|
@ -5,8 +5,9 @@ from collections.abc import Iterator
|
|||
from types import TracebackType
|
||||
from typing import TYPE_CHECKING, ClassVar, Self
|
||||
|
||||
from .command import SysCommand, SysCommandWorker
|
||||
from .exceptions import SysCallError
|
||||
from .general import SysCommand, SysCommandWorker, locate_binary
|
||||
from .general import locate_binary
|
||||
from .output import error
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,377 @@
|
|||
import os
|
||||
import shlex
|
||||
import stat
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from collections.abc import Iterator
|
||||
from select import EPOLLHUP, EPOLLIN, epoll
|
||||
from types import TracebackType
|
||||
from typing import Any, Self, override
|
||||
|
||||
from archinstall.lib.exceptions import SysCallError
|
||||
from archinstall.lib.general import clear_vt100_escape_codes, locate_binary
|
||||
from archinstall.lib.output import debug, error, logger
|
||||
|
||||
|
||||
class SysCommandWorker:
|
||||
def __init__(
|
||||
self,
|
||||
cmd: str | list[str],
|
||||
peek_output: bool | None = False,
|
||||
environment_vars: dict[str, str] | None = None,
|
||||
working_directory: str = './',
|
||||
remove_vt100_escape_codes_from_lines: bool = True,
|
||||
):
|
||||
if isinstance(cmd, str):
|
||||
cmd = shlex.split(cmd)
|
||||
|
||||
if cmd and not cmd[0].startswith(('/', './')): # Path() does not work well
|
||||
cmd[0] = locate_binary(cmd[0])
|
||||
|
||||
self.cmd = cmd
|
||||
self.peek_output = peek_output
|
||||
# define the standard locale for command outputs. For now the C ascii one. Can be overridden
|
||||
self.environment_vars = {'LC_ALL': 'C'}
|
||||
if environment_vars:
|
||||
self.environment_vars.update(environment_vars)
|
||||
|
||||
self.working_directory = working_directory
|
||||
|
||||
self.exit_code: int | None = None
|
||||
self._trace_log = b''
|
||||
self._trace_log_pos = 0
|
||||
self.poll_object = epoll()
|
||||
self.child_fd: int | None = None
|
||||
self.started: float | None = None
|
||||
self.ended: float | None = None
|
||||
self.remove_vt100_escape_codes_from_lines: bool = remove_vt100_escape_codes_from_lines
|
||||
|
||||
def __contains__(self, key: bytes) -> bool:
|
||||
"""
|
||||
Contains will also move the current buffert position forward.
|
||||
This is to avoid re-checking the same data when looking for output.
|
||||
"""
|
||||
assert isinstance(key, bytes)
|
||||
|
||||
index = self._trace_log.find(key, self._trace_log_pos)
|
||||
if index >= 0:
|
||||
self._trace_log_pos += index + len(key)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def __iter__(self, *args: str, **kwargs: dict[str, Any]) -> Iterator[bytes]:
|
||||
last_line = self._trace_log.rfind(b'\n')
|
||||
lines = filter(None, self._trace_log[self._trace_log_pos : last_line].splitlines())
|
||||
for line in lines:
|
||||
if self.remove_vt100_escape_codes_from_lines:
|
||||
line = clear_vt100_escape_codes(line)
|
||||
|
||||
yield line + b'\n'
|
||||
|
||||
self._trace_log_pos = last_line
|
||||
|
||||
@override
|
||||
def __repr__(self) -> str:
|
||||
self.make_sure_we_are_executing()
|
||||
return str(self._trace_log)
|
||||
|
||||
@override
|
||||
def __str__(self) -> str:
|
||||
try:
|
||||
return self._trace_log.decode('utf-8')
|
||||
except UnicodeDecodeError:
|
||||
return str(self._trace_log)
|
||||
|
||||
def __enter__(self) -> Self:
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None) -> None:
|
||||
# b''.join(sys_command('sync')) # No need to, since the underlying fs() object will call sync.
|
||||
# TODO: https://stackoverflow.com/questions/28157929/how-to-safely-handle-an-exception-inside-a-context-manager
|
||||
|
||||
if self.child_fd:
|
||||
try:
|
||||
os.close(self.child_fd)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if self.peek_output:
|
||||
# To make sure any peaked output didn't leave us hanging
|
||||
# on the same line we were on.
|
||||
sys.stdout.write('\n')
|
||||
sys.stdout.flush()
|
||||
|
||||
if exc_type is not None:
|
||||
debug(str(exc_value))
|
||||
|
||||
if self.exit_code != 0:
|
||||
raise SysCallError(
|
||||
f'{self.cmd} exited with abnormal exit code [{self.exit_code}]: {str(self)[-500:]}',
|
||||
self.exit_code,
|
||||
worker_log=self._trace_log,
|
||||
)
|
||||
|
||||
def is_alive(self) -> bool:
|
||||
self.poll()
|
||||
|
||||
if self.started and self.ended is None:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def write(self, data: bytes, line_ending: bool = True) -> int:
|
||||
assert isinstance(data, bytes) # TODO: Maybe we can support str as well and encode it
|
||||
|
||||
self.make_sure_we_are_executing()
|
||||
|
||||
if self.child_fd:
|
||||
return os.write(self.child_fd, data + (b'\n' if line_ending else b''))
|
||||
|
||||
return 0
|
||||
|
||||
def make_sure_we_are_executing(self) -> bool:
|
||||
if not self.started:
|
||||
return self.execute()
|
||||
return True
|
||||
|
||||
def tell(self) -> int:
|
||||
self.make_sure_we_are_executing()
|
||||
return self._trace_log_pos
|
||||
|
||||
def seek(self, pos: int) -> None:
|
||||
self.make_sure_we_are_executing()
|
||||
# Safety check to ensure 0 < pos < len(tracelog)
|
||||
self._trace_log_pos = min(max(0, pos), len(self._trace_log))
|
||||
|
||||
def peak(self, output: str | bytes) -> bool:
|
||||
if self.peek_output:
|
||||
if isinstance(output, bytes):
|
||||
try:
|
||||
output = output.decode('UTF-8')
|
||||
except UnicodeDecodeError:
|
||||
return False
|
||||
|
||||
_cmd_output(output)
|
||||
|
||||
sys.stdout.write(output)
|
||||
sys.stdout.flush()
|
||||
|
||||
return True
|
||||
|
||||
def poll(self) -> None:
|
||||
self.make_sure_we_are_executing()
|
||||
|
||||
if self.child_fd:
|
||||
got_output = False
|
||||
for _fileno, _event in self.poll_object.poll(0.1):
|
||||
try:
|
||||
output = os.read(self.child_fd, 8192)
|
||||
got_output = True
|
||||
self.peak(output)
|
||||
self._trace_log += output
|
||||
except OSError:
|
||||
self.ended = time.time()
|
||||
break
|
||||
|
||||
if self.ended or (not got_output and not _pid_exists(self.pid)):
|
||||
self.ended = time.time()
|
||||
try:
|
||||
wait_status = os.waitpid(self.pid, 0)[1]
|
||||
self.exit_code = os.waitstatus_to_exitcode(wait_status)
|
||||
except ChildProcessError:
|
||||
try:
|
||||
wait_status = os.waitpid(self.child_fd, 0)[1]
|
||||
self.exit_code = os.waitstatus_to_exitcode(wait_status)
|
||||
except ChildProcessError:
|
||||
self.exit_code = 1
|
||||
|
||||
def execute(self) -> bool:
|
||||
import pty
|
||||
|
||||
if (old_dir := os.getcwd()) != self.working_directory:
|
||||
os.chdir(str(self.working_directory))
|
||||
|
||||
# Note: If for any reason, we get a Python exception between here
|
||||
# and until os.close(), the traceback will get locked inside
|
||||
# stdout of the child_fd object. `os.read(self.child_fd, 8192)` is the
|
||||
# only way to get the traceback without losing it.
|
||||
|
||||
self.pid, self.child_fd = pty.fork()
|
||||
|
||||
# https://stackoverflow.com/questions/4022600/python-pty-fork-how-does-it-work
|
||||
if not self.pid:
|
||||
_cmd_history(self.cmd)
|
||||
|
||||
try:
|
||||
os.execve(self.cmd[0], list(self.cmd), {**os.environ, **self.environment_vars})
|
||||
except FileNotFoundError:
|
||||
error(f'{self.cmd[0]} does not exist.')
|
||||
self.exit_code = 1
|
||||
return False
|
||||
else:
|
||||
# Only parent process moves back to the original working directory
|
||||
os.chdir(old_dir)
|
||||
|
||||
self.started = time.time()
|
||||
self.poll_object.register(self.child_fd, EPOLLIN | EPOLLHUP)
|
||||
|
||||
return True
|
||||
|
||||
def decode(self, encoding: str = 'UTF-8') -> str:
|
||||
return self._trace_log.decode(encoding)
|
||||
|
||||
|
||||
class SysCommand:
|
||||
def __init__(
|
||||
self,
|
||||
cmd: str | list[str],
|
||||
peek_output: bool | None = False,
|
||||
environment_vars: dict[str, str] | None = None,
|
||||
working_directory: str = './',
|
||||
remove_vt100_escape_codes_from_lines: bool = True,
|
||||
):
|
||||
self.cmd = cmd
|
||||
self.peek_output = peek_output
|
||||
self.environment_vars = environment_vars
|
||||
self.working_directory = working_directory
|
||||
self.remove_vt100_escape_codes_from_lines = remove_vt100_escape_codes_from_lines
|
||||
|
||||
self.session: SysCommandWorker | None = None
|
||||
self.create_session()
|
||||
|
||||
def __enter__(self) -> SysCommandWorker | None:
|
||||
return self.session
|
||||
|
||||
def __exit__(self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None) -> None:
|
||||
# b''.join(sys_command('sync')) # No need to, since the underlying fs() object will call sync.
|
||||
# TODO: https://stackoverflow.com/questions/28157929/how-to-safely-handle-an-exception-inside-a-context-manager
|
||||
|
||||
if exc_type is not None:
|
||||
error(str(exc_value))
|
||||
|
||||
def __iter__(self, *args: list[Any], **kwargs: dict[str, Any]) -> Iterator[bytes]:
|
||||
if self.session:
|
||||
yield from self.session
|
||||
|
||||
def __getitem__(self, key: slice) -> bytes:
|
||||
if not self.session:
|
||||
raise KeyError('SysCommand() does not have an active session.')
|
||||
elif type(key) is slice:
|
||||
start = key.start or 0
|
||||
end = key.stop or len(self.session._trace_log)
|
||||
|
||||
return self.session._trace_log[start:end]
|
||||
else:
|
||||
raise ValueError("SysCommand() doesn't have key & value pairs, only slices, SysCommand('ls')[:10] as an example.")
|
||||
|
||||
@override
|
||||
def __repr__(self, *args: list[Any], **kwargs: dict[str, Any]) -> str:
|
||||
return self.decode('UTF-8', errors='backslashreplace') or ''
|
||||
|
||||
def create_session(self) -> bool:
|
||||
"""
|
||||
Initiates a :ref:`SysCommandWorker` session in this class ``.session``.
|
||||
It then proceeds to poll the process until it ends, after which it also
|
||||
clears any printed output if ``.peek_output=True``.
|
||||
"""
|
||||
if self.session:
|
||||
return True
|
||||
|
||||
with SysCommandWorker(
|
||||
self.cmd,
|
||||
peek_output=self.peek_output,
|
||||
environment_vars=self.environment_vars,
|
||||
remove_vt100_escape_codes_from_lines=self.remove_vt100_escape_codes_from_lines,
|
||||
working_directory=self.working_directory,
|
||||
) as session:
|
||||
self.session = session
|
||||
|
||||
while not self.session.ended:
|
||||
self.session.poll()
|
||||
|
||||
if self.peek_output:
|
||||
sys.stdout.write('\n')
|
||||
sys.stdout.flush()
|
||||
|
||||
return True
|
||||
|
||||
def decode(self, encoding: str = 'utf-8', errors: str = 'backslashreplace', strip: bool = True) -> str:
|
||||
if not self.session:
|
||||
raise ValueError('No session available to decode')
|
||||
|
||||
val = self.session._trace_log.decode(encoding, errors=errors)
|
||||
|
||||
if strip:
|
||||
return val.strip()
|
||||
return val
|
||||
|
||||
def output(self, remove_cr: bool = True) -> bytes:
|
||||
if not self.session:
|
||||
raise ValueError('No session available')
|
||||
|
||||
if remove_cr:
|
||||
return self.session._trace_log.replace(b'\r\n', b'\n')
|
||||
|
||||
return self.session._trace_log
|
||||
|
||||
@property
|
||||
def exit_code(self) -> int | None:
|
||||
if self.session:
|
||||
return self.session.exit_code
|
||||
else:
|
||||
return None
|
||||
|
||||
@property
|
||||
def trace_log(self) -> bytes | None:
|
||||
if self.session:
|
||||
return self.session._trace_log
|
||||
return None
|
||||
|
||||
|
||||
def run(
|
||||
cmd: list[str],
|
||||
input_data: bytes | None = None,
|
||||
) -> subprocess.CompletedProcess[bytes]:
|
||||
_cmd_history(cmd)
|
||||
|
||||
return subprocess.run(
|
||||
cmd,
|
||||
input=input_data,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
check=True,
|
||||
)
|
||||
|
||||
|
||||
def _pid_exists(pid: int) -> bool:
|
||||
try:
|
||||
return any(subprocess.check_output(['ps', '--no-headers', '-o', 'pid', '-p', str(pid)]).strip())
|
||||
except subprocess.CalledProcessError:
|
||||
return False
|
||||
|
||||
|
||||
def _cmd_history(cmd: list[str]) -> None:
|
||||
content = f'{time.time()} {cmd}\n'
|
||||
_append_log('cmd_history.txt', content)
|
||||
|
||||
|
||||
def _cmd_output(output: str) -> None:
|
||||
_append_log('cmd_output.txt', output)
|
||||
|
||||
|
||||
def _append_log(file: str, content: str) -> None:
|
||||
path = logger.directory / file
|
||||
|
||||
change_perm = not path.exists()
|
||||
|
||||
try:
|
||||
with path.open('a') as f:
|
||||
f.write(content)
|
||||
|
||||
if change_perm:
|
||||
path.chmod(stat.S_IRUSR | stat.S_IWUSR | stat.S_IRGRP)
|
||||
except (PermissionError, FileNotFoundError):
|
||||
# If the file does not exist, ignore the error
|
||||
pass
|
||||
|
|
@ -8,8 +8,8 @@ from typing import Literal, overload
|
|||
|
||||
from parted import Device, Disk, DiskException, FileSystem, Geometry, IOException, Partition, PartitionException, freshDisk, getAllDevices, getDevice, newDisk
|
||||
|
||||
from ..command import SysCommand, SysCommandWorker
|
||||
from ..exceptions import DiskError, SysCallError, UnknownFilesystemFormat
|
||||
from ..general import SysCommand, SysCommandWorker
|
||||
from ..luks import Luks2
|
||||
from ..models.device import (
|
||||
DEFAULT_ITER_TIME,
|
||||
|
|
@ -850,8 +850,8 @@ class DeviceHandler:
|
|||
def _wipe(self, dev_path: Path) -> None:
|
||||
"""
|
||||
Wipe a device (partition or otherwise) of meta-data, be it file system, LVM, etc.
|
||||
@param dev_path: Device path of the partition to be wiped.
|
||||
@type dev_path: str
|
||||
@param dev_path: Device path of the partition to be wiped.
|
||||
@type dev_path: str
|
||||
"""
|
||||
with open(dev_path, 'wb') as p:
|
||||
p.write(bytearray(1024))
|
||||
|
|
|
|||
|
|
@ -4,8 +4,9 @@ from typing import ClassVar
|
|||
|
||||
from archinstall.lib.models.device import Fido2Device
|
||||
|
||||
from ..command import SysCommand, SysCommandWorker
|
||||
from ..exceptions import SysCallError
|
||||
from ..general import SysCommand, SysCommandWorker, clear_vt100_escape_codes_from_str
|
||||
from ..general import clear_vt100_escape_codes_from_str
|
||||
from ..models.users import Password
|
||||
from ..output import error, info
|
||||
|
||||
|
|
@ -61,8 +62,8 @@ class Fido2:
|
|||
|
||||
Output example:
|
||||
|
||||
PATH MANUFACTURER PRODUCT
|
||||
/dev/hidraw1 Yubico YubiKey OTP+FIDO+CCID
|
||||
PATH MANUFACTURER PRODUCT
|
||||
/dev/hidraw1 Yubico YubiKey OTP+FIDO+CCID
|
||||
"""
|
||||
|
||||
# to prevent continuous reloading which will slow
|
||||
|
|
|
|||
|
|
@ -2,8 +2,8 @@ from pathlib import Path
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from archinstall.lib.command import SysCommand
|
||||
from archinstall.lib.exceptions import DiskError, SysCallError
|
||||
from archinstall.lib.general import SysCommand
|
||||
from archinstall.lib.models.device import LsblkInfo
|
||||
from archinstall.lib.output import debug, warn
|
||||
|
||||
|
|
|
|||
|
|
@ -1,24 +1,14 @@
|
|||
import json
|
||||
import os
|
||||
import re
|
||||
import secrets
|
||||
import shlex
|
||||
import stat
|
||||
import string
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from collections.abc import Iterator
|
||||
from datetime import date, datetime
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from select import EPOLLHUP, EPOLLIN, epoll
|
||||
from shutil import which
|
||||
from types import TracebackType
|
||||
from typing import Any, Self, override
|
||||
from typing import Any, override
|
||||
|
||||
from .exceptions import RequirementError, SysCallError
|
||||
from .output import debug, error, logger
|
||||
from archinstall.lib.exceptions import RequirementError
|
||||
|
||||
# https://stackoverflow.com/a/43627833/929999
|
||||
_VT100_ESCAPE_REGEX = r'\x1B\[[?0-9;]*[a-zA-Z]'
|
||||
|
|
@ -105,366 +95,3 @@ class UNSAFE_JSON(json.JSONEncoder, json.JSONDecoder):
|
|||
@override
|
||||
def encode(self, o: Any) -> str:
|
||||
return super().encode(jsonify(o, safe=False))
|
||||
|
||||
|
||||
class SysCommandWorker:
|
||||
def __init__(
|
||||
self,
|
||||
cmd: str | list[str],
|
||||
peek_output: bool | None = False,
|
||||
environment_vars: dict[str, str] | None = None,
|
||||
working_directory: str = './',
|
||||
remove_vt100_escape_codes_from_lines: bool = True,
|
||||
):
|
||||
if isinstance(cmd, str):
|
||||
cmd = shlex.split(cmd)
|
||||
|
||||
if cmd and not cmd[0].startswith(('/', './')): # Path() does not work well
|
||||
cmd[0] = locate_binary(cmd[0])
|
||||
|
||||
self.cmd = cmd
|
||||
self.peek_output = peek_output
|
||||
# define the standard locale for command outputs. For now the C ascii one. Can be overridden
|
||||
self.environment_vars = {'LC_ALL': 'C'}
|
||||
if environment_vars:
|
||||
self.environment_vars.update(environment_vars)
|
||||
|
||||
self.working_directory = working_directory
|
||||
|
||||
self.exit_code: int | None = None
|
||||
self._trace_log = b''
|
||||
self._trace_log_pos = 0
|
||||
self.poll_object = epoll()
|
||||
self.child_fd: int | None = None
|
||||
self.started: float | None = None
|
||||
self.ended: float | None = None
|
||||
self.remove_vt100_escape_codes_from_lines: bool = remove_vt100_escape_codes_from_lines
|
||||
|
||||
def __contains__(self, key: bytes) -> bool:
|
||||
"""
|
||||
Contains will also move the current buffert position forward.
|
||||
This is to avoid re-checking the same data when looking for output.
|
||||
"""
|
||||
assert isinstance(key, bytes)
|
||||
|
||||
index = self._trace_log.find(key, self._trace_log_pos)
|
||||
if index >= 0:
|
||||
self._trace_log_pos += index + len(key)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def __iter__(self, *args: str, **kwargs: dict[str, Any]) -> Iterator[bytes]:
|
||||
last_line = self._trace_log.rfind(b'\n')
|
||||
lines = filter(None, self._trace_log[self._trace_log_pos : last_line].splitlines())
|
||||
for line in lines:
|
||||
if self.remove_vt100_escape_codes_from_lines:
|
||||
line = clear_vt100_escape_codes(line)
|
||||
|
||||
yield line + b'\n'
|
||||
|
||||
self._trace_log_pos = last_line
|
||||
|
||||
@override
|
||||
def __repr__(self) -> str:
|
||||
self.make_sure_we_are_executing()
|
||||
return str(self._trace_log)
|
||||
|
||||
@override
|
||||
def __str__(self) -> str:
|
||||
try:
|
||||
return self._trace_log.decode('utf-8')
|
||||
except UnicodeDecodeError:
|
||||
return str(self._trace_log)
|
||||
|
||||
def __enter__(self) -> Self:
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None) -> None:
|
||||
# b''.join(sys_command('sync')) # No need to, since the underlying fs() object will call sync.
|
||||
# TODO: https://stackoverflow.com/questions/28157929/how-to-safely-handle-an-exception-inside-a-context-manager
|
||||
|
||||
if self.child_fd:
|
||||
try:
|
||||
os.close(self.child_fd)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if self.peek_output:
|
||||
# To make sure any peaked output didn't leave us hanging
|
||||
# on the same line we were on.
|
||||
sys.stdout.write('\n')
|
||||
sys.stdout.flush()
|
||||
|
||||
if exc_type is not None:
|
||||
debug(str(exc_value))
|
||||
|
||||
if self.exit_code != 0:
|
||||
raise SysCallError(
|
||||
f'{self.cmd} exited with abnormal exit code [{self.exit_code}]: {str(self)[-500:]}',
|
||||
self.exit_code,
|
||||
worker_log=self._trace_log,
|
||||
)
|
||||
|
||||
def is_alive(self) -> bool:
|
||||
self.poll()
|
||||
|
||||
if self.started and self.ended is None:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def write(self, data: bytes, line_ending: bool = True) -> int:
|
||||
assert isinstance(data, bytes) # TODO: Maybe we can support str as well and encode it
|
||||
|
||||
self.make_sure_we_are_executing()
|
||||
|
||||
if self.child_fd:
|
||||
return os.write(self.child_fd, data + (b'\n' if line_ending else b''))
|
||||
|
||||
return 0
|
||||
|
||||
def make_sure_we_are_executing(self) -> bool:
|
||||
if not self.started:
|
||||
return self.execute()
|
||||
return True
|
||||
|
||||
def tell(self) -> int:
|
||||
self.make_sure_we_are_executing()
|
||||
return self._trace_log_pos
|
||||
|
||||
def seek(self, pos: int) -> None:
|
||||
self.make_sure_we_are_executing()
|
||||
# Safety check to ensure 0 < pos < len(tracelog)
|
||||
self._trace_log_pos = min(max(0, pos), len(self._trace_log))
|
||||
|
||||
def peak(self, output: str | bytes) -> bool:
|
||||
if self.peek_output:
|
||||
if isinstance(output, bytes):
|
||||
try:
|
||||
output = output.decode('UTF-8')
|
||||
except UnicodeDecodeError:
|
||||
return False
|
||||
|
||||
_cmd_output(output)
|
||||
|
||||
sys.stdout.write(output)
|
||||
sys.stdout.flush()
|
||||
|
||||
return True
|
||||
|
||||
def poll(self) -> None:
|
||||
self.make_sure_we_are_executing()
|
||||
|
||||
if self.child_fd:
|
||||
got_output = False
|
||||
for _fileno, _event in self.poll_object.poll(0.1):
|
||||
try:
|
||||
output = os.read(self.child_fd, 8192)
|
||||
got_output = True
|
||||
self.peak(output)
|
||||
self._trace_log += output
|
||||
except OSError:
|
||||
self.ended = time.time()
|
||||
break
|
||||
|
||||
if self.ended or (not got_output and not _pid_exists(self.pid)):
|
||||
self.ended = time.time()
|
||||
try:
|
||||
wait_status = os.waitpid(self.pid, 0)[1]
|
||||
self.exit_code = os.waitstatus_to_exitcode(wait_status)
|
||||
except ChildProcessError:
|
||||
try:
|
||||
wait_status = os.waitpid(self.child_fd, 0)[1]
|
||||
self.exit_code = os.waitstatus_to_exitcode(wait_status)
|
||||
except ChildProcessError:
|
||||
self.exit_code = 1
|
||||
|
||||
def execute(self) -> bool:
|
||||
import pty
|
||||
|
||||
if (old_dir := os.getcwd()) != self.working_directory:
|
||||
os.chdir(str(self.working_directory))
|
||||
|
||||
# Note: If for any reason, we get a Python exception between here
|
||||
# and until os.close(), the traceback will get locked inside
|
||||
# stdout of the child_fd object. `os.read(self.child_fd, 8192)` is the
|
||||
# only way to get the traceback without losing it.
|
||||
|
||||
self.pid, self.child_fd = pty.fork()
|
||||
|
||||
# https://stackoverflow.com/questions/4022600/python-pty-fork-how-does-it-work
|
||||
if not self.pid:
|
||||
_cmd_history(self.cmd)
|
||||
|
||||
try:
|
||||
os.execve(self.cmd[0], list(self.cmd), {**os.environ, **self.environment_vars})
|
||||
except FileNotFoundError:
|
||||
error(f'{self.cmd[0]} does not exist.')
|
||||
self.exit_code = 1
|
||||
return False
|
||||
else:
|
||||
# Only parent process moves back to the original working directory
|
||||
os.chdir(old_dir)
|
||||
|
||||
self.started = time.time()
|
||||
self.poll_object.register(self.child_fd, EPOLLIN | EPOLLHUP)
|
||||
|
||||
return True
|
||||
|
||||
def decode(self, encoding: str = 'UTF-8') -> str:
|
||||
return self._trace_log.decode(encoding)
|
||||
|
||||
|
||||
class SysCommand:
|
||||
def __init__(
|
||||
self,
|
||||
cmd: str | list[str],
|
||||
peek_output: bool | None = False,
|
||||
environment_vars: dict[str, str] | None = None,
|
||||
working_directory: str = './',
|
||||
remove_vt100_escape_codes_from_lines: bool = True,
|
||||
):
|
||||
self.cmd = cmd
|
||||
self.peek_output = peek_output
|
||||
self.environment_vars = environment_vars
|
||||
self.working_directory = working_directory
|
||||
self.remove_vt100_escape_codes_from_lines = remove_vt100_escape_codes_from_lines
|
||||
|
||||
self.session: SysCommandWorker | None = None
|
||||
self.create_session()
|
||||
|
||||
def __enter__(self) -> SysCommandWorker | None:
|
||||
return self.session
|
||||
|
||||
def __exit__(self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None) -> None:
|
||||
# b''.join(sys_command('sync')) # No need to, since the underlying fs() object will call sync.
|
||||
# TODO: https://stackoverflow.com/questions/28157929/how-to-safely-handle-an-exception-inside-a-context-manager
|
||||
|
||||
if exc_type is not None:
|
||||
error(str(exc_value))
|
||||
|
||||
def __iter__(self, *args: list[Any], **kwargs: dict[str, Any]) -> Iterator[bytes]:
|
||||
if self.session:
|
||||
yield from self.session
|
||||
|
||||
def __getitem__(self, key: slice) -> bytes:
|
||||
if not self.session:
|
||||
raise KeyError('SysCommand() does not have an active session.')
|
||||
elif type(key) is slice:
|
||||
start = key.start or 0
|
||||
end = key.stop or len(self.session._trace_log)
|
||||
|
||||
return self.session._trace_log[start:end]
|
||||
else:
|
||||
raise ValueError("SysCommand() doesn't have key & value pairs, only slices, SysCommand('ls')[:10] as an example.")
|
||||
|
||||
@override
|
||||
def __repr__(self, *args: list[Any], **kwargs: dict[str, Any]) -> str:
|
||||
return self.decode('UTF-8', errors='backslashreplace') or ''
|
||||
|
||||
def create_session(self) -> bool:
|
||||
"""
|
||||
Initiates a :ref:`SysCommandWorker` session in this class ``.session``.
|
||||
It then proceeds to poll the process until it ends, after which it also
|
||||
clears any printed output if ``.peek_output=True``.
|
||||
"""
|
||||
if self.session:
|
||||
return True
|
||||
|
||||
with SysCommandWorker(
|
||||
self.cmd,
|
||||
peek_output=self.peek_output,
|
||||
environment_vars=self.environment_vars,
|
||||
remove_vt100_escape_codes_from_lines=self.remove_vt100_escape_codes_from_lines,
|
||||
working_directory=self.working_directory,
|
||||
) as session:
|
||||
self.session = session
|
||||
|
||||
while not self.session.ended:
|
||||
self.session.poll()
|
||||
|
||||
if self.peek_output:
|
||||
sys.stdout.write('\n')
|
||||
sys.stdout.flush()
|
||||
|
||||
return True
|
||||
|
||||
def decode(self, encoding: str = 'utf-8', errors: str = 'backslashreplace', strip: bool = True) -> str:
|
||||
if not self.session:
|
||||
raise ValueError('No session available to decode')
|
||||
|
||||
val = self.session._trace_log.decode(encoding, errors=errors)
|
||||
|
||||
if strip:
|
||||
return val.strip()
|
||||
return val
|
||||
|
||||
def output(self, remove_cr: bool = True) -> bytes:
|
||||
if not self.session:
|
||||
raise ValueError('No session available')
|
||||
|
||||
if remove_cr:
|
||||
return self.session._trace_log.replace(b'\r\n', b'\n')
|
||||
|
||||
return self.session._trace_log
|
||||
|
||||
@property
|
||||
def exit_code(self) -> int | None:
|
||||
if self.session:
|
||||
return self.session.exit_code
|
||||
else:
|
||||
return None
|
||||
|
||||
@property
|
||||
def trace_log(self) -> bytes | None:
|
||||
if self.session:
|
||||
return self.session._trace_log
|
||||
return None
|
||||
|
||||
|
||||
def _append_log(file: str, content: str) -> None:
|
||||
path = logger.directory / file
|
||||
|
||||
change_perm = not path.exists()
|
||||
|
||||
try:
|
||||
with path.open('a') as f:
|
||||
f.write(content)
|
||||
|
||||
if change_perm:
|
||||
path.chmod(stat.S_IRUSR | stat.S_IWUSR | stat.S_IRGRP)
|
||||
except (PermissionError, FileNotFoundError):
|
||||
# If the file does not exist, ignore the error
|
||||
pass
|
||||
|
||||
|
||||
def _cmd_history(cmd: list[str]) -> None:
|
||||
content = f'{time.time()} {cmd}\n'
|
||||
_append_log('cmd_history.txt', content)
|
||||
|
||||
|
||||
def _cmd_output(output: str) -> None:
|
||||
_append_log('cmd_output.txt', output)
|
||||
|
||||
|
||||
def run(
|
||||
cmd: list[str],
|
||||
input_data: bytes | None = None,
|
||||
) -> subprocess.CompletedProcess[bytes]:
|
||||
_cmd_history(cmd)
|
||||
|
||||
return subprocess.run(
|
||||
cmd,
|
||||
input=input_data,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
check=True,
|
||||
)
|
||||
|
||||
|
||||
def _pid_exists(pid: int) -> bool:
|
||||
try:
|
||||
return any(subprocess.check_output(['ps', '--no-headers', '-o', 'pid', '-p', str(pid)]).strip())
|
||||
except subprocess.CalledProcessError:
|
||||
return False
|
||||
|
|
|
|||
|
|
@ -4,8 +4,8 @@ from functools import cached_property
|
|||
from pathlib import Path
|
||||
from typing import Self
|
||||
|
||||
from .command import SysCommand
|
||||
from .exceptions import SysCallError
|
||||
from .general import SysCommand
|
||||
from .networking import enrich_iface_types, list_interfaces
|
||||
from .output import debug
|
||||
from .translationhandler import tr
|
||||
|
|
|
|||
|
|
@ -37,8 +37,8 @@ from archinstall.lib.translationhandler import tr
|
|||
|
||||
from .args import arch_config_handler
|
||||
from .boot import Boot
|
||||
from .command import SysCommand, run
|
||||
from .exceptions import DiskError, HardwareIncompatibilityError, RequirementError, ServiceException, SysCallError
|
||||
from .general import SysCommand, run
|
||||
from .hardware import SysInfo
|
||||
from .locale.utils import verify_keyboard_layout, verify_x11_keyboard_layout
|
||||
from .luks import Luks2
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from ..command import SysCommand
|
||||
from ..exceptions import ServiceException, SysCallError
|
||||
from ..general import SysCommand, running_from_host
|
||||
from ..general import running_from_host
|
||||
from ..output import error
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -7,8 +7,9 @@ from types import TracebackType
|
|||
from archinstall.lib.disk.utils import get_lsblk_info, umount
|
||||
from archinstall.lib.models.device import DEFAULT_ITER_TIME
|
||||
|
||||
from .command import SysCommand, SysCommandWorker, run
|
||||
from .exceptions import DiskError, SysCallError
|
||||
from .general import SysCommand, SysCommandWorker, generate_password, run
|
||||
from .general import generate_password
|
||||
from .models.users import Password
|
||||
from .output import debug, info
|
||||
|
||||
|
|
|
|||
|
|
@ -3,8 +3,8 @@ from dataclasses import dataclass
|
|||
from pathlib import Path
|
||||
from typing import assert_never
|
||||
|
||||
from archinstall.lib.command import SysCommand
|
||||
from archinstall.lib.exceptions import SysCallError
|
||||
from archinstall.lib.general import SysCommand
|
||||
from archinstall.lib.models.network import WifiConfiguredNetwork, WifiNetwork
|
||||
from archinstall.lib.network.wpa_supplicant import WpaSupplicantConfig
|
||||
from archinstall.lib.output import debug
|
||||
|
|
|
|||
|
|
@ -5,8 +5,8 @@ from pathlib import Path
|
|||
|
||||
from archinstall.lib.translationhandler import tr
|
||||
|
||||
from ..command import SysCommand
|
||||
from ..exceptions import RequirementError
|
||||
from ..general import SysCommand
|
||||
from ..output import error, info, warn
|
||||
from ..plugins import plugins
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue