Source code for production_configuration.interpolation_handler

"""Interpolates between instances of StatisticalErrorEvaluator using EventScaler."""

import astropy.units as u
import numpy as np
from scipy.interpolate import griddata

from simtools.production_configuration.event_scaler import EventScaler

__all__ = ["InterpolationHandler"]


[docs] class InterpolationHandler: """Handle interpolation between multiple StatisticalErrorEvaluator instances.""" def __init__(self, evaluators, metrics: dict): self.evaluators = evaluators self.metrics = metrics self.event_scalers = [EventScaler(e, self.metrics) for e in self.evaluators] self.azimuths = [e.grid_point[1].to(u.deg).value for e in self.evaluators] self.zeniths = [e.grid_point[2].to(u.deg).value for e in self.evaluators] self.nsbs = [e.grid_point[3] for e in self.evaluators] self.offsets = [e.grid_point[4].to(u.deg).value for e in self.evaluators] self.energy_grids = [ (e.data["bin_edges_low"][:-1] + e.data["bin_edges_high"][:-1]) / 2 for e in self.evaluators ] self.scaled_events = [ scaler.scale_events(return_sum=False) for scaler in self.event_scalers ] self.energy_thresholds = np.array([e.energy_threshold for e in self.evaluators]) self.data, self.grid_points = self._build_data_array() def _build_data_array(self): """ Build a data array with interpolated values across all dimensions including energy. Returns ------- np.ndarray The data array with interpolated values. np.ndarray The corresponding grid points. """ # Flatten the energy grid and other dimensions into a combined array flat_data_list = [] flat_grid_points = [] for e, energy_grid, scaled_events in zip( self.evaluators, self.energy_grids, self.scaled_events ): az = np.full(len(energy_grid), e.grid_point[1].to(u.deg).value) zen = np.full(len(energy_grid), e.grid_point[2].to(u.deg).value) nsb = np.full(len(energy_grid), e.grid_point[3]) offset = np.full(len(energy_grid), e.grid_point[4].to(u.deg).value) # Combine grid points and data grid_points = np.column_stack([energy_grid.to(u.TeV).value, az, zen, nsb, offset]) flat_grid_points.append(grid_points) flat_data_list.append(scaled_events) # Flatten the list and convert to numpy arrays flat_grid_points = np.vstack(flat_grid_points) flat_data = np.hstack(flat_data_list) # Sort the grid points and corresponding data by energy sorted_indices = np.argsort(flat_grid_points[:, 0]) sorted_grid_points = flat_grid_points[sorted_indices] sorted_data = flat_data[sorted_indices] return sorted_data, sorted_grid_points def _remove_flat_dimensions(self, grid_points): """Identify and remove flat dimensions (dimensions with no variance).""" variance = np.var(grid_points, axis=0) non_flat_mask = variance > 1e-6 # Threshold for determining flatness reduced_grid_points = grid_points[:, non_flat_mask] return reduced_grid_points, non_flat_mask
[docs] def interpolate(self, query_points: np.ndarray) -> np.ndarray: """ Interpolate the number of simulated events given query points. Parameters ---------- query_points : np.ndarray Array of query points with shape (n, 5), where n is the number of points, and 5 represents (energy, azimuth, zenith, nsb, offset). Returns ------- np.ndarray Interpolated values at the query points. """ reduced_grid_points, non_flat_mask = self._remove_flat_dimensions(self.grid_points) reduced_query_points = query_points[:, non_flat_mask] # Interpolate using the reduced dimensions return griddata( reduced_grid_points, self.data, reduced_query_points, method="linear", fill_value=np.nan, rescale=True, )
[docs] def interpolate_energy_threshold(self, query_point: np.ndarray) -> float: """ Interpolate the energy threshold for a given grid point. Parameters ---------- query_point : np.ndarray Array specifying the grid point (energy, azimuth, zenith, NSB, offset). Returns ------- float Interpolated energy threshold. """ flat_grid_points = [] flat_energy_thresholds = [] for e in self.evaluators: az = e.grid_point[1].to(u.deg).value zen = e.grid_point[2].to(u.deg).value nsb = e.grid_point[3] offset = e.grid_point[4].to(u.deg).value grid_point = np.array([az, zen, nsb, offset]) flat_grid_points.append(grid_point) flat_energy_thresholds.append(e.energy_threshold) flat_grid_points = np.array(flat_grid_points) flat_energy_thresholds = np.array(flat_energy_thresholds) reduced_grid_points, non_flat_mask = self._remove_flat_dimensions(flat_grid_points) full_non_flat_mask = np.concatenate(([False], non_flat_mask)) reduced_query_point = query_point[0][full_non_flat_mask] interpolated_threshold = griddata( reduced_grid_points, flat_energy_thresholds, reduced_query_point, method="linear", fill_value=np.nan, rescale=False, ) return interpolated_threshold.item()
[docs] def plot_comparison(self, evaluator): """ Plot a comparison between the simulated, scaled, and reconstructed events. Parameters ---------- evaluator : StatisticalErrorEvaluator The evaluator for which to plot the comparison. """ import matplotlib.pyplot as plt # pylint: disable=import-outside-toplevel midpoints = 0.5 * (evaluator.data["bin_edges_high"] + evaluator.data["bin_edges_low"]) self.grid_points = np.column_stack( [ midpoints, np.full_like(midpoints, evaluator.grid_point[1]), np.full_like(midpoints, evaluator.grid_point[2]), np.full_like(midpoints, evaluator.grid_point[3]), np.full_like(midpoints, evaluator.grid_point[4]), ] ) self.interpolate(self.grid_points) plt.plot(midpoints, evaluator.scaled_events, label="Scaled") reconstructed_event_histogram, _ = np.histogram( evaluator.data["event_energies_reco"], bins=evaluator.data["bin_edges_low"] ) plt.plot(midpoints[:-1], reconstructed_event_histogram, label="Reconstructed") plt.legend() plt.xscale("log") plt.xlabel("Energy (Midpoint of Bin Edges)") plt.ylabel("Event Count") plt.title("Comparison of Simulated, scaled, and reconstructed events") plt.show()