Fixing some mypy complaints (#780)

* Fixed some mypy issues regarding SysCommand* and logging
* Fixed imports and undefined variable
This commit is contained in:
Anton Hvornum 2021-12-02 20:20:31 +00:00 committed by GitHub
parent 908c7b8cc0
commit b1b820f4cb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 122 additions and 149 deletions

View File

@ -1,3 +1,5 @@
from typing import Optional
class RequirementError(BaseException): class RequirementError(BaseException):
pass pass
@ -15,7 +17,7 @@ class ProfileError(BaseException):
class SysCallError(BaseException): class SysCallError(BaseException):
def __init__(self, message, exit_code): def __init__(self, message :str, exit_code :Optional[int]) -> None:
super(SysCallError, self).__init__(message) super(SysCallError, self).__init__(message)
self.message = message self.message = message
self.exit_code = exit_code self.exit_code = exit_code

View File

@ -9,10 +9,11 @@ import string
import sys import sys
import time import time
from datetime import datetime, date from datetime import datetime, date
from typing import Union from typing import Callable, Optional, Dict, Any, List, Union, Iterator
try:
if sys.platform == 'linux':
from select import epoll, EPOLLIN, EPOLLHUP from select import epoll, EPOLLIN, EPOLLHUP
except: else:
import select import select
EPOLLIN = 0 EPOLLIN = 0
EPOLLHUP = 0 EPOLLHUP = 0
@ -22,20 +23,20 @@ except:
Create a epoll() implementation that simulates the epoll() behavior. Create a epoll() implementation that simulates the epoll() behavior.
This so that the rest of the code doesn't need to worry weither we're using select() or epoll(). This so that the rest of the code doesn't need to worry weither we're using select() or epoll().
""" """
def __init__(self): def __init__(self) -> None:
self.sockets = {} self.sockets: Dict[str, Any] = {}
self.monitoring = {} self.monitoring: Dict[int, Any] = {}
def unregister(self, fileno, *args, **kwargs): def unregister(self, fileno :int, *args :List[Any], **kwargs :Dict[str, Any]) -> None:
try: try:
del(self.monitoring[fileno]) del(self.monitoring[fileno])
except: except:
pass pass
def register(self, fileno, *args, **kwargs): def register(self, fileno :int, *args :int, **kwargs :Dict[str, Any]) -> None:
self.monitoring[fileno] = True self.monitoring[fileno] = True
def poll(self, timeout=0.05, *args, **kwargs): def poll(self, timeout: float = 0.05, *args :str, **kwargs :Dict[str, Any]) -> List[Any]:
try: try:
return [[fileno, 1] for fileno in select.select(list(self.monitoring.keys()), [], [], timeout)[0]] return [[fileno, 1] for fileno in select.select(list(self.monitoring.keys()), [], [], timeout)[0]]
except OSError: except OSError:
@ -66,13 +67,13 @@ def multisplit(s, splitters):
s = ns s = ns
return s return s
def locate_binary(name): def locate_binary(name :str) -> str:
for PATH in os.environ['PATH'].split(':'): for PATH in os.environ['PATH'].split(':'):
for root, folders, files in os.walk(PATH): for root, folders, files in os.walk(PATH):
for file in files: for file in files:
if file == name: if file == name:
return os.path.join(root, file) return os.path.join(root, file)
break # Don't recurse break # Don't recurse
raise RequirementError(f"Binary {name} does not exist.") raise RequirementError(f"Binary {name} does not exist.")
@ -157,7 +158,14 @@ class UNSAFE_JSON(json.JSONEncoder, json.JSONDecoder):
return super(UNSAFE_JSON, self).encode(self._encode(obj)) return super(UNSAFE_JSON, self).encode(self._encode(obj))
class SysCommandWorker: class SysCommandWorker:
def __init__(self, cmd, callbacks=None, peak_output=False, environment_vars=None, logfile=None, working_directory='./'): def __init__(self,
cmd :Union[str, List[str]],
callbacks :Optional[Dict[str, Any]] = None,
peak_output :Optional[bool] = False,
environment_vars :Optional[Dict[str, Any]] = None,
logfile :Optional[None] = None,
working_directory :Optional[str] = './'):
if not callbacks: if not callbacks:
callbacks = {} callbacks = {}
if not environment_vars: if not environment_vars:
@ -166,6 +174,7 @@ class SysCommandWorker:
if type(cmd) is str: if type(cmd) is str:
cmd = shlex.split(cmd) cmd = shlex.split(cmd)
cmd = list(cmd) # This is to please mypy
if cmd[0][0] != '/' and cmd[0][:2] != './': if cmd[0][0] != '/' and cmd[0][:2] != './':
# "which" doesn't work as it's a builtin to bash. # "which" doesn't work as it's a builtin to bash.
# It used to work, but for whatever reason it doesn't anymore. # It used to work, but for whatever reason it doesn't anymore.
@ -179,15 +188,15 @@ class SysCommandWorker:
self.logfile = logfile self.logfile = logfile
self.working_directory = working_directory self.working_directory = working_directory
self.exit_code = None self.exit_code :Optional[int] = None
self._trace_log = b'' self._trace_log = b''
self._trace_log_pos = 0 self._trace_log_pos = 0
self.poll_object = epoll() self.poll_object = epoll()
self.child_fd = None self.child_fd :Optional[int] = None
self.started = None self.started :Optional[float] = None
self.ended = None self.ended :Optional[float] = None
def __contains__(self, key: bytes): def __contains__(self, key: bytes) -> bool:
""" """
Contains will also move the current buffert position forward. Contains will also move the current buffert position forward.
This is to avoid re-checking the same data when looking for output. This is to avoid re-checking the same data when looking for output.
@ -199,21 +208,21 @@ class SysCommandWorker:
return contains return contains
def __iter__(self, *args, **kwargs): def __iter__(self, *args :str, **kwargs :Dict[str, Any]) -> Iterator[bytes]:
for line in self._trace_log[self._trace_log_pos:self._trace_log.rfind(b'\n')].split(b'\n'): for line in self._trace_log[self._trace_log_pos:self._trace_log.rfind(b'\n')].split(b'\n'):
if line: if line:
yield line + b'\n' yield line + b'\n'
self._trace_log_pos = self._trace_log.rfind(b'\n') self._trace_log_pos = self._trace_log.rfind(b'\n')
def __repr__(self): def __repr__(self) -> str:
self.make_sure_we_are_executing() self.make_sure_we_are_executing()
return str(self._trace_log) return str(self._trace_log)
def __enter__(self): def __enter__(self) -> 'SysCommandWorker':
return self return self
def __exit__(self, *args): def __exit__(self, *args :str) -> None:
# b''.join(sys_command('sync')) # No need to, since the underlying fs() object will call sync. # b''.join(sys_command('sync')) # No need to, since the underlying fs() object will call sync.
# TODO: https://stackoverflow.com/questions/28157929/how-to-safely-handle-an-exception-inside-a-context-manager # TODO: https://stackoverflow.com/questions/28157929/how-to-safely-handle-an-exception-inside-a-context-manager
@ -233,9 +242,9 @@ class SysCommandWorker:
log(args[1], level=logging.ERROR, fg='red') log(args[1], level=logging.ERROR, fg='red')
if self.exit_code != 0: if self.exit_code != 0:
raise SysCallError(f"{self.cmd} exited with abnormal exit code: {self.exit_code}") raise SysCallError(f"{self.cmd} exited with abnormal exit code: {self.exit_code}", self.exit_code)
def is_alive(self): def is_alive(self) -> bool:
self.poll() self.poll()
if self.started and self.ended is None: if self.started and self.ended is None:
@ -243,22 +252,26 @@ class SysCommandWorker:
return False return False
def write(self, data: bytes, line_ending=True): def write(self, data: bytes, line_ending :bool = True) -> int:
assert type(data) == bytes # TODO: Maybe we can support str as well and encode it assert type(data) == bytes # TODO: Maybe we can support str as well and encode it
self.make_sure_we_are_executing() self.make_sure_we_are_executing()
os.write(self.child_fd, data + (b'\n' if line_ending else b'')) if self.child_fd:
return os.write(self.child_fd, data + (b'\n' if line_ending else b''))
def make_sure_we_are_executing(self): return 0
def make_sure_we_are_executing(self) -> bool:
if not self.started: if not self.started:
return self.execute() return self.execute()
return True
def tell(self) -> int: def tell(self) -> int:
self.make_sure_we_are_executing() self.make_sure_we_are_executing()
return self._trace_log_pos return self._trace_log_pos
def seek(self, pos): def seek(self, pos :int) -> None:
self.make_sure_we_are_executing() self.make_sure_we_are_executing()
# Safety check to ensure 0 < pos < len(tracelog) # Safety check to ensure 0 < pos < len(tracelog)
self._trace_log_pos = min(max(0, pos), len(self._trace_log)) self._trace_log_pos = min(max(0, pos), len(self._trace_log))
@ -271,39 +284,41 @@ class SysCommandWorker:
except UnicodeDecodeError: except UnicodeDecodeError:
return False return False
sys.stdout.write(output) sys.stdout.write(str(output))
sys.stdout.flush() sys.stdout.flush()
return True return True
def poll(self): def poll(self) -> None:
self.make_sure_we_are_executing() self.make_sure_we_are_executing()
got_output = False if self.child_fd:
for fileno, event in self.poll_object.poll(0.1): got_output = False
try: for fileno, event in self.poll_object.poll(0.1):
output = os.read(self.child_fd, 8192)
got_output = True
self.peak(output)
self._trace_log += output
except OSError:
self.ended = time.time()
break
if self.ended or (got_output is False and pid_exists(self.pid) is False):
self.ended = time.time()
try:
self.exit_code = os.waitpid(self.pid, 0)[1]
except ChildProcessError:
try: try:
self.exit_code = os.waitpid(self.child_fd, 0)[1] output = os.read(self.child_fd, 8192)
got_output = True
self.peak(output)
self._trace_log += output
except OSError:
self.ended = time.time()
break
if self.ended or (got_output is False and pid_exists(self.pid) is False):
self.ended = time.time()
try:
self.exit_code = os.waitpid(self.pid, 0)[1]
except ChildProcessError: except ChildProcessError:
self.exit_code = 1 try:
self.exit_code = os.waitpid(self.child_fd, 0)[1]
except ChildProcessError:
self.exit_code = 1
def execute(self) -> bool: def execute(self) -> bool:
import pty import pty
if (old_dir := os.getcwd()) != self.working_directory: if (old_dir := os.getcwd()) != self.working_directory:
os.chdir(self.working_directory) os.chdir(str(self.working_directory))
# Note: If for any reason, we get a Python exception between here # Note: If for any reason, we get a Python exception between here
# and until os.close(), the traceback will get locked inside # and until os.close(), the traceback will get locked inside
@ -320,7 +335,7 @@ class SysCommandWorker:
except PermissionError: except PermissionError:
pass pass
os.execve(self.cmd[0], self.cmd, {**os.environ, **self.environment_vars}) os.execve(self.cmd[0], list(self.cmd), {**os.environ, **self.environment_vars})
if storage['arguments'].get('debug'): if storage['arguments'].get('debug'):
log(f"Executing: {self.cmd}", level=logging.DEBUG) log(f"Executing: {self.cmd}", level=logging.DEBUG)
@ -334,15 +349,23 @@ class SysCommandWorker:
return True return True
def decode(self, encoding='UTF-8'): def decode(self, encoding :str = 'UTF-8') -> str:
return self._trace_log.decode(encoding) return self._trace_log.decode(encoding)
class SysCommand: class SysCommand:
def __init__(self, cmd, callback=None, start_callback=None, peak_output=False, environment_vars=None, working_directory='./'): def __init__(self,
cmd :Union[str, List[str]],
callbacks :Optional[Dict[str, Callable[[Any], Any]]] = None,
start_callback :Optional[Callable[[Any], Any]] = None,
peak_output :Optional[bool] = False,
environment_vars :Optional[Dict[str, Any]] = None,
working_directory :Optional[str] = './'):
_callbacks = {} _callbacks = {}
if callback: if callbacks:
_callbacks['on_end'] = callback for hook, func in callbacks.items():
_callbacks[hook] = func
if start_callback: if start_callback:
_callbacks['on_start'] = start_callback _callbacks['on_start'] = start_callback
@ -352,26 +375,28 @@ class SysCommand:
self.environment_vars = environment_vars self.environment_vars = environment_vars
self.working_directory = working_directory self.working_directory = working_directory
self.session = None self.session :Optional[SysCommandWorker] = None
self.create_session() self.create_session()
def __enter__(self): def __enter__(self) -> Optional[SysCommandWorker]:
return self.session return self.session
def __exit__(self, *args, **kwargs): def __exit__(self, *args :str, **kwargs :Dict[str, Any]) -> None:
# b''.join(sys_command('sync')) # No need to, since the underlying fs() object will call sync. # b''.join(sys_command('sync')) # No need to, since the underlying fs() object will call sync.
# TODO: https://stackoverflow.com/questions/28157929/how-to-safely-handle-an-exception-inside-a-context-manager # TODO: https://stackoverflow.com/questions/28157929/how-to-safely-handle-an-exception-inside-a-context-manager
if len(args) >= 2 and args[1]: if len(args) >= 2 and args[1]:
log(args[1], level=logging.ERROR, fg='red') log(args[1], level=logging.ERROR, fg='red')
def __iter__(self, *args, **kwargs): def __iter__(self, *args :List[Any], **kwargs :Dict[str, Any]) -> Iterator[bytes]:
if self.session:
for line in self.session:
yield line
for line in self.session: def __getitem__(self, key :slice) -> Optional[bytes]:
yield line if not self.session:
raise KeyError(f"SysCommand() does not have an active session.")
def __getitem__(self, key): elif type(key) is slice:
if type(key) is slice:
start = key.start if key.start else 0 start = key.start if key.start else 0
end = key.stop if key.stop else len(self.session._trace_log) end = key.stop if key.stop else len(self.session._trace_log)
@ -379,10 +404,12 @@ class SysCommand:
else: else:
raise ValueError("SysCommand() doesn't have key & value pairs, only slices, SysCommand('ls')[:10] as an example.") raise ValueError("SysCommand() doesn't have key & value pairs, only slices, SysCommand('ls')[:10] as an example.")
def __repr__(self, *args, **kwargs): def __repr__(self, *args :List[Any], **kwargs :Dict[str, Any]) -> str:
return self.session._trace_log.decode('UTF-8') if self.session:
return self.session._trace_log.decode('UTF-8')
return ''
def __json__(self): def __json__(self) -> Dict[str, Union[str, bool, List[str], Dict[str, Any], Optional[bool], Optional[Dict[str, Any]]]]:
return { return {
'cmd': self.cmd, 'cmd': self.cmd,
'callbacks': self._callbacks, 'callbacks': self._callbacks,
@ -391,7 +418,7 @@ class SysCommand:
'session': True if self.session else False 'session': True if self.session else False
} }
def create_session(self): def create_session(self) -> bool:
if self.session: if self.session:
return True return True
@ -406,16 +433,23 @@ class SysCommand:
return True return True
def decode(self, fmt='UTF-8'): def decode(self, fmt :str = 'UTF-8') -> Optional[str]:
return self.session._trace_log.decode(fmt) if self.session:
return self.session._trace_log.decode(fmt)
return None
@property @property
def exit_code(self): def exit_code(self) -> Optional[int]:
return self.session.exit_code if self.session:
return self.session.exit_code
else:
return None
@property @property
def trace_log(self): def trace_log(self) -> Optional[bytes]:
return self.session._trace_log if self.session:
return self.session._trace_log
return None
def prerequisite_check(): def prerequisite_check():
@ -428,7 +462,8 @@ def prerequisite_check():
def reboot(): def reboot():
SysCommand("/usr/bin/reboot") SysCommand("/usr/bin/reboot")
def pid_exists(pid: int):
def pid_exists(pid: int) -> bool:
try: try:
return any(subprocess.check_output(['/usr/bin/ps', '--no-headers', '-o', 'pid', '-p', str(pid)]).strip()) return any(subprocess.check_output(['/usr/bin/ps', '--no-headers', '-o', 'pid', '-p', str(pid)]).strip())
except subprocess.CalledProcessError: except subprocess.CalledProcessError:

View File

@ -1,51 +1,19 @@
import abc
import logging import logging
import os import os
import sys import sys
from pathlib import Path from pathlib import Path
from typing import Dict, Union
from .storage import storage from .storage import storage
# TODO: use logging's built in levels instead. class Journald:
# Although logging is threaded and I wish to avoid that.
# It's more Pythonistic or w/e you want to call it.
class LogLevels:
Critical = 0b001
Error = 0b010
Warning = 0b011
Info = 0b101
Debug = 0b111
class Journald(dict):
@staticmethod @staticmethod
@abc.abstractmethod def log(message :str, level :int = logging.DEBUG) -> None:
def log(message, level=logging.DEBUG):
try: try:
import systemd.journal # type: ignore import systemd.journal # type: ignore
except ModuleNotFoundError: except ModuleNotFoundError:
return False return None
# For backwards compatibility, convert old style log-levels
# to logging levels (and warn about deprecated usage)
# There's some code re-usage here but that should be fine.
# TODO: Remove these in a few versions:
if level == LogLevels.Critical:
log("Deprecated level detected in log message, please use new logging.<level> instead for the following log message:", fg="red", level=logging.ERROR, force=True)
level = logging.CRITICAL
elif level == LogLevels.Error:
log("Deprecated level detected in log message, please use new logging.<level> instead for the following log message:", fg="red", level=logging.ERROR, force=True)
level = logging.ERROR
elif level == LogLevels.Warning:
log("Deprecated level detected in log message, please use new logging.<level> instead for the following log message:", fg="red", level=logging.ERROR, force=True)
level = logging.WARNING
elif level == LogLevels.Info:
log("Deprecated level detected in log message, please use new logging.<level> instead for the following log message:", fg="red", level=logging.ERROR, force=True)
level = logging.INFO
elif level == LogLevels.Debug:
log("Deprecated level detected in log message, please use new logging.<level> instead for the following log message:", fg="red", level=logging.ERROR, force=True)
level = logging.DEBUG
log_adapter = logging.getLogger('archinstall') log_adapter = logging.getLogger('archinstall')
log_fmt = logging.Formatter("[%(levelname)s]: %(message)s") log_fmt = logging.Formatter("[%(levelname)s]: %(message)s")
@ -65,7 +33,7 @@ class SessionLogging:
# Found first reference here: https://stackoverflow.com/questions/7445658/how-to-detect-if-the-console-does-support-ansi-escape-codes-in-python # Found first reference here: https://stackoverflow.com/questions/7445658/how-to-detect-if-the-console-does-support-ansi-escape-codes-in-python
# And re-used this: https://github.com/django/django/blob/master/django/core/management/color.py#L12 # And re-used this: https://github.com/django/django/blob/master/django/core/management/color.py#L12
def supports_color(): def supports_color() -> bool:
""" """
Return True if the running system's terminal supports color, Return True if the running system's terminal supports color,
and False otherwise. and False otherwise.
@ -79,7 +47,7 @@ def supports_color():
# Heavily influenced by: https://github.com/django/django/blob/ae8338daf34fd746771e0678081999b656177bae/django/utils/termcolors.py#L13 # Heavily influenced by: https://github.com/django/django/blob/ae8338daf34fd746771e0678081999b656177bae/django/utils/termcolors.py#L13
# Color options here: https://askubuntu.com/questions/528928/how-to-do-underline-bold-italic-strikethrough-color-background-and-size-i # Color options here: https://askubuntu.com/questions/528928/how-to-do-underline-bold-italic-strikethrough-color-background-and-size-i
def stylize_output(text: str, *opts, **kwargs): def stylize_output(text: str, *opts :str, **kwargs :Union[str, int, Dict[str, Union[str, int]]]) -> str:
opt_dict = {'bold': '1', 'italic': '3', 'underscore': '4', 'blink': '5', 'reverse': '7', 'conceal': '8'} opt_dict = {'bold': '1', 'italic': '3', 'underscore': '4', 'blink': '5', 'reverse': '7', 'conceal': '8'}
color_names = ('black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white') color_names = ('black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white')
foreground = {color_names[x]: '3%s' % x for x in range(8)} foreground = {color_names[x]: '3%s' % x for x in range(8)}
@ -91,9 +59,9 @@ def stylize_output(text: str, *opts, **kwargs):
return '\x1b[%sm' % reset return '\x1b[%sm' % reset
for k, v in kwargs.items(): for k, v in kwargs.items():
if k == 'fg': if k == 'fg':
code_list.append(foreground[v]) code_list.append(foreground[str(v)])
elif k == 'bg': elif k == 'bg':
code_list.append(background[v]) code_list.append(background[str(v)])
for o in opts: for o in opts:
if o in opt_dict: if o in opt_dict:
code_list.append(opt_dict[o]) code_list.append(opt_dict[o])
@ -102,7 +70,7 @@ def stylize_output(text: str, *opts, **kwargs):
return '%s%s' % (('\x1b[%sm' % ';'.join(code_list)), text or '') return '%s%s' % (('\x1b[%sm' % ';'.join(code_list)), text or '')
def log(*args, **kwargs): def log(*args :str, **kwargs :Union[str, int, Dict[str, Union[str, int]]]) -> None:
string = orig_string = ' '.join([str(x) for x in args]) string = orig_string = ' '.join([str(x) for x in args])
# Attempt to colorize the output if supported # Attempt to colorize the output if supported
@ -132,42 +100,10 @@ def log(*args, **kwargs):
with open(absolute_logfile, 'a') as log_file: with open(absolute_logfile, 'a') as log_file:
log_file.write(f"{orig_string}\n") log_file.write(f"{orig_string}\n")
# If we assigned a level, try to log it to systemd's journald. Journald.log(string, level=int(str(kwargs.get('level', logging.INFO))))
# Unless the level is higher than we've decided to output interactively.
# (Remember, log files still get *ALL* the output despite level restrictions)
if 'level' in kwargs:
# For backwards compatibility, convert old style log-levels
# to logging levels (and warn about deprecated usage)
# There's some code re-usage here but that should be fine.
# TODO: Remove these in a few versions:
if kwargs['level'] == LogLevels.Critical:
log("Deprecated level detected in log message, please use new logging.<level> instead for the following log message:", fg="red", level=logging.ERROR, force=True)
kwargs['level'] = logging.CRITICAL
elif kwargs['level'] == LogLevels.Error:
log("Deprecated level detected in log message, please use new logging.<level> instead for the following log message:", fg="red", level=logging.ERROR, force=True)
kwargs['level'] = logging.ERROR
elif kwargs['level'] == LogLevels.Warning:
log("Deprecated level detected in log message, please use new logging.<level> instead for the following log message:", fg="red", level=logging.ERROR, force=True)
kwargs['level'] = logging.WARNING
elif kwargs['level'] == LogLevels.Info:
log("Deprecated level detected in log message, please use new logging.<level> instead for the following log message:", fg="red", level=logging.ERROR, force=True)
kwargs['level'] = logging.INFO
elif kwargs['level'] == LogLevels.Debug:
log("Deprecated level detected in log message, please use new logging.<level> instead for the following log message:", fg="red", level=logging.ERROR, force=True)
kwargs['level'] = logging.DEBUG
if kwargs['level'] < storage.get('LOG_LEVEL', logging.INFO) and 'force' not in kwargs:
# Level on log message was Debug, but output level is set to Info.
# In that case, we'll drop it.
return None
try:
Journald.log(string, level=kwargs.get('level', logging.INFO))
except ModuleNotFoundError:
pass # Ignore writing to journald
# Finally, print the log unless we skipped it based on level. # Finally, print the log unless we skipped it based on level.
# We use sys.stdout.write()+flush() instead of print() to try and # We use sys.stdout.write()+flush() instead of print() to try and
# fix issue #94 # fix issue #94
sys.stdout.write(f"{string}\n") sys.stdout.write(f"{string}\n")
sys.stdout.flush() sys.stdout.flush()