Source code for production_configuration.interpolation_handler

"""Handle interpolation between multiple StatisticalUncertaintyEvaluator instances."""

import logging

import astropy.units as u
import numpy as np
from scipy.interpolate import griddata

from simtools.production_configuration.derive_production_statistics import (
    ProductionStatisticsDerivator,
)

__all__ = ["InterpolationHandler"]


[docs] class InterpolationHandler: """ Calculate the required events for production via interpolation from a grid. This class provides methods to interpolate production statistics across a grid of parameter values (azimuth, zenith, NSB, offset) and energy. """ def __init__(self, evaluators, metrics: dict, grid_points_production: list): """ Initialize the InterpolationHandler. Parameters ---------- evaluators : list List of StatisticalUncertaintyEvaluator instances. metrics : dict Dictionary of metrics to use for production statistics. grid_points_production : list List of grid points for interpolation, each being a dictionary with keys 'azimuth', 'zenith_angle', 'nsb', 'offset' etc. """ self._logger = logging.getLogger(__name__) self.evaluators = evaluators self.metrics = metrics self.grid_points_production = grid_points_production self._initialize_derivators() self._extract_grid_properties() self.data, self.grid_points = self._build_data_array() self.interpolated_production_statistics = None self.interpolated_production_statistics_with_energy = None self._non_flat_mask = None def _initialize_derivators(self): """Initialize production statistics derivators for all evaluators.""" self.derive_production_statistics = [ ProductionStatisticsDerivator(e, self.metrics) for e in self.evaluators ] self.production_statistics = [ derivator.derive_statistics(return_sum=False) for derivator in self.derive_production_statistics ] self.production_statistics_sum = [ derivator.derive_statistics(return_sum=True) for derivator in self.derive_production_statistics ] def _extract_grid_properties(self): """Extract grid properties from evaluators.""" self.azimuths = [e.grid_point[1].to(u.deg).value for e in self.evaluators] self.zeniths = [e.grid_point[2].to(u.deg).value for e in self.evaluators] self.nsbs = [e.grid_point[3] for e in self.evaluators] self.offsets = [e.grid_point[4].to(u.deg).value for e in self.evaluators] self.energy_grids = [ (e.data["bin_edges_low"][:-1] + e.data["bin_edges_high"][:-1]) / 2 for e in self.evaluators ] self.energy_thresholds = np.array([e.energy_threshold for e in self.evaluators]) # Check if energy grids are consistent if self.energy_grids and not all( np.array_equal(self.energy_grids[0], grid) for grid in self.energy_grids ): self._logger.warning( "Energy grids are not identical across evaluators. " "Using the first evaluator's energy grid for interpolation." ) def _build_data_array(self): """ Build a data array with interpolated values across all dimensions including energy. Returns ------- np.ndarray The data array with interpolated values. np.ndarray The corresponding grid points. """ if not self.evaluators: return np.array([]), np.array([]) flat_data_list = [] flat_grid_points = [] for i, (energy_grid, production_statistics) in enumerate( zip(self.energy_grids, self.production_statistics) ): az = self.azimuths[i] zen = self.zeniths[i] nsb = self.nsbs[i] offset = self.offsets[i] az_array = np.full(len(energy_grid), az) zen_array = np.full(len(energy_grid), zen) nsb_array = np.full(len(energy_grid), nsb) offset_array = np.full(len(energy_grid), offset) grid_points = np.column_stack( [energy_grid.to(u.TeV).value, az_array, zen_array, nsb_array, offset_array] ) flat_grid_points.append(grid_points) flat_data_list.append(production_statistics) flat_grid_points = np.vstack(flat_grid_points) flat_data = np.hstack(flat_data_list) sorted_indices = np.argsort(flat_grid_points[:, 0]) sorted_grid_points = flat_grid_points[sorted_indices] sorted_data = flat_data[sorted_indices] return sorted_data, sorted_grid_points def _remove_flat_dimensions(self, grid_points, threshold=1e-6): """ Identify and remove flat dimensions (dimensions with no variance). Parameters ---------- grid_points : np.ndarray Grid points to analyze. threshold : float, optional Threshold for determining flatness, by default 1e-6 Returns ------- tuple (reduced_grid_points, non_flat_mask) """ if grid_points.size == 0: return grid_points, np.array([], dtype=bool) variance = np.var(grid_points, axis=0) non_flat_mask = variance > threshold if not np.any(non_flat_mask): self._logger.warning( "All dimensions are flat. Keeping all dimensions for interpolation." ) return grid_points, np.ones_like(variance, dtype=bool) reduced_grid_points = grid_points[:, non_flat_mask] return reduced_grid_points, non_flat_mask
[docs] def build_grid_points_no_energy(self): """ Build grid points without energy dimension. Returns ------- tuple (production_statistics, grid_points_no_energy) """ if not self.evaluators: self._logger.error("No evaluators available for grid point building.") return np.array([]), np.array([]) flat_data_list = [] flat_grid_points = [] for i, production_statistics_sum in enumerate(self.production_statistics_sum): az = self.azimuths[i] zen = self.zeniths[i] nsb = self.nsbs[i] offset = self.offsets[i] flat_data_list.append(float(production_statistics_sum.value)) grid_point = np.array([[az, zen, nsb, offset]]) flat_grid_points.append(grid_point) flat_grid_points = np.vstack(flat_grid_points) return flat_data_list, flat_grid_points
def _prepare_energy_independent_data(self): """ Prepare data for energy-independent interpolation. Returns ------- tuple (production_statistic, grid_points_no_energy) """ production_statistic, grid_points_no_energy = self.build_grid_points_no_energy() production_statistic = np.array(production_statistic, dtype=float) grid_points_no_energy, non_flat_mask = self._remove_flat_dimensions(grid_points_no_energy) self._non_flat_mask = non_flat_mask # Store for later use return production_statistic, grid_points_no_energy def _prepare_production_grid_points(self): """ Convert grid_points_production to a format suitable for interpolation. Returns ------- np.ndarray Reduced production grid points. """ production_grid_points = [] for point in self.grid_points_production: production_grid_points.append( [ point["azimuth"]["value"], point["zenith_angle"]["value"], point["nsb"]["value"], point["offset"]["value"], ] ) production_grid_points = np.array(production_grid_points) return production_grid_points[:, self._non_flat_mask] def _perform_interpolation(self, grid_points, values, query_points, method="linear"): """ Perform interpolation using griddata. Parameters ---------- grid_points : np.ndarray Grid points for interpolation. values : np.ndarray Values at the grid points. query_points : np.ndarray Query points for interpolation. method : str, optional Interpolation method, by default "linear". Returns ------- np.ndarray Interpolated values. """ self._logger.debug(f"Grid points shape: {grid_points.shape}") self._logger.debug(f"Values shape: {values.shape}") self._logger.debug(f"Query points shape: {query_points.shape}") return griddata( grid_points, values, query_points, method=method, fill_value=np.nan, rescale=True, ) def _perform_interpolation_with_energy(self): """ Perform energy-dependent interpolation. Returns ------- np.ndarray Energy-dependent interpolated values. """ # Get grid points with energy dimension grid_points_energy = self.grid_points grid_points_energy, _ = self._remove_flat_dimensions(grid_points_energy) # Build energy query grid reduced_production_grid_points = self._prepare_production_grid_points() energy_grid = self.energy_grids[0] if self.energy_grids else [] energy_query_grid = [] for energy in energy_grid: for grid_point in reduced_production_grid_points: energy_query_grid.append(np.hstack([energy.to(u.TeV).value, grid_point])) energy_query_grid = np.array(energy_query_grid) self._logger.debug(f"Grid points with energy shape: {grid_points_energy.shape}") self._logger.debug(f"Data shape: {self.data.shape}") self._logger.debug(f"Energy query grid shape: {energy_query_grid.shape}") interpolated_values = self._perform_interpolation( grid_points_energy, self.data, energy_query_grid ) reshaped = interpolated_values.reshape( len(reduced_production_grid_points), len(energy_grid) ) return np.array([reshaped])
[docs] def interpolate(self) -> np.ndarray: """ Interpolate production statistics at the grid points specified in grid_points_production. This method performs two types of interpolation: 1. Energy-independent interpolation using the sum of production statistics 2. Energy-dependent interpolation for each energy bin Returns ------- np.ndarray Interpolated values at the query points. """ if not self.evaluators: self._logger.error("No evaluators available for interpolation.") return np.array([]) # Energy-independent interpolation production_statistic, grid_points_no_energy = self._prepare_energy_independent_data() reduced_production_grid_points = self._prepare_production_grid_points() self.interpolated_production_statistics = self._perform_interpolation( grid_points_no_energy, production_statistic, reduced_production_grid_points ) # Energy-dependent interpolation self.interpolated_production_statistics_with_energy = ( self._perform_interpolation_with_energy() ) return self.interpolated_production_statistics
[docs] def plot_comparison(self, grid_point_index=0): """ Plot a comparison between interpolated production statistics and reconstructed events. Parameters ---------- grid_point_index : int, optional Index of the grid point to plot, by default 0 Returns ------- matplotlib.axes.Axes The Axes object containing the plot. """ import matplotlib.pyplot as plt # pylint: disable=C0415 if not self.evaluators: self._logger.error("No evaluators available for plotting.") _, ax = plt.subplots() ax.text(0.5, 0.5, "No data available", ha="center", va="center") return ax # Use first evaluator for energy bins bin_edges_low = self.evaluators[0].data["bin_edges_low"][:-1] bin_edges_high = self.evaluators[0].data["bin_edges_high"][:-1] midpoints = (bin_edges_low + bin_edges_high) / 2 if ( self.interpolated_production_statistics_with_energy is None or len(self.interpolated_production_statistics_with_energy) == 0 or len(self.interpolated_production_statistics_with_energy[0]) <= grid_point_index ): self._logger.warning( f"Invalid grid point index {grid_point_index}. Using index 0 instead." ) grid_point_index = 0 _, ax = plt.subplots() if ( self.interpolated_production_statistics_with_energy is not None and len(self.interpolated_production_statistics_with_energy) > 0 ): interpolated_stats = self.interpolated_production_statistics_with_energy[0][ grid_point_index ] ax.plot(midpoints, interpolated_stats, label="Interpolated Production Statistics") reconstructed_event_histogram, _ = np.histogram( self.evaluators[0].data["event_energies_reco"], bins=self.evaluators[0].data["bin_edges_low"], ) ax.plot(midpoints, reconstructed_event_histogram, label="Reconstructed Events") ax.legend() ax.set_xscale("log") ax.set_yscale("log") ax.set_xlabel("Energy (TeV)") ax.set_ylabel("Event Count") ax.set_title("Comparison of Interpolated and Reconstructed Events") return ax