Source code for refit.task

from __future__ import annotations

from abc import abstractmethod
import asyncio
import time
import typing as t

from termcolor import colored

from .mixins.apt import AptMixin
from .mixins.docker import DockerMixin
from .mixins.file import FileMixin
from .mixins.path import PathMixin
from .mixins.python import PythonMixin
from .mixins.template import TemplateMixin

    from .host import Host
    from .registry import HostRegistry

[docs]class Task( AptMixin, DockerMixin, FileMixin, PathMixin, PythonMixin, TemplateMixin ): tags: t.Iterable[str] = ["all"] sub_tasks: t.Iterable[Task] = [] def __init__(self, host_class: t.Type[Host]): self.host_class = host_class
[docs] @classmethod async def create( cls, host_registry: HostRegistry, environment: str ) -> None: """ Creates and runs a task for all matching hosts. """ host_classes = host_registry.get_host_classes( tags=cls.tags, environment=environment ) for host_class in host_classes: host_class.start_connection_pool() await asyncio.gather( *[ cls(host_class=host_class).entrypoint() for host_class in host_classes ] ) for host_class in host_classes: host_class.close_connection_pool()
[docs] async def entrypoint(self) -> None: """ Kicks off the task, along with printing some info. """ message = f"{self.__class__.__name__} [{self.host_class.address}]" line_length = int((100 - len(message)) / 2) line = "".join(["-" for i in range(line_length)]) print(colored(f"{line} {message} {line}", "cyan")) await
[docs] @abstractmethod async def run(self) -> None: """ Override in subclasses. This is what does the actual work in the task, and is awaited when the Task is run. """ pass
[docs] async def raw(self, command: str, raise_exception=True): """ Execute a raw shell command on the remote server. """ return await self._execute_command(command, raise_exception)
########################################################################### def _print_command(self, command: str) -> None: print(colored(f"{command}", "green")) async def _execute_command(self, command: str, raise_exception=True): """ Runs the command on the host. """ started_at = time.time() connection = await self.host_class.get_connection() result = await command, # check=True ) finished_at = time.time() took = round(finished_at - started_at, 4) self._print_command(f"Running: {command}") stdout = colored(result.stdout, "magenta") stderr = colored(result.stderr, "red") print(f"Took: {took} seconds\n{stdout}\n{stderr}\n") if (result.exit_status == 1) and raise_exception: raise Exception(f"Command - {command} returned 1 result code!") return result
[docs]class Concurrent(Task): """ Bundles several tasks to be run concurrently. """
[docs] async def run(self): await asyncio.gather( *[ task(host_class=self.host_class).run() for task in self.sub_tasks ] )
[docs]def new_gathered_task(tasks: t.Iterable[t.Type[Task]]) -> t.Type[Concurrent]: """ Task definitions are classes, not instances, hence why we require this. :param tasks: A list of Task classes to execute. """ name = "+".join([task.__name__ for task in tasks]) return type(name, (Concurrent,), {"sub_tasks": tasks})