Source code for refit.task

from __future__ import annotations

from abc import abstractmethod
import asyncio
import time
import typing as t

from termcolor import colored

from .mixins.apt import AptMixin
from .mixins.docker import DockerMixin
from .mixins.file import FileMixin
from .mixins.path import PathMixin
from .mixins.python import PythonMixin
from .mixins.template import TemplateMixin

if t.TYPE_CHECKING:
    from .host import Host
    from .registry import HostRegistry


[docs]class Task( AptMixin, DockerMixin, FileMixin, PathMixin, PythonMixin, TemplateMixin ): tags: t.Iterable[str] = ["all"] sub_tasks: t.Iterable[Task] = [] def __init__(self, host_class: t.Type[Host]): self.host_class = host_class
[docs] @classmethod async def create( cls, host_registry: HostRegistry, environment: str ) -> None: """ Creates and runs a task for all matching hosts. """ host_classes = host_registry.get_host_classes( tags=cls.tags, environment=environment ) for host_class in host_classes: host_class.start_connection_pool() await asyncio.gather( *[ cls(host_class=host_class).entrypoint() for host_class in host_classes ] ) for host_class in host_classes: host_class.close_connection_pool()
[docs] async def entrypoint(self) -> None: """ Kicks off the task, along with printing some info. """ message = f"{self.__class__.__name__} [{self.host_class.address}]" line_length = int((100 - len(message)) / 2) line = "".join(["-" for i in range(line_length)]) print(colored(f"{line} {message} {line}", "cyan")) await self.run()
###########################################################################
[docs] @abstractmethod async def run(self) -> None: """ Override in subclasses. This is what does the actual work in the task, and is awaited when the Task is run. """ pass
###########################################################################
[docs] async def raw(self, command: str, raise_exception=True): """ Execute a raw shell command on the remote server. """ return await self._execute_command(command, raise_exception)
########################################################################### def _print_command(self, command: str) -> None: print(colored(f"{command}", "green")) async def _execute_command(self, command: str, raise_exception=True): """ Runs the command on the host. """ started_at = time.time() connection = await self.host_class.get_connection() result = await connection.run( command, # check=True ) finished_at = time.time() took = round(finished_at - started_at, 4) self._print_command(f"Running: {command}") stdout = colored(result.stdout, "magenta") stderr = colored(result.stderr, "red") print(f"Took: {took} seconds\n{stdout}\n{stderr}\n") if (result.exit_status == 1) and raise_exception: raise Exception(f"Command - {command} returned 1 result code!") return result
[docs]class Concurrent(Task): """ Bundles several tasks to be run concurrently. """
[docs] async def run(self): await asyncio.gather( *[ task(host_class=self.host_class).run() for task in self.sub_tasks ] )
[docs]def new_gathered_task(tasks: t.Iterable[t.Type[Task]]) -> t.Type[Concurrent]: """ Task definitions are classes, not instances, hence why we require this. :param tasks: A list of Task classes to execute. """ name = "+".join([task.__name__ for task in tasks]) return type(name, (Concurrent,), {"sub_tasks": tasks})