Source code for visualization.plot_array_layout

#!/usr/bin/python3
"""Plot array elements for a layout."""

from collections import Counter
from typing import NamedTuple

import astropy.units as u
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import numpy as np
from astropy.table import Column
from matplotlib.collections import PatchCollection

from simtools.utils import geometry as transf
from simtools.utils import names
from simtools.visualization import legend_handlers as leg_h


[docs] class PlotBounds(NamedTuple): """Axis-aligned bounds for the layout in meters. Attributes ---------- x_lim : tuple[float, float] Min/max for x (meters). y_lim : tuple[float, float] Min/max for y (meters). """ x_lim: tuple[float, float] y_lim: tuple[float, float]
[docs] def plot_array_layout( telescopes, show_tel_label=False, axes_range=None, marker_scaling=1.0, background_telescopes=None, grayed_out_elements=None, highlighted_elements=None, legend_location="best", bounds_mode="exact", padding=0.1, x_lim=None, y_lim=None, ): """ Plot telescope array layout. Parameters ---------- telescopes : Table Telescope data table. show_tel_label : bool Show telescope labels (default False). axes_range : float or None Axis range, auto if None. marker_scaling : float Marker size scale factor. background_telescopes : Table or None Optional background telescope table. grayed_out_elements : list or None List of telescope names to plot as gray circles. highlighted_elements : list or None List of telescope names to plot with red circles around them. legend_location : str Location of the legend (default "best"). Returns ------- fig : Figure Matplotlib figure object. Other Parameters ---------------- bounds_mode : {"symmetric", "exact"} Controls axis limits calculation. "symmetric" uses +-R where R is the padded maximum extent (default), while "exact" uses individual x/y min/max bounds. padding : float Fractional padding applied around computed extents (used for both modes). x_lim, y_lim : tuple(float, float), optional Explicit axis limits in meters. If provided, these override axes_range and bounds_mode for the respective axis. If only one is provided, the other axis is derived per mode. """ fig, ax = plt.subplots(1) # If explicit limits are provided (one or both), filter patches accordingly filter_x = x_lim filter_y = y_lim patches, plot_range, highlighted_patches, bounds = get_patches( ax, telescopes, show_tel_label, axes_range, marker_scaling, grayed_out_elements, highlighted_elements, filter_x_lim=filter_x, filter_y_lim=filter_y, ) plot_range, bounds = _get_patches_for_background_telescopes( ax, background_telescopes, axes_range, marker_scaling, bounds_mode, plot_range, bounds, filter_x_lim=filter_x, filter_y_lim=filter_y, ) if legend_location != "no_legend": update_legend(ax, telescopes, grayed_out_elements, legend_location) x_lim, y_lim = _get_axis_limits( axes_range, bounds_mode, padding, plot_range, bounds, x_lim, y_lim ) finalize_plot(ax, patches, "Easting [m]", "Northing [m]", x_lim, y_lim, highlighted_patches) return fig
def _get_axis_limits( axes_range, bounds_mode, padding, plot_range, bounds, x_lim_override=None, y_lim_override=None, ): """Get axis limits based on mode and padding.""" def _derive_axis(axis: str) -> tuple[float, float]: if bounds_mode == "exact": if axis == "x": span = bounds.x_lim[1] - bounds.x_lim[0] pad = padding * span return (bounds.x_lim[0] - pad, bounds.x_lim[1] + pad) span = bounds.y_lim[1] - bounds.y_lim[0] pad = padding * span return (bounds.y_lim[0] - pad, bounds.y_lim[1] + pad) # symmetric sym = plot_range padf = max(0.0, min(1.0, float(padding))) if padding is not None else 0.0 sym *= 1.0 + padf return (-sym, sym) # Highest priority: explicit overrides (per axis) if x_lim_override is not None or y_lim_override is not None: x_lim = x_lim_override if x_lim_override is not None else _derive_axis("x") y_lim = y_lim_override if y_lim_override is not None else _derive_axis("y") return x_lim, y_lim if axes_range is not None: return (-axes_range, axes_range), (-axes_range, axes_range) # Derive both axes using selected mode return _derive_axis("x"), _derive_axis("y") def _get_patches_for_background_telescopes( ax, background_telescopes, axes_range, marker_scaling, bounds_mode, plot_range, bounds, filter_x_lim=None, filter_y_lim=None, ): """Get background telescope patches and update plot range/bounds.""" if background_telescopes is None: return plot_range, bounds bg_patches, bg_range, _, bg_bounds = get_patches( ax, background_telescopes, False, axes_range, marker_scaling, None, None, filter_x_lim=filter_x_lim, filter_y_lim=filter_y_lim, ) ax.add_collection(PatchCollection(bg_patches, match_original=True, alpha=0.1)) if axes_range is None: if bounds_mode == "symmetric": plot_range = max(plot_range, bg_range) else: bounds = PlotBounds( x_lim=( min(bounds.x_lim[0], bg_bounds.x_lim[0]), max(bounds.x_lim[1], bg_bounds.x_lim[1]), ), y_lim=( min(bounds.y_lim[0], bg_bounds.y_lim[0]), max(bounds.y_lim[1], bg_bounds.y_lim[1]), ), ) return plot_range, bounds def _apply_limits_filter(telescopes, pos_x, pos_y, filter_x_lim, filter_y_lim): """Filter telescope table and positions by optional axis limits.""" if filter_x_lim is None and filter_y_lim is None: return telescopes, pos_x, pos_y px = np.asarray(pos_x.to_value(u.m)) py = np.asarray(pos_y.to_value(u.m)) mask = np.ones(px.shape, dtype=bool) if filter_x_lim is not None: mask &= (px >= float(filter_x_lim[0])) & (px <= float(filter_x_lim[1])) if filter_y_lim is not None: mask &= (py >= float(filter_y_lim[0])) & (py <= float(filter_y_lim[1])) if mask.size and mask.any(): return telescopes[mask], pos_x[mask], pos_y[mask] # No telescopes within limits return telescopes[:0], pos_x[:0], pos_y[:0]
[docs] def get_patches( ax, telescopes, show_tel_label, axes_range, marker_scaling, grayed_out_elements=None, highlighted_elements=None, filter_x_lim=None, filter_y_lim=None, ): """ Get plot patches and axis range. Parameters ---------- ax : Axes Matplotlib axes object. telescopes : Table Telescope data table. show_tel_label : bool Show telescope labels. axes_range : float or None Axis range, auto if None. marker_scaling : float Marker size scale factor. grayed_out_elements : list or None List of telescope names to plot as gray circles. highlighted_elements : list or None List of telescope names to plot with red circles around them. Returns ------- patches : list List of telescope patches. axes_range : float Calculated or input symmetric axis range (meters). highlighted_patches : list List of highlighted telescope patches. bounds : PlotBounds Min/max for x and y in meters. """ pos_x, pos_y = get_positions(telescopes) tel_table, pos_x, pos_y = _apply_limits_filter( telescopes, pos_x, pos_y, filter_x_lim, filter_y_lim ) tel_table["pos_x_rotated"] = Column(pos_x) tel_table["pos_y_rotated"] = Column(pos_y) patches, radii, highlighted_patches = create_patches( tel_table, marker_scaling, show_tel_label, ax, grayed_out_elements, highlighted_elements ) if len(radii) == 0: r = 0.0 else: radii_q = u.Quantity(radii) r = float(np.nanmax(radii_q).to_value(u.m)) if len(pos_x) == 0: bounds = PlotBounds(x_lim=(0.0, 0.0), y_lim=(0.0, 0.0)) if axes_range: return patches, axes_range, highlighted_patches, bounds return patches, 0.0, highlighted_patches, bounds x_min = float(np.nanmin(pos_x).to_value(u.m)) - r x_max = float(np.nanmax(pos_x).to_value(u.m)) + r y_min = float(np.nanmin(pos_y).to_value(u.m)) - r y_max = float(np.nanmax(pos_y).to_value(u.m)) + r bounds = PlotBounds(x_lim=(x_min, x_max), y_lim=(y_min, y_max)) if axes_range: return patches, axes_range, highlighted_patches, bounds max_x = max(abs(x_min), abs(x_max)) max_y = max(abs(y_min), abs(y_max)) updated_axes_range = max(max_x, max_y) * 1.1 return patches, updated_axes_range, highlighted_patches, bounds
[docs] @u.quantity_input(x=u.m, y=u.m, radius=u.m) def get_telescope_patch(tel_type, x, y, radius, is_grayed_out=False): """ Create patch for a telescope. Parameters ---------- tel_type: str Telescope type. x : Quantity X position. y : Quantity Y position. radius : Quantity Telescope radius. is_grayed_out : bool Whether to plot telescope in gray. Returns ------- patch : Patch Circle or rectangle patch. """ config = leg_h.get_telescope_config(tel_type) x, y, r = x.to(u.m), y.to(u.m), radius.to(u.m) color = "gray" if is_grayed_out else config["color"] fill_flag = True if is_grayed_out else bool(config.get("filled", True)) if config.get("shape", "circle") == "square": return mpatches.Rectangle( ((x - r / 2).value, (y - r / 2).value), width=r.value, height=r.value, fill=fill_flag, color=color, ) if config.get("shape") == "hexagon": return mpatches.RegularPolygon( (x.value, y.value), numVertices=6, radius=r.value * np.sqrt(3) / 2, orientation=np.pi / 6, fill=fill_flag, color=color, ) return mpatches.Circle( (x.value, y.value), radius=r.value, fill=fill_flag, alpha=0.5 if is_grayed_out else 1.0, color=color, )
[docs] def get_positions(telescopes): """ Get X/Y positions depending on coordinate system. For ground coordinates, rotates the positions by 90 degrees. Returns ------- x_rot, y_rot : Quantity Position coordinates. """ if "position_x" in telescopes.colnames: x, y = telescopes["position_x"], telescopes["position_y"] locale_rotate_angle = 90 * u.deg elif "utm_east" in telescopes.colnames: x, y = telescopes["utm_east"], telescopes["utm_north"] locale_rotate_angle = 0 * u.deg else: raise ValueError("Missing required position columns.") return transf.rotate(x, y, locale_rotate_angle) if locale_rotate_angle != 0 else (x, y)
[docs] def create_patches( telescopes, scale, show_label, ax, grayed_out_elements=None, highlighted_elements=None ): """ Create telescope patches and labels. Parameters ---------- telescopes : Table Telescope data table. scale : float Marker size scale factor. show_label : bool Show telescope labels. ax : Axes Matplotlib axes object. grayed_out_elements : list or None List of telescope names to plot as gray circles. highlighted_elements : list or None List of telescope names to plot with red circles around them. Returns ------- patches : list Shape patches. radii : list Telescope radii. highlighted_patches : list List of highlighted telescope patches. """ patches, radii, highlighted_patches = [], [], [] fontsize, scale_factor = (4, 2) if len(telescopes) > 30 else (8, 1) grayed_out_set = set(grayed_out_elements) if grayed_out_elements else set() highlighted_set = set(highlighted_elements) if highlighted_elements else set() for tel in telescopes: name = get_telescope_name(tel) radius = get_sphere_radius(tel) radii.append(radius) try: tel_type = names.get_array_element_type_from_name(name) except ValueError: tel_type = None is_grayed_out = name in grayed_out_set is_highlighted = name in highlighted_set patches.append( get_telescope_patch( tel_type, tel["pos_x_rotated"], tel["pos_y_rotated"], scale_factor * radius * scale, is_grayed_out=is_grayed_out, ) ) if is_highlighted: highlight_patch = mpatches.Circle( (tel["pos_x_rotated"].value, tel["pos_y_rotated"].value), radius=(scale_factor * radius * scale * 4).value, fill=False, color="red", linewidth=1, ) highlighted_patches.append(highlight_patch) if show_label: ax.text( tel["pos_x_rotated"].value, tel["pos_y_rotated"].value + scale_factor * radius.value, name, ha="center", va="bottom", fontsize=fontsize * 0.8, ) return patches, radii, highlighted_patches
[docs] def get_telescope_name(tel): """ Get telescope name. Returns ------- name : str Telescope name or fallback identifier. """ if "telescope_name" in tel.colnames: return tel["telescope_name"] if "asset_code" in tel.colnames and "sequence_number" in tel.colnames: return f"{tel['asset_code']}-{tel['sequence_number']}" return f"tel_{tel.index}"
[docs] def get_sphere_radius(tel): """ Get telescope sphere radius. Returns ------- radius : Quantity Radius with units. """ return tel["sphere_radius"] if "sphere_radius" in tel.colnames else 10.0 * u.m
[docs] def update_legend(ax, telescopes, grayed_out_elements=None, legend_location="best"): """Add legend for telescope types and counts.""" grayed_out_set = set(grayed_out_elements) if grayed_out_elements else set() types = [] for tel in telescopes: tel_name = get_telescope_name(tel) if tel_name not in grayed_out_set: types.append(names.get_array_element_type_from_name(tel_name)) counts = Counter(types) objs, labels = [], [] handler_map = {} for telescope_type in names.get_list_of_array_element_types(): if counts[telescope_type]: objs.append(telescope_type) labels.append(f"{telescope_type} ({counts[telescope_type]})") class BaseLegendHandlerWrapper: # pylint: disable=R0903 """Wrapper for BaseLegendHandler to use in legend.""" def __init__(self, tel_type): self.tel_type = tel_type def legend_artist(self, legend, orig_handle, fontsize, handlebox): handler = leg_h.BaseLegendHandler(self.tel_type) return handler.legend_artist(legend, orig_handle, fontsize, handlebox) handler_map[telescope_type] = BaseLegendHandlerWrapper(telescope_type) ax.legend(objs, labels, handler_map=handler_map, prop={"size": 11}, loc=legend_location)
[docs] def finalize_plot( ax, patches, x_title, y_title, x_lim=None, y_lim=None, highlighted_patches=None, ): """Finalize plot appearance and limits.""" ax.add_collection(PatchCollection(patches, match_original=True)) if highlighted_patches: ax.add_collection(PatchCollection(highlighted_patches, match_original=True)) ax.set(xlabel=x_title, ylabel=y_title) ax.tick_params(labelsize=8) ax.axis("square") if x_lim is not None and y_lim is not None: ax.set_xlim(*x_lim) ax.set_ylim(*y_lim) plt.tight_layout()