Source code for metaheuristic_designer.checkpointer
"""
Module for checkpointing and resuming optimisation runs.
"""
from __future__ import annotations
import logging
from pickle import PicklingError
from typing import TYPE_CHECKING, Optional
import cloudpickle
import os
import time
from .reporters import create_reporter
from .reporter import Reporter
if TYPE_CHECKING:
from .algorithm import Algorithm
logger = logging.getLogger(__name__)
[docs]
class Checkpointer:
"""Periodically save and restore the state of an optimisation run.
The checkpointer can be triggered by iteration count, elapsed
wall-clock time, or both. It writes the entire
:class:`Algorithm` object to disk using ``cloudpickle``, so that
an interrupted run can be resumed later without losing progress.
Parameters
----------
checkpoint_file : str
Path to the file where the checkpoint will be saved (e.g.,
``"run.pkl"``).
iteration_frequency : int, optional
Save a checkpoint every *n* iterations.
time_frequency : float, optional
Save a checkpoint when at least this many seconds have
elapsed since the last save.
Notes
-----
At least one of *iteration_frequency* or *time_frequency* must
be provided; otherwise the checkpointer does nothing and logs a
warning.
"""
def __init__(self, checkpoint_file: str, iteration_frequency: Optional[int] = None, time_frequency: Optional[float] = None):
self.checkpoint_file = checkpoint_file
if iteration_frequency is None and time_frequency is None:
logger.warning("Checkpointing frequency not configured. No checkpoints will happen.")
self.iteration_frequency = iteration_frequency
self.time_frequency = time_frequency
self.timer = time.time()
[docs]
def restart(self):
"""Reset the internal timer so that time-based checkpoints are measured from this moment onward."""
self.timer = time.time()
[docs]
def checkpoint(self, algorithm: Algorithm):
"""Evaluate whether a checkpoint should be saved, and perform the save
if necessary.
Parameters
----------
algorithm : Algorithm
The running algorithm instance.
"""
iterations = algorithm.stopping_condition.iterations
saving_iteration = self.iteration_frequency is not None and iterations % self.iteration_frequency == 0 and iterations > 0
saving_time = self.time_frequency is not None and time.time() - self.timer > self.time_frequency
if not (saving_iteration or saving_time):
return
if saving_time:
self.timer = time.time()
self.save(algorithm)
[docs]
def save(self, algorithm: Algorithm):
"""Serialize the algorithm to disk using cloudpickle.
A temporary file is written first and atomically moved to the
final location, preventing corruption if the process crashes
mid-write. The reporter, parallel flag, and
the checkpointer itself are temporarily removed before pickling
to avoid serialisation issues, and then restored.
Parameters
----------
algorithm : Algorithm
The algorithm to save.
"""
# Temporarily remove problematic components for serialization.
reporter = algorithm.reporter
is_parallel = algorithm.parallel
checkpointer = algorithm.checkpointer
algorithm.reporter = None
algorithm.parallel = False
algorithm.checkpointer = None
# Store checkpoint to a temp file without overwriting the previous one yet
try:
tmp_file = self.checkpoint_file + ".tmp"
with open(tmp_file, "wb") as f:
cloudpickle.dump(algorithm, f, protocol=5)
# Once we know the checkpoint has finished writing we can replace the preivous one.
os.replace(tmp_file, self.checkpoint_file)
except (OSError, PermissionError, PicklingError, TypeError, MemoryError) as e:
logger.error("Failed to save checkpoint: %s", e)
finally:
# Restore dropped attributes
algorithm.reporter = reporter
algorithm.parallel = is_parallel
algorithm.checkpointer = checkpointer
[docs]
def load(
self,
file_name: Optional[str] = None,
reporter: Reporter | str = "silent",
parallel: bool = False,
) -> Algorithm:
"""Restore a previously saved algorithm from a checkpoint file.
Parameters
----------
file_name : str, optional
Path to the checkpoint file. If not provided, the path
given at construction is used.
reporter : Reporter or str, optional
Reporter to attach to the restored algorithm (a
:class:`Reporter` instance or a string like ``"tqdm"``,
``"silent"``). Default is ``"silent"``.
parallel : bool, optional
Whether parallel evaluation should be enabled after
restoration. Default is ``False``.
Returns
-------
Algorithm
The deserialized algorithm, ready to continue from where it
was saved. Ensure you run the algorithm with `.resume()` so
data is not lost.
"""
if file_name is None:
file_name = self.checkpoint_file
with open(file_name, "rb") as f:
algorithm: Algorithm = cloudpickle.load(f)
if isinstance(reporter, str):
reporter = create_reporter(reporter)
algorithm.reporter = reporter
algorithm.checkpointer = self
algorithm.parallel = parallel
return algorithm