import logging
from dataclasses import dataclass, fields
from typing import List
from typing import TYPE_CHECKING
import numpy as np
from rtree import index
if TYPE_CHECKING:
from sbp_env.engine import Engine
from sbp_env.planners.basePlanner import Planner
from sbp_env.samplers.baseSampler import Sampler
from sbp_env.utils.planner_registry import PlannerDataPack, SamplerDataPack
[docs]@dataclass
class PlanningOptions:
"""Former: MagicDict"""
planner_data_pack: "PlannerDataPack"
sampler_data_pack: "SamplerDataPack"
skip_optimality: bool
showSampledPoint: bool
scaling: float
goalBias: float
max_number_nodes: int
radius: float
ignore_step_size: bool
always_refresh: bool
rrdt_proposal_distribution: str
start_pt: str
goal_pt: str
epsilon: float = None
goal_radius: float = None
as_radian: bool = False
simplify_solution: bool = True
no_display: bool = False
first_solution: bool = False
output_dir: str = None
save_output: str = None
rover_arm_robot_lengths: str = None
sampler: "Sampler" = None
planner: "Planner" = None
engine: "Engine" = None
def as_dict(self):
out = {field.name: getattr(self, field.name) for field in fields(self)}
out.update(self.extra_options)
return out
def __post_init__(self):
self.extra_options = {}
def compute_default_values(self):
if self.epsilon is None:
self.epsilon = 1.5 if self.as_radian else 10.0
if self.radius is None:
self.radius = self.epsilon * 1.1
if self.goal_radius is None:
self.goal_radius = 2 / 3 * self.radius
def add_option(self, **kwargs):
self.extra_options.update(kwargs)
# def __init__(self, *args, **kwargs):
# super().__init__()
# # super().__init__(*args, **kwargs)
# super().__setattr__("__frozen", False)
#
# def __deepcopy__(self, memo):
# cls = self.__class__
# result = cls.__new__(cls)
# super(MagicDict, result).__setattr__(
# "__frozen", self.__getattribute__("__frozen")
# )
# memo[id(self)] = result
# for k, v in self.items():
# result[k] = copy.deepcopy(v, memo)
# return result
# #
# # def __getattr__(self, attr):
# # """This is called what self.attr doesn't exist.
# #
# # :param attr: attribute to access
# # :return: self[attr]
# # """
# # return self[attr]
#
# def __setattr__(self, name, value):
# """This is called `m_dict.attr = XX` is called
#
# :param name: the key of the attribute
# :param value: the value of the attribute
# :return: self[attr]
# """
# # self[name] = value
# #
# # def __setitem__(self, key, value):
# if super().__getattribute__("__frozen") and not hasattr(self, name):
# raise ValueError(
# f"{self.__class__.__name__} is frozen but attempting to add new name "
# f"'{name}' to the dictionary."
# )
# super().__setattr__(name, value)
#
# def freeze(self):
# super().__setattr__("__frozen", True)
#
# def unfreeze(self):
# super().__setattr__("__frozen", False)
def update(self, dictionary: dict):
for k, v in dictionary.items():
setattr(self, k, v)
class Colour:
"""
Convenient class to define some RGB colour
"""
ALPHA_CK = 255, 0, 255
white = 255, 255, 255
black = 20, 20, 40
red = 255, 0, 0
blue = 0, 0, 255
path_blue = 26, 128, 178
green = 0, 150, 0
cyan = 20, 200, 200
orange = 255, 160, 16
@staticmethod
def cAlpha(colour, alpha):
"""Add alpha value to the given colour
:param colour: the base colour
:param alpha: the desire alpha value to be added
:return: colour with alpha
"""
colour = list(colour)
colour.append(alpha)
return colour
[docs]class Node:
r"""Represents a node inside a tree.
:ivar pos: position, a.k.a., the configuration :math:`q \in C` that this node
represents
:ivar cost: a positive float that represents the cost of this node
:ivar parent: parent node
:ivar children: children nodes
:ivar is_start: a flag to indicate this is the start node
:ivar is_goal: a flag to indicate this is the goal node
:vartype pos: :class:`numpy.ndarray`
:vartype cost: float
:vartype parent: :class:`Node`
:vartype children: a list of :class:`Node`
:vartype is_start: bool
:vartype is_goal: bool
"""
def __init__(self, pos: np.ndarray):
"""
:param pos: configuration of this node
"""
self.pos: np.ndarray = np.array(pos)
self.cost: float = 0
self.parent = None
self.children = []
self.is_start = False
self.is_goal = False
def __getitem__(self, x):
return self.pos[x]
def __len__(self):
return len(self.pos)
def __repr__(self):
return f"{self.__class__.__name__}<{self.pos}>"
def __eq__(self, other):
return isinstance(other, self.__class__) and np.all(self.pos == other.pos)
def __hash__(self):
return hash(tuple(self.pos))
class Stats:
r"""
Stores statistics of a planning problem instance
:ivar invalid_samples_connections: the number of invalid samples due to
intermediate connections being invalid
:ivar invalid_samples_obstacles: the number of invalid samples due to
the sampled configurations is invalid, i.e., :math:`q \in C_\text{obs}`
:ivar valid_sample: the number of valid samples, i.e., :math:`q \in C_\text{free}`
:ivar sampledNodes: temporarily list of the recently sampled configurations
:ivar showSampledPoint: A flag to denote whether we should store the list of
recently sampled configurations
:ivar sampler_success: UNDOCUMENTED
:ivar sampler_success_all: UNDOCUMENTED
:ivar sampler_fail: UNDOCUMENTED
:ivar visible_cnt: the number of calls to visibility test in the collision checker
:ivar feasible_cnt: the number of calls to feasibility test in the collision checker
:type invalid_samples_connections: int
:type invalid_samples_obstacles: int
:type valid_sample: int
:type sampledNodes: List[:class:`Node`]
:type showSampledPoint: bool
:type sampler_success: int
:type sampler_success_all: int
:type sampler_fail: int
:type visible_cnt: int
:type feasible_cnt: int
"""
__global_stats_instance = None
@classmethod
def has_instance(cls) -> bool:
return cls.__global_stats_instance is not None
@classmethod
def build_instance(cls, **kwargs) -> "Stats":
if cls.__global_stats_instance is not None:
raise ValueError(
f"There are already an existing instance of Stats! "
f"{cls.get_instance()}"
)
cls.__global_stats_instance = Stats(**kwargs)
return cls.__global_stats_instance
@classmethod
def get_instance(cls) -> "Stats":
if cls.__global_stats_instance is None:
raise ValueError("Stats instance had not been built yet!")
return cls.__global_stats_instance
@classmethod
def clear_instance(cls) -> None:
cls.__global_stats_instance = None
def __init__(self, showSampledPoint=True):
self.invalid_samples_connections = 0
self.invalid_samples_obstacles = 0
self.valid_sample = 0
self.sampledNodes = []
self.showSampledPoint = showSampledPoint
self.sampler_success = 0
self.sampler_success_all = 0
self.sampler_fail = 0
self.visible_cnt = 0
self.feasible_cnt = 0
self.lsampler_restart_counter = 0
self.lsampler_randomwalk_counter = 0
def add_invalid(self, obs):
"""
:param obs:
"""
if obs:
self.invalid_samples_obstacles += 1
else:
self.invalid_samples_connections += 1
def add_free(self):
"""
Increment the free sampled point counter
"""
self.valid_sample += 1
def add_sampled_node(self, pos: np.ndarray):
"""Add a sampled node position
:param pos: the position of a sampled node
"""
# if pygame is not enabled, skip showing sampled point
if not self.showSampledPoint:
return
self.sampledNodes.append(pos)
def __repr__(self):
return "Stats<{}>".format(
"|".join(
f"{attr}={getattr(self, attr)}"
for attr in dir(self)
if not attr.startswith("__") and not callable(getattr(self, attr))
)
)
def update_progress(progress: int, total_num: int, num_of_blocks: int = 10):
"""Print a progress bar
:param progress: the current progress
:param total_num: the total count for the progress
:param num_of_blocks: number of blocks for the progress bar
"""
if not logging.getLogger().isEnabledFor(logging.INFO):
return
percentage = progress / total_num
print(
"\r[{bar:<{num_of_blocks}}] {cur}/{total} {percen:0.1f}%".format(
bar="#" * int(percentage * num_of_blocks),
cur=progress,
total=total_num,
percen=percentage * 100,
num_of_blocks=num_of_blocks,
),
end="",
)
if percentage == 1:
print()
class BFS:
"""Walk through the connected nodes with BFS"""
def __init__(self, node, validNodes):
"""
:param node: the starting node
:param validNodes: the set of valid nodes that this BFS will transverse
"""
self.visitedNodes = set()
self.validNodes = set(validNodes)
self.next_node_to_visit = [node]
self.next_node = None
def visit_node(self, node):
"""Visits the given node
:param node: the node to visit
"""
self.visitedNodes.add(node)
self.next_node_to_visit.extend(node.edges)
# self.next_node_to_visit.extend(node.children)
# try:
# if node.parent is not None:
# self.next_node_to_visit.append(node.parent)
# except AttributeError:
# pass
self.next_node = node
def has_next(self) -> bool:
"""Check whether there's a next node for the BFS search. This function also
performs the actual computation of finding next available node.
:return: whether a next node is available
"""
if self.next_node is not None:
return True
if len(self.next_node_to_visit) < 1:
return False
# get next available node
while True:
_node = self.next_node_to_visit.pop(0)
if _node not in self.visitedNodes and _node in self.validNodes:
# if _node not in self.visitedNodes:
break
if len(self.next_node_to_visit) < 1:
return False
self.visit_node(_node)
return True
def next(self):
"""Get the next node
:return: the next node from the BFS search
"""
if self.next_node is None and not self.has_next():
raise StopIteration("No more node to visit")
node = self.next_node
self.next_node = None
return node
[docs]class Tree:
"""
A tree representation that stores nodes and edges.
"""
def __init__(self, dimension: int):
"""
:param dimension: A positive integer that represents :math:`d`,
the dimensionality of the C-space.
"""
p = index.Property()
p.dimension = dimension
self.V = index.Index(interleaved=True, properties=p)
self.E = {} # edges in form E[child] = parent
[docs] def add_vertex(self, v: Node, pos: np.ndarray) -> None:
"""Add a new vertex to this tree
:param v: Node to be added
:param pos: The configuration :math:`q` that corresponds to the node ``v``
"""
if len(pos) == 2:
# print(v)
# print(pos)
# print(np.tile(pos, 2))
self.V.insert(0, tuple(pos), v)
else:
self.V.insert(0, np.tile(pos, 2), v)
# self.V_raw.append(v)
[docs] def add_edge(self, child: Node, parent: Node) -> None:
"""Add a new edge to this tree
:param child: The child node of the edge
:param parent: The parent node of the edge
"""
self.E[child] = parent
[docs] def nearby(self, x: np.ndarray, n: int) -> List[Node]:
"""Find ``n`` many nearby nodes that are closest to a given position
:param x: Position
:param n: Max number of results
"""
return self.V.nearest(np.tile(x, 2), num_results=n, objects="raw")
[docs] def get_nearest(self, x: np.ndarray) -> Node:
"""Get the closest node
:param x: Position
:return: the closest node
"""
return next(self.nearby(x, 1))
# def connect_to_point(self, tree, x_a, x_b):
# """
# Connect vertex x_a in tree to vertex x_b
# :param tree: int, tree to which to add edge
# :param x_a: tuple, vertex
# :param x_b: tuple, vertex
# :return: bool, True if able to add edge, False if prohibited by an obstacle
# """
# if self.V.count(x_b) == 0 and self.X.collision_free(x_a, x_b, self.r):
# self.add_vertex(tree, x_b)
# self.add_edge(tree, x_b, x_a)
# return True
# return False
# def can_connect_to_goal(self, tree):
# """
# Check if the goal can be connected to the graph
# :param tree: rtree of all Vertices
# :return: True if can be added, False otherwise
# """
# x_nearest = self.get_nearest(tree, self.x_goal)
# if self.x_goal in self.E and x_nearest in self.E[self.x_goal]:
# # tree is already connected to goal using nearest vertex
# return True
# if self.X.collision_free(x_nearest, self.x_goal, self.r): # check if obstacle-free
# return True
# return False
try:
from functools import cached_property
except ImportError:
import functools
class cached_property:
"""Decorator that caches a function's return value each time it is called.
If called later with the same arguments, the cached value is returned, and
not re-evaluated.
"""
def __init__(self, func):
self.func = func
self.cache = {}
def __call__(self, *args):
try:
return self.cache[args]
except KeyError:
value = self.func(*args)
self.cache[args] = value
return value
except TypeError:
# uncachable -- for instance, passing a list as an argument.
# Better to not cache than to blow up entirely.
return self.func(*args)
def __repr__(self):
"""Return the function's docstring."""
return self.func.__doc__
def __get__(self, obj, objtype):
"""Support instance methods."""
return functools.partial(self.__call__, obj)()