Source code for simulation.src.simulation_evaluation.src.speaker.speakers.speaker

"""Base class for all other speakers."""
import bisect
import functools
import itertools
from dataclasses import dataclass
from typing import Any, Callable, Iterable, List, Tuple

from gazebo_simulation.msg import CarState as CarStateMsg
from simulation_evaluation.msg import Speaker as SpeakerMsg
from simulation_groundtruth.msg import LabeledPolygon as LabeledPolygonMsg
from simulation_groundtruth.msg import Lane as LaneMsg
from simulation_groundtruth.msg import Section as SectionMsg

from simulation.utils.geometry import Line, Polygon, Pose, Vector
from simulation.utils.road.sections.line_tuple import LineTuple

[docs]@dataclass class LabeledSpeakerPolygon: id_: int desc: str frame: Polygon height: float
[docs]class Speaker: """Base class for all speakers. This class is part of the evaluation pipeline used to automatically evaluate the cars behavior in simulation. It converts information about the cars state and the predefined groundtruth into SpeakerMsg's which serve as inputs for the state machines. Information is passed to the speaker by calling the :py:func:`Speaker.listen` function with a new CarState msg. Output can be retrieved in form of Speaker msgs by calling the .speak function. Attributes: sections (Callable[[], List[SectionMsg]]): List of all sections as SectionMsgs get_lanes (Callable[[int], LaneMsg]): Get LaneMsg for a given section get_obstacles (Callable[[int], List[LabeledPolygonMsg]]): Get ObstacleMsg for a given section get_intersection (Callable[[int], Any]): Get intersections for a given section """ def __init__( self, *, section_proxy: Callable[[], List[SectionMsg]], lane_proxy: Callable[[int], LaneMsg], obstacle_proxy: Callable[[int], List[LabeledPolygonMsg]] = None, surface_marking_proxy: Callable[[int], List[LabeledPolygonMsg]] = None, intersection_proxy: Callable[[int], Any] = None, sign_proxy: Callable[[int], List[LabeledPolygonMsg]] = None, ): """Initialize speaker with funtions that can be queried for groundtruth information. Args: section_proxy: Returns all sections when called. lane_proxy: Returns a LaneMsg for each section. obstacle_proxy: Optional function which returns obstacles in a section. surface_marking_proxy: Optional function which returns surface_markings in a section. intersection_proxy: Optional function which returns an IntersectionMsg for a section. (Only makes sense if the section is an intersection.) """ self.sections = section_proxy().sections self.get_lanes = lane_proxy self.get_obstacles = obstacle_proxy self.get_surface_markings = surface_marking_proxy self.get_intersection = intersection_proxy self.get_traffic_signs = sign_proxy
[docs] def listen(self, msg: CarStateMsg): """Receive information about current observations and update internal values.""" # Store car frame self.car_frame = Polygon(msg.frame) self.car_pose = Pose(msg.pose) self.car_speed = abs(Vector(msg.twist.linear))
[docs] def speak(self) -> List[SpeakerMsg]: """Speak about the current observations.""" return []
@functools.cached_property def middle_line(self) -> Line: """Complete middle line.""" # Get the middle line of each seaction middle_line_pieces = (self.get_road_lines( for sec in self.sections) # Sum it up (and start with empty line) return sum(middle_line_pieces, Line()) @functools.cached_property def left_line(self) -> Line: """Complete left line.""" # Get the left line of each seaction left_line_pieces = (self.get_road_lines( for sec in self.sections) # Sum it up (and start with empty line) return sum(left_line_pieces, Line()) @functools.cached_property def right_line(self) -> Line: """Complete right line.""" # Get the right line of each seaction right_line_pieces = (self.get_road_lines( for sec in self.sections) # Sum it up (and start with empty line) return sum(right_line_pieces, Line()) @property def arc_length(self): """Position of the car projected on the middle line (== Length driven so far).""" return self.middle_line.project(self.car_pose.position) @functools.cached_property def section_intervals(self) -> List[Tuple[float, float]]: """Get (start,end) intervals of all sections.""" # First extract the individual lengths of each section lengths = (self.get_road_lines( for sec in self.sections) # The accumulate function computes an inclusive prefix sum prefix = list(itertools.accumulate(lengths)) beginnings = [0] + prefix[:-1] endings = prefix return list(zip(beginnings, endings)) @property def current_section(self): """Get the current section. Returns: SectionMsg of the current. """ section_beginnings = list(interval[0] for interval in self.section_intervals) # Find the current section: the first section which starts before self.arc_length # (This will also return a value if arc_length is outside of the sections) idx = max(bisect.bisect_left(section_beginnings, self.arc_length) - 1, 0) return self.sections[idx]
[docs] @functools.lru_cache def get_road_lines(self, section_id: int) -> LineTuple: """Request and return the road lines of a single section. Args: section_id (int): id of the section Returns: LineTuple of the section as a named tuple. """ msg = self.get_lanes(section_id).lane_msg return LineTuple(Line(msg.left_line), Line(msg.middle_line), Line(msg.right_line))
[docs] def _get_labeled_polygons( self, proxy: Callable[[int], List[LabeledPolygonMsg]], section_id: int ) -> List[LabeledSpeakerPolygon]: """Get all objects inside a section returned by a service. Args: proxy: Service receiver function section_id: id of the section Returns: List of tuples containing id, description, polygon, and height. """ assert isinstance(section_id, int) response = proxy(section_id) return [ LabeledSpeakerPolygon(r.type, r.desc, Polygon(r.frame), r.height) for r in response.polygons ]
[docs] def get_obstacles_in_section(self, section_id: int) -> List[LabeledSpeakerPolygon]: """Get all obstacles in a section.""" return self._get_labeled_polygons(self.get_obstacles, section_id)
[docs] def get_traffic_signs_in_section(self, section_id: int) -> List[LabeledSpeakerPolygon]: """Get all traffic_signs inside section.""" return self._get_labeled_polygons(self.get_traffic_signs, section_id)
[docs] def get_surface_markings_in_section( self, section_id: int ) -> List[LabeledSpeakerPolygon]: """Get all surface_markings inside a section.""" return self._get_labeled_polygons(self.get_surface_markings, section_id)
[docs] def get_interval_for_polygon(self, *polygons: Iterable[Polygon]) -> Tuple[float, float]: """Get start and end points of the polygons. Args: polygon: The polygon. Returns: Start and end point for each polygon as a tuple """ projections = list( self.middle_line.project(p) for polygon in polygons for p in polygon.get_points() ) return (min(projections), max(projections))
[docs] def car_is_inside( self, *polygons: Iterable[Polygon], min_wheel_count: int = None ) -> bool: """Check if the car is currently inside one of the polygons. The car can also be in the union of the provided polygons. Args: polygons: The polygons. min_wheel_count: If provided it is enough for a given number of wheels to be inside the polygons (e.g. 3 wheels inside parking spot) """ if min_wheel_count: return self.wheel_count_inside(*polygons, in_total=True) >= min_wheel_count return ( sum(p.intersection(self.car_frame).area for p in polygons) / self.car_frame.area > 0.99 )
[docs] def wheel_count_inside( self, *polygons: Iterable[Polygon], in_total: bool = False ) -> int: """Get the maximum number of wheels inside one of the polygons. Args: polygons: The polygons. in_total: If true, the number of wheels are summed up """ if len(polygons) == 0: return False frame_points = set(self.car_frame.get_points()) # Loop over each point of the car frame and add it 1 to inside # if the point is in any of the polygons if in_total: inside = ( any(polygon.contains(point) for polygon in polygons) for point in frame_points ) return sum(inside) else: inside = ( sum(polygon.contains(point) for point in frame_points) for polygon in polygons ) return max(inside)
[docs] def car_overlaps_with(self, *polygons: Iterable[Polygon]) -> bool: """Decide if the car overlaps with any of the polygons. Args: polygons: The polygons. """ # True if any of the intersections are positive return any(p.intersection(self.car_frame).area > 0 for p in polygons)