Source code for ai2_kit.core.checkpoint

from typing import TypeVar, Union, Callable, NamedTuple
from threading import Lock
import functools
import cloudpickle
import os
import inspect

from .log import get_logger
from .util import to_awaitable, ensure_dir

logger = get_logger(__name__)


[docs] class FnInfo(NamedTuple): fn_name: str args: tuple kwargs: dict call_site: str
KeyFn = Callable[[FnInfo], str] EMPTY = object()
[docs] class CheckpointService: def __init__(self): self.checkpoint_dir = None self._lock = Lock()
[docs] def set_checkpoint_dir(self, checkpoint_dir: str): self.checkpoint_dir = checkpoint_dir try: os.makedirs(checkpoint_dir, exist_ok=True) except FileExistsError: logger.error("ai2-kit use directory as checkpoint since 0.17.0") raise
[docs] def apply_checkpoint(self, key_fn: Union[str, KeyFn], disable = False): call_site = inspect.getframeinfo(inspect.stack()[1][0]) T = TypeVar('T', bound=Callable) def _checkpoint(fn: T) -> T: @functools.wraps(fn) def wrapper(*args, **kwargs): fn_info = FnInfo( fn_name=fn.__name__, args=args, kwargs=kwargs, call_site=f'{call_site.filename}:{call_site.lineno}', ) key = key_fn if isinstance(key_fn, str) else key_fn(fn_info) if disable or self.checkpoint_dir is None: return fn(*args, **kwargs) ret = self._get_checkpoint(key) if ret is not EMPTY: logger.info(f"Hit checkpoint: {key}") return ret ret = fn(*args, **kwargs) if inspect.isawaitable(ret): async def _wrap_fn(): _ret = await ret self._set_checkpoint(key, _ret, fn_info, True) return _ret return _wrap_fn() else: self._set_checkpoint(key, ret, fn_info, False) return ret return wrapper # type: ignore return _checkpoint
def _get_checkpoint(self, key: str): assert self.checkpoint_dir is not None try: with self._lock: checkpoint_file = os.path.join(self.checkpoint_dir, key) if not os.path.exists(checkpoint_file): return EMPTY with open(checkpoint_file, 'rb') as f: value = cloudpickle.load(f) if value['is_awaitable']: return to_awaitable(value['return']) else: return value['return'] except Exception as e: logger.exception(f"Fail to get checkpoint: {key}") return EMPTY def _set_checkpoint(self, key: str, value, info: FnInfo, is_awaitable: bool = False): assert self.checkpoint_dir is not None try: with self._lock: # args, kwargs may contain unpickable objects _checkpoint_data = { 'return': value, 'is_awaitable': is_awaitable, 'info': { 'fn_name': info.fn_name, 'call_site': info.call_site, } } checkpoint_file = os.path.join(self.checkpoint_dir, key) ensure_dir(checkpoint_file) with open(checkpoint_file, 'wb') as f: cloudpickle.dump(_checkpoint_data, f) except Exception as e: logger.exception('Fail to set checkpoint')
checkpoint_service = CheckpointService()
[docs] def set_checkpoint_dir(checkpoint_dir: str): checkpoint_service.set_checkpoint_dir(checkpoint_dir)
[docs] def apply_checkpoint(key_fn: Union[str, KeyFn], disable = False): """ apply checkpoint for function. Note: This checkpoint implementation doesn't support multiprocess. To support multiple process we need to have a dedicated background process to read/write checkpoint, which will require message queue (e.g. nanomsg or nng) to implement it. Example: >>> set_checkpoint_file('/tmp/test.ckpt') >>> task_fn = lambda a, b: a + b >>> checkpoint('task_1+2')(task_fn)(1, 2) """ return checkpoint_service.apply_checkpoint(key_fn, disable)