"""Model data writer module."""
import json
import logging
from pathlib import Path
import astropy.units as u
import numpy as np
import yaml
from astropy.io.registry.base import IORegistryError
import simtools.utils.general as gen
from simtools.constants import MODEL_PARAMETER_SCHEMA_PATH
from simtools.data_model import validate_data
from simtools.data_model.metadata_collector import MetadataCollector
from simtools.io_operations import io_handler
from simtools.utils import names, value_conversion
__all__ = ["ModelDataWriter"]
class JsonNumpyEncoder(json.JSONEncoder):
"""Convert numpy to python types as accepted by json.dump."""
def default(self, o):
if isinstance(o, np.floating):
return float(o)
if isinstance(o, np.integer):
return int(o)
if isinstance(o, np.ndarray):
return o.tolist()
if isinstance(o, u.core.CompositeUnit | u.core.IrreducibleUnit | u.core.Unit):
return str(o) if o != u.dimensionless_unscaled else None
if np.issubdtype(type(o), np.bool_):
return bool(o)
return super().default(o)
[docs]
class ModelDataWriter:
"""
Writer for simulation model data and metadata.
Parameters
----------
product_data_file: str
Name of output file.
product_data_format: str
Format of output file.
args_dict: Dictionary
Dictionary with configuration parameters.
output_path: str or Path
Path to output file.
use_plain_output_path: bool
Use plain output path.
args_dict: dict
Dictionary with configuration parameters.
"""
def __init__(
self,
product_data_file=None,
product_data_format=None,
output_path=None,
use_plain_output_path=True,
args_dict=None,
):
"""Initialize model data writer."""
self._logger = logging.getLogger(__name__)
self.io_handler = io_handler.IOHandler()
self.schema_dict = {}
if args_dict is not None:
output_path = args_dict.get("output_path", output_path)
use_plain_output_path = args_dict.get("use_plain_output_path", use_plain_output_path)
if output_path is not None:
self.io_handler.set_paths(
output_path=output_path, use_plain_output_path=use_plain_output_path
)
try:
self.product_data_file = self.io_handler.get_output_file(file_name=product_data_file)
except TypeError:
self.product_data_file = None
self.product_data_format = self._astropy_data_format(product_data_format)
[docs]
@staticmethod
def dump(
args_dict, output_file=None, metadata=None, product_data=None, validate_schema_file=None
):
"""
Write model data and metadata (as static method).
Parameters
----------
args_dict: dict
Dictionary with configuration parameters (including output file name and path).
output_file: string or Path
Name of output file (args["output_file"] is used if this parameter is not set).
metadata: dict
Metadata to be written.
product_data: astropy Table
Model data to be written
validate_schema_file: str
Schema file used in validation of output data.
"""
writer = ModelDataWriter(
product_data_file=(
args_dict.get("output_file", None) if output_file is None else output_file
),
product_data_format=args_dict.get("output_file_format", "ascii.ecsv"),
args_dict=args_dict,
)
if validate_schema_file is not None and not args_dict.get("skip_output_validation", True):
product_data = writer.validate_and_transform(
product_data_table=product_data,
validate_schema_file=validate_schema_file,
)
writer.write(metadata=metadata, product_data=product_data)
[docs]
@staticmethod
def dump_model_parameter(
parameter_name,
value,
instrument,
model_version,
output_file,
output_path=None,
use_plain_output_path=False,
metadata_input_dict=None,
):
"""
Generate DB-style model parameter dict and write it to json file.
Parameters
----------
parameter_name: str
Name of the parameter.
value: any
Value of the parameter.
instrument: str
Name of the instrument.
model_version: str
Version of the model.
output_file: str
Name of output file.
output_path: str or Path
Path to output file.
use_plain_output_path: bool
Use plain output path.
metadata_input_dict: dict
Input to metadata collector.
Returns
-------
dict
Validated parameter dictionary.
"""
writer = ModelDataWriter(
product_data_file=output_file,
product_data_format="json",
args_dict=None,
output_path=output_path,
use_plain_output_path=use_plain_output_path,
)
_json_dict = writer.get_validated_parameter_dict(
parameter_name, value, instrument, model_version
)
writer.write_dict_to_model_parameter_json(output_file, _json_dict)
if metadata_input_dict is not None:
metadata_input_dict["output_file"] = output_file
metadata_input_dict["output_file_format"] = Path(output_file).suffix.lstrip(".")
writer.write_metadata_to_yml(
metadata=MetadataCollector(args_dict=metadata_input_dict).get_top_level_metadata(),
yml_file=output_path / f"{Path(output_file).stem}",
)
return _json_dict
[docs]
def get_validated_parameter_dict(self, parameter_name, value, instrument, model_version):
"""
Get validated parameter dictionary.
Parameters
----------
parameter_name: str
Name of the parameter.
value: any
Value of the parameter.
instrument: str
Name of the instrument.
model_version: str
Version of the model.
Returns
-------
dict
Validated parameter dictionary.
"""
self._logger.debug(f"Getting validated parameter dictionary for {instrument}")
schema_file = self._read_model_parameter_schema(parameter_name)
try: # e.g. instrument is 'North"
site = names.validate_site_name(instrument)
except ValueError: # e.g. instrument is 'LSTN-01'
site = names.get_site_from_array_element_name(instrument)
try:
applicable = self._get_parameter_applicability(instrument)
except ValueError:
applicable = True # Default to True (expect that this field goes in future)
value, unit = value_conversion.split_value_and_unit(value)
data_dict = {
"parameter": parameter_name,
"instrument": instrument,
"site": site,
"version": model_version,
"value": value,
"unit": unit,
"type": self._get_parameter_type(),
"applicable": applicable,
"file": self._parameter_is_a_file(),
}
return self.validate_and_transform(
product_data_dict=data_dict,
validate_schema_file=schema_file,
is_model_parameter=True,
)
def _read_model_parameter_schema(self, parameter_name):
"""
Read model parameter schema.
Parameters
----------
parameter_name: str
Name of the parameter.
"""
schema_file = MODEL_PARAMETER_SCHEMA_PATH / f"{parameter_name}.schema.yml"
try:
self.schema_dict = gen.collect_data_from_file(file_name=schema_file)
except FileNotFoundError as exc:
raise FileNotFoundError(f"Schema file not found: {schema_file}") from exc
return schema_file
def _get_parameter_type(self):
"""
Return parameter type from schema.
Returns
-------
str
Parameter type
"""
_parameter_type = []
for data in self.schema_dict["data"]:
_parameter_type.append(data["type"])
return _parameter_type if len(_parameter_type) > 1 else _parameter_type[0]
def _parameter_is_a_file(self):
"""
Check if parameter is a file.
Returns
-------
bool
True if parameter is a file.
"""
try:
return self.schema_dict["data"][0]["type"] == "file"
except (KeyError, IndexError):
pass
return False
def _get_parameter_applicability(self, telescope_name):
"""
Check if a parameter is applicable for a given telescope using schema files.
First check for exact telescope name (e.g., LSTN-01), if not listed in the schema
use telescope type (LSTN).
Parameters
----------
telescope_name: str
Telescope name (e.g., LSTN-01)
Returns
-------
bool
True if parameter is applicable to telescope.
"""
try:
if telescope_name in self.schema_dict["instrument"]["type"]:
return True
except KeyError as exc:
self._logger.error("Schema file does not contain 'instrument:type' key.")
raise exc
return (
names.get_array_element_type_from_name(telescope_name)
in self.schema_dict["instrument"]["type"]
)
def _get_unit_from_schema(self):
"""
Return unit(s) from schema dict.
Returns
-------
str or list
Parameter unit(s)
"""
try:
unit_list = []
for data in self.schema_dict["data"]:
unit_list.append(data["unit"] if data["unit"] != "dimensionless" else None)
return unit_list if len(unit_list) > 1 else unit_list[0]
except (KeyError, IndexError):
pass
return None
[docs]
def write(self, product_data=None, metadata=None):
"""
Write model data and metadata.
Parameters
----------
product_data: astropy Table
Model data to be written
metadata: dict
Metadata to be written.
Raises
------
FileNotFoundError
if data writing was not successful.
"""
if product_data is None:
return
if metadata is not None:
product_data.meta.update(gen.change_dict_keys_case(metadata, False))
self._logger.info(f"Writing data to {self.product_data_file}")
if isinstance(product_data, dict) and Path(self.product_data_file).suffix == ".json":
self.write_dict_to_model_parameter_json(self.product_data_file, product_data)
return
try:
product_data.write(
self.product_data_file, format=self.product_data_format, overwrite=True
)
except IORegistryError:
self._logger.error(f"Error writing model data to {self.product_data_file}.")
raise
[docs]
def write_dict_to_model_parameter_json(self, file_name, data_dict):
"""
Write dictionary to model-parameter-style json file.
Parameters
----------
file_name : str
Name of output file.
data_dict : dict
Data dictionary.
Raises
------
FileNotFoundError
if data writing was not successful.
"""
data_dict = ModelDataWriter.prepare_data_dict_for_writing(data_dict)
try:
self._logger.info(f"Writing data to {self.io_handler.get_output_file(file_name)}")
with open(self.io_handler.get_output_file(file_name), "w", encoding="UTF-8") as file:
json.dump(data_dict, file, indent=4, sort_keys=False, cls=JsonNumpyEncoder)
file.write("\n")
except FileNotFoundError as exc:
raise FileNotFoundError(
f"Error writing model data to {self.io_handler.get_output_file(file_name)}"
) from exc
[docs]
@staticmethod
def prepare_data_dict_for_writing(data_dict):
"""
Prepare data dictionary for writing to json file.
Ensure sim_telarray style lists as strings.
Replace "None" with "null" for unit field.
Parameters
----------
data_dict: dict
Dictionary with lists.
Returns
-------
dict
Dictionary with lists converted to strings.
"""
try:
data_dict["value"] = gen.convert_list_to_string(data_dict["value"])
data_dict["unit"] = gen.convert_list_to_string(data_dict["unit"], comma_separated=True)
data_dict["type"] = gen.convert_list_to_string(
data_dict["type"], comma_separated=True, collapse_list=True
)
if isinstance(data_dict["unit"], str):
data_dict["unit"] = data_dict["unit"].replace("None", "null")
except KeyError:
pass
return data_dict
@staticmethod
def _astropy_data_format(product_data_format):
"""
Ensure conformance with astropy data format naming.
Parameters
----------
product_data_format: string
format identifier
"""
if product_data_format == "ecsv":
product_data_format = "ascii.ecsv"
return product_data_format