Source code for production_configuration.generate_production_grid

"""
Module defines the `GridGeneration` class.

Used to generate a grid of simulation points based on flexible axes definitions such
azimuth, zenith angle, night-sky background, and camera offset.
The module handles axis binning, scaling and interpolation of energy thresholds, viewcone,
and radius limits from a lookup table.
Additionally, it allows for converting between Altitude/Azimuth and Right Ascension
Declination coordinates. The resulting grid points are saved to a file.
"""

import json
import logging

import numpy as np
from astropy import units as u
from astropy.coordinates import AltAz, EarthLocation, SkyCoord
from astropy.table import Table
from astropy.units import Quantity
from scipy.interpolate import griddata


[docs] class GridGeneration: """ Defines and generates a grid of simulation points based on flexible axes definitions. This class generates a grid of points for a simulation based on parameters such as azimuth, zenith angle, night-sky background, and camera offset, taking into account axis definitions, scaling, and units and interpolating values for simulations from a lookup table. """ def __init__( self, axes: dict, coordinate_system: str = "zenith_azimuth", observing_location=None, observing_time=None, lookup_table: str | None = None, telescope_ids: list | None = None, ): """ Initialize the grid with the given axes and coordinate system. Parameters ---------- axes : dict Dictionary where each key is the axis name and the value is a dictionary defining the axis properties (range, binning, scaling, etc.). coordinate_system : str, optional The coordinate system for the grid generation (default is 'zenith_azimuth'). observing_location : EarthLocation, optional The location of the observation (latitude, longitude, height). observing_time : Time, optional The time of the observation. If None, coordinate conversion to RA/Dec not working. lookup_table : str, optional Path to the lookup table file (ECSV format). telescope_ids : list of int, optional List of telescope IDs to get the limits for. """ self._logger = logging.getLogger(__name__) self.axes = axes["axes"] if "axes" in axes else axes self.coordinate_system = coordinate_system self.observing_location = ( observing_location if observing_location is not None else EarthLocation(lat=0.0 * u.deg, lon=0.0 * u.deg, height=0 * u.m) ) self.observing_time = observing_time self.lookup_table = lookup_table self.telescope_ids = telescope_ids # Store target values for each axis self.target_values = self._generate_target_values() if self.lookup_table: self._apply_lookup_table_limits() def _generate_target_values(self): """ Generate target axis values and store them as Quantities. Returns ------- dict Dictionary of target values for each axis, stored as Quantity objects. """ target_values = {} for axis_name, axis in self.axes.items(): axis_range = axis["range"] binning = axis["binning"] scaling = axis.get("scaling", "linear") units = axis.get("units", None) if axis_name == "azimuth": # Use circular binning for azimuth values = self.create_circular_binning(axis_range, binning) elif scaling == "log": # Log scaling values = np.logspace(np.log10(axis_range[0]), np.log10(axis_range[1]), binning) elif scaling == "1/cos": # 1/cos scaling cos_min = np.cos(np.radians(axis_range[0])) cos_max = np.cos(np.radians(axis_range[1])) inv_cos_values = np.linspace(1 / cos_min, 1 / cos_max, binning) values = np.degrees(np.arccos(1 / inv_cos_values)) else: # Linear scaling values = np.linspace(axis_range[0], axis_range[1], binning) if units: values = values * u.Unit(units) target_values[axis_name] = values return target_values def _apply_lookup_table_limits(self): """Apply limits from the lookup table and interpolate values.""" lookup_table = Table.read(self.lookup_table, format="ascii.ecsv") matching_rows = [ row for row in lookup_table if set(self.telescope_ids) == set(row["telescope_ids"]) ] if not matching_rows: raise ValueError( f"No matching rows in the lookup table for telescope_ids: {self.telescope_ids}" ) def extract_array(field, transform=lambda x: x): return np.array([transform(row[field]) for row in matching_rows]) zeniths = extract_array("zenith") azimuths = extract_array("azimuth", lambda x: x % 360) nsb_values = extract_array("nsb", lambda x: 1 if x == "dark" else 5) lower_energy_thresholds = extract_array("lower_energy_threshold") upper_radius_thresholds = extract_array("upper_radius_threshold") viewcone_radii = extract_array("viewcone_radius") # Wrap azimuths and repeat others azimuths_wrapped = np.concatenate([azimuths + shift for shift in (0, 360, -360)]) def repeat_3(arr): """Repeat an array three times.""" return np.tile(arr, 3) points = np.column_stack( ( repeat_3(zeniths), azimuths_wrapped, repeat_3(nsb_values), ) ) target_grid = ( np.array( np.meshgrid( self.target_values["zenith_angle"].value, self.target_values["azimuth"].value, self.target_values["nsb"].value, indexing="ij", ) ) .reshape(3, -1) .T ) def interpolate(values): return griddata( points, repeat_3(values), target_grid, method="linear", fill_value=np.nan ).reshape( len(self.target_values["zenith_angle"]), len(self.target_values["azimuth"]), len(self.target_values["nsb"]), ) self.interpolated_limits = { "energy": interpolate(lower_energy_thresholds), "radius": interpolate(upper_radius_thresholds), "viewcone": interpolate(viewcone_radii), }
[docs] def create_circular_binning(self, azimuth_range, num_bins): """ Create bin centers for azimuth angles, handling circular wraparound (0 deg to 360 deg). Parameters ---------- azimuth_range : tuple (min_azimuth, max_azimuth), can wrap around 0 deg. num_bins : int Number of bins. Returns ------- np.ndarray Array of bin centers. """ azimuth_min, azimuth_max = azimuth_range azimuth_min %= 360 # Normalize to [0, 360) azimuth_max %= 360 clockwise_distance = (azimuth_max - azimuth_min) % 360 counterclockwise_distance = (azimuth_min - azimuth_max) % 360 if clockwise_distance <= counterclockwise_distance: bin_centers = ( np.linspace(azimuth_min, azimuth_min + clockwise_distance, num_bins, endpoint=True) % 360 ) else: bin_centers = ( np.linspace( azimuth_min, azimuth_min - counterclockwise_distance, num_bins, endpoint=True ) % 360 ) return bin_centers
[docs] def generate_grid(self) -> list[dict]: """ Generate the grid based on the required axes and include interpolated limits. Takes energy threshold, viewcone, and radius from the interpolated lookup table. Returns ------- list of dict A list of grid points, each represented as a dictionary with axis names as keys and axis values as values. Axis values may include units where defined. """ value_arrays = [value.value for value in self.target_values.values()] units = [value.unit for value in self.target_values.values()] grid = np.meshgrid(*value_arrays, indexing="ij") combinations = np.vstack(list(map(np.ravel, grid))).T grid_points = [] for combination in combinations: grid_point = { key: Quantity(combination[i], units[i]) for i, key in enumerate(self.target_values.keys()) } if "energy" in self.interpolated_limits: zenith_idx = np.searchsorted( self.target_values["zenith_angle"].value, grid_point["zenith_angle"].value ) azimuth_idx = np.searchsorted( self.target_values["azimuth"].value, grid_point["azimuth"].value ) nsb_idx = np.searchsorted(self.target_values["nsb"].value, grid_point["nsb"].value) energy_lower = self.interpolated_limits["energy"][zenith_idx, azimuth_idx, nsb_idx] grid_point["energy_threshold"] = {"lower": energy_lower * u.TeV} if "radius" in self.interpolated_limits: radius_value = self.interpolated_limits["radius"][zenith_idx, azimuth_idx, nsb_idx] grid_point["radius"] = radius_value * u.m if "viewcone" in self.interpolated_limits: viewcone_value = self.interpolated_limits["viewcone"][ zenith_idx, azimuth_idx, nsb_idx ] grid_point["viewcone"] = viewcone_value * u.deg grid_points.append(grid_point) return grid_points
[docs] def convert_altaz_to_radec(self, alt, az): """ Convert Altitude/Azimuth (AltAz) coordinates to Right Ascension/Declination (RA/Dec). Parameters ---------- alt : float Altitude angle in degrees. az : float Azimuth angle in degrees. Returns ------- SkyCoord SkyCoord object containing the RA/Dec coordinates. Raises ------ ValueError If observing_time is not set. """ if self.observing_time is None: raise ValueError( "Observing time is not set. " "Please provide an observing_time to convert coordinates." ) alt_rad = alt.to(u.rad) az_rad = az.to(u.rad) aa = AltAz( alt=alt_rad, az=az_rad, location=self.observing_location, obstime=self.observing_time, ) skycoord = SkyCoord(aa) return skycoord.icrs # Return RA/Dec in ICRS frame
[docs] def convert_coordinates(self, grid_points: list[dict]) -> list[dict]: """ Convert the grid points to RA/Dec coordinates if necessary. Parameters ---------- grid_points : list of dict List of grid points, each represented as a dictionary with axis names as keys and values. Returns ------- list of dict The grid points with converted RA/Dec coordinates. """ if self.coordinate_system == "ra_dec": for point in grid_points: if "zenith_angle" in point and "azimuth" in point: alt = (90.0 * u.deg) - point.pop("zenith_angle") az = point.pop("azimuth") radec = self.convert_altaz_to_radec(alt, az) point["ra"] = radec.ra.deg * u.deg point["dec"] = radec.dec.deg * u.deg return grid_points
[docs] def serialize_grid_points(self, grid_points, output_file=None): """Serialize the grid output and save to a file or print to the console.""" cleaned_points = [] for point in grid_points: cleaned_point = {} for key, value in point.items(): if isinstance(value, dict): # Nested dictionaries cleaned_point[key] = {k: self.serialize_quantity(v) for k, v in value.items()} else: cleaned_point[key] = self.serialize_quantity(value) cleaned_points.append(cleaned_point) output_data = json.dumps(cleaned_points, indent=4) if output_file: with open(output_file, "w", encoding="utf-8") as f: f.write(output_data) self._logger.info(f"Output saved to {output_file}") else: self._logger.info(output_data) return output_data
[docs] def serialize_quantity(self, value): """Serialize Quantity.""" if isinstance(value, u.Quantity): return {"value": value.value, "unit": str(value.unit)} self._logger.warning(f"Unsupported type {type(value)} for serialization. Returning as is.") return value