Source code for ai2_kit.core.resource_manager

from typing import Dict, List, Tuple, Union, Optional, Sequence
import copy

from .artifact import Artifact, ArtifactMap
from .executor import Executor, ExecutorMap, create_executor


ArtifactOrKey = Union[Artifact, str]

[docs] class ResourceManager: @property def default_executor(self): return self.get_executor() def __init__(self, executor_configs: ExecutorMap, artifacts: ArtifactMap, default_executor: str, ) -> None: self._executor_configs = executor_configs self._default_executor_name = default_executor self._executors: Dict[str, Executor] = dict() # runtime check to ensure quick failure self.default_executor # fill in default values for key, artifact in artifacts.items(): if artifact.executor is None: artifact.executor = self.default_executor.name if artifact.key is None: artifact.key = key self._artifacts = artifacts
[docs] def get_executor(self, name: Optional[str] = None) -> Executor: if name is None: name = self._default_executor_name config = self._executor_configs.get(name) if config is None: raise ValueError( 'Executor with name {} is not found!'.format(name)) if name not in self._executors: executor = create_executor(config, name) executor.init() self._executors[name] = executor return self._executors[name]
[docs] def get_artifact(self, key: str) -> Artifact: # raise error it is by designed, ensure quick failure return self._artifacts[key]
[docs] def get_artifacts(self, keys: List[str]) -> List[Artifact]: return [self.get_artifact(key) for key in keys]
[docs] def resolve_artifact(self, artifact: ArtifactOrKey ) -> List[Artifact]: # TODO: support cross executor data resolve in the future if isinstance(artifact, str): artifact = self.get_artifact(artifact) paths = self.default_executor.resolve_artifact(artifact) result = [Artifact.of( url=path, format=artifact.format, includes=None, # has been consumed attrs=copy.deepcopy(artifact.attrs), executor=self.default_executor.name, ) for path in paths] assert len(result) > 0, f'artifact {artifact} is invalid' return result
[docs] def resolve_artifacts(self, artifacts: Sequence[ArtifactOrKey]) -> List[Artifact]: # flat map return [x for a in artifacts for x in self.resolve_artifact(a)]