"""Model data writer module."""
import json
import logging
from pathlib import Path
import astropy.units as u
import numpy as np
import yaml
from astropy.io.registry.base import IORegistryError
import simtools.utils.general as gen
from simtools.data_model import validate_data
from simtools.io_operations import io_handler
__all__ = ["ModelDataWriter"]
class JsonNumpyEncoder(json.JSONEncoder):
"""Convert numpy to python types as accepted by json.dump."""
def default(self, o):
if isinstance(o, np.floating):
return float(o)
if isinstance(o, np.integer):
return int(o)
if isinstance(o, np.ndarray):
return o.tolist()
if isinstance(o, u.core.CompositeUnit | u.core.IrreducibleUnit | u.core.Unit):
return str(o) if o != u.dimensionless_unscaled else None
if np.issubdtype(type(o), np.bool_):
return bool(o)
return super().default(o)
[docs]
class ModelDataWriter:
"""
Writer for simulation model data and metadata.
Parameters
----------
product_data_file: str
Name of output file.
product_data_format: str
Format of output file.
args_dict: Dictionary
Dictionary with configuration parameters.
"""
def __init__(self, product_data_file=None, product_data_format=None, args_dict=None):
"""Initialize model data writer."""
self._logger = logging.getLogger(__name__)
self.io_handler = io_handler.IOHandler()
if args_dict is not None:
self.io_handler.set_paths(
output_path=args_dict.get("output_path", None),
use_plain_output_path=args_dict.get("use_plain_output_path", False),
)
try:
self.product_data_file = self.io_handler.get_output_file(
file_name=product_data_file, dir_type="simtools-result"
)
except TypeError:
self.product_data_file = None
self.product_data_format = self._astropy_data_format(product_data_format)
[docs]
@staticmethod
def dump(
args_dict, output_file=None, metadata=None, product_data=None, validate_schema_file=None
):
"""
Write model data and metadata (as static method).
Parameters
----------
args_dict: dict
Dictionary with configuration parameters (including output file name and path).
output_file: string or Path
Name of output file (args["output_file"] is used if this parameter is not set).
metadata: dict
Metadata to be written.
product_data: astropy Table
Model data to be written
validate_schema_file: str
Schema file used in validation of output data.
"""
writer = ModelDataWriter(
product_data_file=(
args_dict.get("output_file", None) if output_file is None else output_file
),
product_data_format=args_dict.get("output_file_format", "ascii.ecsv"),
args_dict=args_dict,
)
if validate_schema_file is not None and not args_dict.get("skip_output_validation", True):
product_data = writer.validate_and_transform(
product_data=product_data,
validate_schema_file=validate_schema_file,
)
writer.write(metadata=metadata, product_data=product_data)
[docs]
def write(self, product_data=None, metadata=None):
"""
Write model data and metadata.
Parameters
----------
product_data: astropy Table
Model data to be written
metadata: dict
Metadata to be written.
Raises
------
FileNotFoundError
if data writing was not successful.
"""
if product_data is None:
return
if metadata is not None:
product_data.meta.update(gen.change_dict_keys_case(metadata, False))
self._logger.info(f"Writing data to {self.product_data_file}")
if isinstance(product_data, dict) and Path(self.product_data_file).suffix == ".json":
self.write_dict_to_model_parameter_json(self.product_data_file, product_data)
return
try:
product_data.write(
self.product_data_file, format=self.product_data_format, overwrite=True
)
except IORegistryError:
self._logger.error(f"Error writing model data to {self.product_data_file}.")
raise
[docs]
@staticmethod
def write_dict_to_model_parameter_json(file_name, data_dict):
"""
Write dictionary to model-parameter-style json file.
Parameters
----------
file_name : str
Name of output file.
data_dict : dict
Dictionary to be written.
Raises
------
FileNotFoundError
if data writing was not successful.
"""
try:
with open(file_name, "w", encoding="UTF-8") as file:
json.dump(data_dict, file, indent=4, sort_keys=False, cls=JsonNumpyEncoder)
file.write("\n")
except FileNotFoundError as exc:
raise FileNotFoundError(f"Error writing model data to {file_name}") from exc
@staticmethod
def _astropy_data_format(product_data_format):
"""
Ensure conformance with astropy data format naming.
Parameters
----------
product_data_format: string
format identifier
"""
if product_data_format == "ecsv":
product_data_format = "ascii.ecsv"
return product_data_format