Source code for sbp_env.samplers.baseSampler

from abc import ABC
from typing import Tuple, Callable

import numpy as np

from ..utils.common import PlanningOptions, Stats
from ..visualiser import VisualiserSwitcher, BaseSamplerVisualiser


# noinspection PyAttributeOutsideInit
[docs]class Sampler(ABC): """ Abstract base sampler that defines each unique methods that some sampler, but not all samplers, uses. """ GetNextPosReturnType = Tuple[np.ndarray, Callable, Callable] """The return type of :func:`get_next_pos`.""" def __init__(self, **kwargs): super().__init__() self.use_radian = False if "sampler_data_pack" in kwargs: self.visualiser = VisualiserSwitcher.sampler_clname( sampler_instance=self, sampler_data_pack=kwargs["sampler_data_pack"] ) else: # if kwargs does not contains data-pack, it means we do not need to # visualise this sampler (e.g. nested sampler) self.visualiser = BaseSamplerVisualiser()
[docs] def init(self, use_radian: bool = False, **kwargs): """The delayed **initialisation** method :param use_radian: whether this sampler should returns value in radian (as opposite to Euclidean) :param start_pt: the starting configuration for the planning problem :param goal_pt: the goal configuration for the planning problem :type start_pt: :class:`~utils.common.Node` :type goal_pt: :class:`~utils.common.Node` """ self.args = kwargs["args"] self.start_pos = self.args.start_pt.pos self.goal_pos = self.args.goal_pt.pos self.use_radian = use_radian self.visualiser.init(**kwargs)
[docs] def get_next_pos(self, **kwargs): """Retrieve next sampled position :return: a sampled position, a callable to report success, and a callable to report failure """ raise NotImplementedError()
[docs] def get_valid_next_pos(self): """Loop until we find a valid next node. Uses ``get_next_pos`` internally.""" while True: coordinate, report_success, report_fail = self.get_next_pos() Stats.get_instance().add_sampled_node(coordinate) if self.args.engine.cc.feasible(coordinate): return coordinate, report_success, report_fail report_fail(pos=coordinate, obstacle=True) Stats.get_instance().add_invalid(obs=True)
[docs] def set_use_radian(self, value: bool = True): """Set this sampler to use radian or not :param value: the value to set """ self.use_radian = value
[docs] def report_success(self, **kwargs): """Report to the sampler that the last sample was successfully. This function is sampler dependent. :param kwargs: pass through to derived class """ pass
[docs] def report_fail(self, **kwargs): """Report to the sampler that the last sample was unsuccessful. This function is sampler dependent. :param kwargs: pass through to derived class """ pass
[docs] def add_tree_node(self, **kwargs): """Report to the sampler about the last node that was added to the tree :param kwargs: pass through to derived class """ pass
[docs] def add_sample_line(self, **kwargs): """Report to the sampler about the entire line that was sampled last time :param kwargs: pass through to derived class """ pass
@property def name(self) -> str: return self.__class__.__name__