Fixing some mypy complaints (#780)

* Fixed some mypy issues regarding SysCommand* and logging
* Fixed imports and undefined variable
This commit is contained in:
Anton Hvornum 2021-12-02 20:20:31 +00:00 committed by GitHub
parent 908c7b8cc0
commit b1b820f4cb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 122 additions and 149 deletions

View File

@ -1,3 +1,5 @@
from typing import Optional
class RequirementError(BaseException):
pass
@ -15,7 +17,7 @@ class ProfileError(BaseException):
class SysCallError(BaseException):
def __init__(self, message, exit_code):
def __init__(self, message :str, exit_code :Optional[int]) -> None:
super(SysCallError, self).__init__(message)
self.message = message
self.exit_code = exit_code

View File

@ -9,10 +9,11 @@ import string
import sys
import time
from datetime import datetime, date
from typing import Union
try:
from typing import Callable, Optional, Dict, Any, List, Union, Iterator
if sys.platform == 'linux':
from select import epoll, EPOLLIN, EPOLLHUP
except:
else:
import select
EPOLLIN = 0
EPOLLHUP = 0
@ -22,20 +23,20 @@ except:
Create a epoll() implementation that simulates the epoll() behavior.
This so that the rest of the code doesn't need to worry weither we're using select() or epoll().
"""
def __init__(self):
self.sockets = {}
self.monitoring = {}
def __init__(self) -> None:
self.sockets: Dict[str, Any] = {}
self.monitoring: Dict[int, Any] = {}
def unregister(self, fileno, *args, **kwargs):
def unregister(self, fileno :int, *args :List[Any], **kwargs :Dict[str, Any]) -> None:
try:
del(self.monitoring[fileno])
except:
pass
def register(self, fileno, *args, **kwargs):
def register(self, fileno :int, *args :int, **kwargs :Dict[str, Any]) -> None:
self.monitoring[fileno] = True
def poll(self, timeout=0.05, *args, **kwargs):
def poll(self, timeout: float = 0.05, *args :str, **kwargs :Dict[str, Any]) -> List[Any]:
try:
return [[fileno, 1] for fileno in select.select(list(self.monitoring.keys()), [], [], timeout)[0]]
except OSError:
@ -66,13 +67,13 @@ def multisplit(s, splitters):
s = ns
return s
def locate_binary(name):
def locate_binary(name :str) -> str:
for PATH in os.environ['PATH'].split(':'):
for root, folders, files in os.walk(PATH):
for file in files:
if file == name:
return os.path.join(root, file)
break # Don't recurse
break # Don't recurse
raise RequirementError(f"Binary {name} does not exist.")
@ -157,7 +158,14 @@ class UNSAFE_JSON(json.JSONEncoder, json.JSONDecoder):
return super(UNSAFE_JSON, self).encode(self._encode(obj))
class SysCommandWorker:
def __init__(self, cmd, callbacks=None, peak_output=False, environment_vars=None, logfile=None, working_directory='./'):
def __init__(self,
cmd :Union[str, List[str]],
callbacks :Optional[Dict[str, Any]] = None,
peak_output :Optional[bool] = False,
environment_vars :Optional[Dict[str, Any]] = None,
logfile :Optional[None] = None,
working_directory :Optional[str] = './'):
if not callbacks:
callbacks = {}
if not environment_vars:
@ -166,6 +174,7 @@ class SysCommandWorker:
if type(cmd) is str:
cmd = shlex.split(cmd)
cmd = list(cmd) # This is to please mypy
if cmd[0][0] != '/' and cmd[0][:2] != './':
# "which" doesn't work as it's a builtin to bash.
# It used to work, but for whatever reason it doesn't anymore.
@ -179,15 +188,15 @@ class SysCommandWorker:
self.logfile = logfile
self.working_directory = working_directory
self.exit_code = None
self.exit_code :Optional[int] = None
self._trace_log = b''
self._trace_log_pos = 0
self.poll_object = epoll()
self.child_fd = None
self.started = None
self.ended = None
self.child_fd :Optional[int] = None
self.started :Optional[float] = None
self.ended :Optional[float] = None
def __contains__(self, key: bytes):
def __contains__(self, key: bytes) -> bool:
"""
Contains will also move the current buffert position forward.
This is to avoid re-checking the same data when looking for output.
@ -199,21 +208,21 @@ class SysCommandWorker:
return contains
def __iter__(self, *args, **kwargs):
def __iter__(self, *args :str, **kwargs :Dict[str, Any]) -> Iterator[bytes]:
for line in self._trace_log[self._trace_log_pos:self._trace_log.rfind(b'\n')].split(b'\n'):
if line:
yield line + b'\n'
self._trace_log_pos = self._trace_log.rfind(b'\n')
def __repr__(self):
def __repr__(self) -> str:
self.make_sure_we_are_executing()
return str(self._trace_log)
def __enter__(self):
def __enter__(self) -> 'SysCommandWorker':
return self
def __exit__(self, *args):
def __exit__(self, *args :str) -> None:
# b''.join(sys_command('sync')) # No need to, since the underlying fs() object will call sync.
# TODO: https://stackoverflow.com/questions/28157929/how-to-safely-handle-an-exception-inside-a-context-manager
@ -233,9 +242,9 @@ class SysCommandWorker:
log(args[1], level=logging.ERROR, fg='red')
if self.exit_code != 0:
raise SysCallError(f"{self.cmd} exited with abnormal exit code: {self.exit_code}")
raise SysCallError(f"{self.cmd} exited with abnormal exit code: {self.exit_code}", self.exit_code)
def is_alive(self):
def is_alive(self) -> bool:
self.poll()
if self.started and self.ended is None:
@ -243,22 +252,26 @@ class SysCommandWorker:
return False
def write(self, data: bytes, line_ending=True):
def write(self, data: bytes, line_ending :bool = True) -> int:
assert type(data) == bytes # TODO: Maybe we can support str as well and encode it
self.make_sure_we_are_executing()
os.write(self.child_fd, data + (b'\n' if line_ending else b''))
if self.child_fd:
return os.write(self.child_fd, data + (b'\n' if line_ending else b''))
def make_sure_we_are_executing(self):
return 0
def make_sure_we_are_executing(self) -> bool:
if not self.started:
return self.execute()
return True
def tell(self) -> int:
self.make_sure_we_are_executing()
return self._trace_log_pos
def seek(self, pos):
def seek(self, pos :int) -> None:
self.make_sure_we_are_executing()
# Safety check to ensure 0 < pos < len(tracelog)
self._trace_log_pos = min(max(0, pos), len(self._trace_log))
@ -271,39 +284,41 @@ class SysCommandWorker:
except UnicodeDecodeError:
return False
sys.stdout.write(output)
sys.stdout.write(str(output))
sys.stdout.flush()
return True
def poll(self):
def poll(self) -> None:
self.make_sure_we_are_executing()
got_output = False
for fileno, event in self.poll_object.poll(0.1):
try:
output = os.read(self.child_fd, 8192)
got_output = True
self.peak(output)
self._trace_log += output
except OSError:
self.ended = time.time()
break
if self.ended or (got_output is False and pid_exists(self.pid) is False):
self.ended = time.time()
try:
self.exit_code = os.waitpid(self.pid, 0)[1]
except ChildProcessError:
if self.child_fd:
got_output = False
for fileno, event in self.poll_object.poll(0.1):
try:
self.exit_code = os.waitpid(self.child_fd, 0)[1]
output = os.read(self.child_fd, 8192)
got_output = True
self.peak(output)
self._trace_log += output
except OSError:
self.ended = time.time()
break
if self.ended or (got_output is False and pid_exists(self.pid) is False):
self.ended = time.time()
try:
self.exit_code = os.waitpid(self.pid, 0)[1]
except ChildProcessError:
self.exit_code = 1
try:
self.exit_code = os.waitpid(self.child_fd, 0)[1]
except ChildProcessError:
self.exit_code = 1
def execute(self) -> bool:
import pty
if (old_dir := os.getcwd()) != self.working_directory:
os.chdir(self.working_directory)
os.chdir(str(self.working_directory))
# Note: If for any reason, we get a Python exception between here
# and until os.close(), the traceback will get locked inside
@ -320,7 +335,7 @@ class SysCommandWorker:
except PermissionError:
pass
os.execve(self.cmd[0], self.cmd, {**os.environ, **self.environment_vars})
os.execve(self.cmd[0], list(self.cmd), {**os.environ, **self.environment_vars})
if storage['arguments'].get('debug'):
log(f"Executing: {self.cmd}", level=logging.DEBUG)
@ -334,15 +349,23 @@ class SysCommandWorker:
return True
def decode(self, encoding='UTF-8'):
def decode(self, encoding :str = 'UTF-8') -> str:
return self._trace_log.decode(encoding)
class SysCommand:
def __init__(self, cmd, callback=None, start_callback=None, peak_output=False, environment_vars=None, working_directory='./'):
def __init__(self,
cmd :Union[str, List[str]],
callbacks :Optional[Dict[str, Callable[[Any], Any]]] = None,
start_callback :Optional[Callable[[Any], Any]] = None,
peak_output :Optional[bool] = False,
environment_vars :Optional[Dict[str, Any]] = None,
working_directory :Optional[str] = './'):
_callbacks = {}
if callback:
_callbacks['on_end'] = callback
if callbacks:
for hook, func in callbacks.items():
_callbacks[hook] = func
if start_callback:
_callbacks['on_start'] = start_callback
@ -352,26 +375,28 @@ class SysCommand:
self.environment_vars = environment_vars
self.working_directory = working_directory
self.session = None
self.session :Optional[SysCommandWorker] = None
self.create_session()
def __enter__(self):
def __enter__(self) -> Optional[SysCommandWorker]:
return self.session
def __exit__(self, *args, **kwargs):
def __exit__(self, *args :str, **kwargs :Dict[str, Any]) -> None:
# b''.join(sys_command('sync')) # No need to, since the underlying fs() object will call sync.
# TODO: https://stackoverflow.com/questions/28157929/how-to-safely-handle-an-exception-inside-a-context-manager
if len(args) >= 2 and args[1]:
log(args[1], level=logging.ERROR, fg='red')
def __iter__(self, *args, **kwargs):
def __iter__(self, *args :List[Any], **kwargs :Dict[str, Any]) -> Iterator[bytes]:
if self.session:
for line in self.session:
yield line
for line in self.session:
yield line
def __getitem__(self, key):
if type(key) is slice:
def __getitem__(self, key :slice) -> Optional[bytes]:
if not self.session:
raise KeyError(f"SysCommand() does not have an active session.")
elif type(key) is slice:
start = key.start if key.start else 0
end = key.stop if key.stop else len(self.session._trace_log)
@ -379,10 +404,12 @@ class SysCommand:
else:
raise ValueError("SysCommand() doesn't have key & value pairs, only slices, SysCommand('ls')[:10] as an example.")
def __repr__(self, *args, **kwargs):
return self.session._trace_log.decode('UTF-8')
def __repr__(self, *args :List[Any], **kwargs :Dict[str, Any]) -> str:
if self.session:
return self.session._trace_log.decode('UTF-8')
return ''
def __json__(self):
def __json__(self) -> Dict[str, Union[str, bool, List[str], Dict[str, Any], Optional[bool], Optional[Dict[str, Any]]]]:
return {
'cmd': self.cmd,
'callbacks': self._callbacks,
@ -391,7 +418,7 @@ class SysCommand:
'session': True if self.session else False
}
def create_session(self):
def create_session(self) -> bool:
if self.session:
return True
@ -406,16 +433,23 @@ class SysCommand:
return True
def decode(self, fmt='UTF-8'):
return self.session._trace_log.decode(fmt)
def decode(self, fmt :str = 'UTF-8') -> Optional[str]:
if self.session:
return self.session._trace_log.decode(fmt)
return None
@property
def exit_code(self):
return self.session.exit_code
def exit_code(self) -> Optional[int]:
if self.session:
return self.session.exit_code
else:
return None
@property
def trace_log(self):
return self.session._trace_log
def trace_log(self) -> Optional[bytes]:
if self.session:
return self.session._trace_log
return None
def prerequisite_check():
@ -428,7 +462,8 @@ def prerequisite_check():
def reboot():
SysCommand("/usr/bin/reboot")
def pid_exists(pid: int):
def pid_exists(pid: int) -> bool:
try:
return any(subprocess.check_output(['/usr/bin/ps', '--no-headers', '-o', 'pid', '-p', str(pid)]).strip())
except subprocess.CalledProcessError:

View File

@ -1,51 +1,19 @@
import abc
import logging
import os
import sys
from pathlib import Path
from typing import Dict, Union
from .storage import storage
# TODO: use logging's built in levels instead.
# Although logging is threaded and I wish to avoid that.
# It's more Pythonistic or w/e you want to call it.
class LogLevels:
Critical = 0b001
Error = 0b010
Warning = 0b011
Info = 0b101
Debug = 0b111
class Journald(dict):
class Journald:
@staticmethod
@abc.abstractmethod
def log(message, level=logging.DEBUG):
def log(message :str, level :int = logging.DEBUG) -> None:
try:
import systemd.journal # type: ignore
except ModuleNotFoundError:
return False
# For backwards compatibility, convert old style log-levels
# to logging levels (and warn about deprecated usage)
# There's some code re-usage here but that should be fine.
# TODO: Remove these in a few versions:
if level == LogLevels.Critical:
log("Deprecated level detected in log message, please use new logging.<level> instead for the following log message:", fg="red", level=logging.ERROR, force=True)
level = logging.CRITICAL
elif level == LogLevels.Error:
log("Deprecated level detected in log message, please use new logging.<level> instead for the following log message:", fg="red", level=logging.ERROR, force=True)
level = logging.ERROR
elif level == LogLevels.Warning:
log("Deprecated level detected in log message, please use new logging.<level> instead for the following log message:", fg="red", level=logging.ERROR, force=True)
level = logging.WARNING
elif level == LogLevels.Info:
log("Deprecated level detected in log message, please use new logging.<level> instead for the following log message:", fg="red", level=logging.ERROR, force=True)
level = logging.INFO
elif level == LogLevels.Debug:
log("Deprecated level detected in log message, please use new logging.<level> instead for the following log message:", fg="red", level=logging.ERROR, force=True)
level = logging.DEBUG
return None
log_adapter = logging.getLogger('archinstall')
log_fmt = logging.Formatter("[%(levelname)s]: %(message)s")
@ -65,7 +33,7 @@ class SessionLogging:
# Found first reference here: https://stackoverflow.com/questions/7445658/how-to-detect-if-the-console-does-support-ansi-escape-codes-in-python
# And re-used this: https://github.com/django/django/blob/master/django/core/management/color.py#L12
def supports_color():
def supports_color() -> bool:
"""
Return True if the running system's terminal supports color,
and False otherwise.
@ -79,7 +47,7 @@ def supports_color():
# Heavily influenced by: https://github.com/django/django/blob/ae8338daf34fd746771e0678081999b656177bae/django/utils/termcolors.py#L13
# Color options here: https://askubuntu.com/questions/528928/how-to-do-underline-bold-italic-strikethrough-color-background-and-size-i
def stylize_output(text: str, *opts, **kwargs):
def stylize_output(text: str, *opts :str, **kwargs :Union[str, int, Dict[str, Union[str, int]]]) -> str:
opt_dict = {'bold': '1', 'italic': '3', 'underscore': '4', 'blink': '5', 'reverse': '7', 'conceal': '8'}
color_names = ('black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white')
foreground = {color_names[x]: '3%s' % x for x in range(8)}
@ -91,9 +59,9 @@ def stylize_output(text: str, *opts, **kwargs):
return '\x1b[%sm' % reset
for k, v in kwargs.items():
if k == 'fg':
code_list.append(foreground[v])
code_list.append(foreground[str(v)])
elif k == 'bg':
code_list.append(background[v])
code_list.append(background[str(v)])
for o in opts:
if o in opt_dict:
code_list.append(opt_dict[o])
@ -102,7 +70,7 @@ def stylize_output(text: str, *opts, **kwargs):
return '%s%s' % (('\x1b[%sm' % ';'.join(code_list)), text or '')
def log(*args, **kwargs):
def log(*args :str, **kwargs :Union[str, int, Dict[str, Union[str, int]]]) -> None:
string = orig_string = ' '.join([str(x) for x in args])
# Attempt to colorize the output if supported
@ -132,42 +100,10 @@ def log(*args, **kwargs):
with open(absolute_logfile, 'a') as log_file:
log_file.write(f"{orig_string}\n")
# If we assigned a level, try to log it to systemd's journald.
# Unless the level is higher than we've decided to output interactively.
# (Remember, log files still get *ALL* the output despite level restrictions)
if 'level' in kwargs:
# For backwards compatibility, convert old style log-levels
# to logging levels (and warn about deprecated usage)
# There's some code re-usage here but that should be fine.
# TODO: Remove these in a few versions:
if kwargs['level'] == LogLevels.Critical:
log("Deprecated level detected in log message, please use new logging.<level> instead for the following log message:", fg="red", level=logging.ERROR, force=True)
kwargs['level'] = logging.CRITICAL
elif kwargs['level'] == LogLevels.Error:
log("Deprecated level detected in log message, please use new logging.<level> instead for the following log message:", fg="red", level=logging.ERROR, force=True)
kwargs['level'] = logging.ERROR
elif kwargs['level'] == LogLevels.Warning:
log("Deprecated level detected in log message, please use new logging.<level> instead for the following log message:", fg="red", level=logging.ERROR, force=True)
kwargs['level'] = logging.WARNING
elif kwargs['level'] == LogLevels.Info:
log("Deprecated level detected in log message, please use new logging.<level> instead for the following log message:", fg="red", level=logging.ERROR, force=True)
kwargs['level'] = logging.INFO
elif kwargs['level'] == LogLevels.Debug:
log("Deprecated level detected in log message, please use new logging.<level> instead for the following log message:", fg="red", level=logging.ERROR, force=True)
kwargs['level'] = logging.DEBUG
if kwargs['level'] < storage.get('LOG_LEVEL', logging.INFO) and 'force' not in kwargs:
# Level on log message was Debug, but output level is set to Info.
# In that case, we'll drop it.
return None
try:
Journald.log(string, level=kwargs.get('level', logging.INFO))
except ModuleNotFoundError:
pass # Ignore writing to journald
Journald.log(string, level=int(str(kwargs.get('level', logging.INFO))))
# Finally, print the log unless we skipped it based on level.
# We use sys.stdout.write()+flush() instead of print() to try and
# fix issue #94
sys.stdout.write(f"{string}\n")
sys.stdout.flush()
sys.stdout.flush()