Source code for production_configuration.derive_production_statistics_handler
"""
Derives the required statistics for a requested set of production parameters through interpolation.
This module provides the `ProductionStatisticsHandler` class, which manages the workflow for
derivation of required number of events for a simulation production using pre-defined metrics.
The module includes functionality to:
- Initialize evaluators for statistical uncertainty calculations based on input parameters.
- Perform interpolation using the initialized evaluators to estimate production statistics at a
query point.
- Write the results of the interpolation to an output file.
"""
import itertools
import json
import logging
from pathlib import Path
import astropy.units as u
import numpy as np
from simtools.production_configuration.calculate_statistical_uncertainties_grid_point import (
StatisticalUncertaintyEvaluator,
)
from simtools.production_configuration.interpolation_handler import InterpolationHandler
from simtools.utils.general import collect_data_from_file
[docs]
class ProductionStatisticsHandler:
"""
Handles the workflow for deriving production statistics.
This class manages the evaluation of statistical uncertainties from DL2 MC event files
and performs interpolation to estimate the required number of events for a simulation
production at a specified query point.
"""
def __init__(self, args_dict):
"""
Initialize the manager with the provided arguments.
Parameters
----------
args_dict : dict
Dictionary of command-line arguments.
"""
self.args = args_dict
self.logger = logging.getLogger(__name__)
self.output_path = Path(self.args.get("output_path", "."))
self.output_filepath = self.output_path.joinpath(f"{self.args['output_file']}")
self.metrics = collect_data_from_file(self.args["metrics_file"])
self.evaluator_instances = []
self.interpolation_handler = None
[docs]
def initialize_evaluators(self):
"""Initialize StatisticalUncertaintyEvaluator instances for the given zeniths/offsets."""
if not (self.args["base_path"] and self.args["zeniths"] and self.args["camera_offsets"]):
self.logger.warning("No files read")
self.logger.warning(f"Base Path: {self.args['base_path']}")
self.logger.warning(f"Zeniths: {self.args['zeniths']}")
self.logger.warning(f"Camera offsets: {self.args['camera_offsets']}")
return
for zenith, offset in itertools.product(self.args["zeniths"], self.args["camera_offsets"]):
file_name = self.args["file_name_template"].format(zenith=int(zenith))
file_path = Path(self.args["base_path"]).joinpath(file_name)
if not file_path.exists():
self.logger.warning(f"File not found: {file_path}. Skipping.")
continue
evaluator = StatisticalUncertaintyEvaluator(
file_path,
metrics=self.metrics,
grid_point=(None, None, zenith, None, offset * u.deg),
)
evaluator.calculate_metrics()
self.evaluator_instances.append(evaluator)
[docs]
def write_output(self, production_statistics):
"""Write the derived event statistics to a file."""
output_data = {
"query_point": self.args["query_point"],
"production_statistics": production_statistics.tolist(),
}
self.output_filepath.parent.mkdir(parents=True, exist_ok=True)
with open(self.output_filepath, "w", encoding="utf-8") as f:
json.dump(output_data, f, indent=4)
self.logger.info(f"Output saved to {self.output_filepath}")
self.logger.info(
f"production statistics for grid point "
f"{self.args['query_point']}: {production_statistics}"
)
[docs]
def plot_production_statistics_comparison(self):
"""Plot the derived event statistics."""
ax = self.interpolation_handler.plot_comparison()
plot_path = self.output_path.joinpath("production_statistics_comparison.png")
plot_path.parent.mkdir(parents=True, exist_ok=True)
ax.figure.savefig(plot_path)
self.logger.info(f"Plot saved to {plot_path}")
[docs]
def run(self):
"""Run the scaling and interpolation workflow."""
self.logger.info(f"Zeniths: {self.args['zeniths']}")
self.logger.info(f"Camera offsets: {self.args['camera_offsets']}")
self.logger.info(f"Query Point: {self.args['query_point']}")
self.logger.info(f"Metrics File: {self.args['metrics_file']}")
self.initialize_evaluators()
production_statistics = self.perform_interpolation()
if self.args.get("plot_production_statistics"):
self.plot_production_statistics_comparison()
self.write_output(production_statistics)