"""Module to handle interaction with DB."""
import logging
from collections import defaultdict
from pathlib import Path
from simtools.data_model import validate_data
from simtools.db.mongo_db import MongoDBHandler
from simtools.io import io_handler
from simtools.simtel import simtel_table_reader
from simtools.utils import names, value_conversion
from simtools.version import resolve_version_to_latest_patch
[docs]
class DatabaseHandler:
"""
DatabaseHandler provides the interface to the DB.
Note the two types of version variables used in this class:
- db_simulation_model_version (from db_config): version of the simulation model database
- model_version (from production_tables): version of the model contained in the database
Parameters
----------
db_config: dict
Dictionary with the DB configuration.
"""
ALLOWED_FILE_EXTENSIONS = [".dat", ".txt", ".lis", ".cfg", ".yml", ".yaml", ".ecsv"]
production_table_cached = {}
model_parameters_cached = {}
model_versions_cached = {}
def __init__(self, db_config=None):
"""Initialize the DatabaseHandler class."""
self._logger = logging.getLogger(__name__)
self.db_config = MongoDBHandler.validate_db_config(db_config)
self.io_handler = io_handler.IOHandler()
self.mongo_db_handler = MongoDBHandler(db_config) if self.db_config else None
self.db_name = (
MongoDBHandler.get_db_name(
db_simulation_model_version=self.db_config.get("db_simulation_model_version"),
model_name=self.db_config.get("db_simulation_model"),
)
if self.db_config
else None
)
[docs]
def get_db_name(self, db_name=None, db_simulation_model_version=None, model_name=None):
"""Build DB name from configuration."""
if db_name:
return db_name
if db_simulation_model_version and model_name:
return MongoDBHandler.get_db_name(
db_simulation_model_version=db_simulation_model_version,
model_name=model_name,
)
if not (db_simulation_model_version or model_name):
return self.db_name
return None
[docs]
def print_connection_info(self):
"""Print the connection information."""
if self.mongo_db_handler:
self.mongo_db_handler.print_connection_info(self.db_name)
else:
self._logger.info("No database defined.")
[docs]
def is_remote_database(self):
"""
Check if the database is remote.
Check for domain pattern like "cta-simpipe-protodb.zeuthen.desy.de"
Returns
-------
bool
True if the database is remote, False otherwise.
"""
return bool(self.mongo_db_handler and self.mongo_db_handler.is_remote_database())
[docs]
def generate_compound_indexes_for_databases(
self, db_name, db_simulation_model, db_simulation_model_version
):
"""
Generate compound indexes for several databases.
Parameters
----------
db_name: str
Name of the database.
db_simulation_model: str
Name of the simulation model.
db_simulation_model_version: str
Version of the simulation model.
"""
self.mongo_db_handler.generate_compound_indexes_for_databases(
db_name, db_simulation_model, db_simulation_model_version
)
[docs]
def get_model_parameter(
self,
parameter,
site,
array_element_name,
parameter_version=None,
model_version=None,
):
"""
Get a single model parameter using model or parameter version.
Note that this function should not be called in a loop for many parameters,
as it each call queries the database.
Parameters
----------
parameter: str
Name of the parameter.
site: str
Site name.
array_element_name: str
Name of the array element model.
parameter_version: str
Version of the parameter.
model_version: str
Version of the model.
Returns
-------
dict containing the parameter
"""
collection_name = names.get_collection_name_from_parameter_name(parameter)
if model_version:
if isinstance(model_version, list):
raise ValueError(
"Only one model version can be passed to get_model_parameter, not a list."
)
model_version = resolve_version_to_latest_patch(
model_version, self.get_model_versions(collection_name)
)
production_table = self.read_production_table_from_db(collection_name, model_version)
array_element_list = self._get_array_element_list(
array_element_name, site, production_table, collection_name
)
for array_element in reversed(array_element_list):
parameter_version = (
production_table["parameters"].get(array_element, {}).get(parameter)
)
if parameter_version:
array_element_name = array_element
break
query = {
"parameter_version": parameter_version,
"parameter": parameter,
}
if array_element_name:
query["instrument"] = array_element_name
if site:
query["site"] = site
return self._read_db(query=query, collection_name=collection_name)
[docs]
def get_model_parameters(self, site, array_element_name, collection, model_version):
"""
Get model parameters using the model version.
Queries parameters for design and for the specified array element (if necessary).
Parameters
----------
site: str
Site name.
array_element_name: str
Name of the array element model (e.g. LSTN-01, MSTx-FlashCam, ILLN-01).
model_version: str, list
Version(s) of the model.
collection: str
Collection of array element (e.g. telescopes, calibration_devices).
Returns
-------
dict containing the parameters
"""
pars = {}
model_version = resolve_version_to_latest_patch(
model_version, self.get_model_versions(collection)
)
production_table = self.read_production_table_from_db(collection, model_version)
array_element_list = self._get_array_element_list(
array_element_name, site, production_table, collection
)
for array_element in array_element_list:
pars.update(
self._get_parameter_for_model_version(
array_element, model_version, site, collection, production_table
)
)
return pars
[docs]
def get_model_parameters_for_all_model_versions(self, site, array_element_name, collection):
"""
Get model parameters for all model versions.
Parameters
----------
site: str
Site name.
array_element_name: str
Name of the array element model (e.g. LSTN-01, MSTx-FlashCam, ILLN-01).
collection: str
Collection of array element (e.g. telescopes, calibration_devices).
Returns
-------
dict containing the parameters with model version as first key
"""
pars = defaultdict(dict)
for _model_version in self.get_model_versions(collection):
try:
parameter_data = self.get_model_parameters(
site, array_element_name, collection, _model_version
)
pars[_model_version].update(parameter_data)
except KeyError:
self._logger.debug(
f"Skipping model version {_model_version} - {array_element_name} not found"
)
continue
return pars
def _get_parameter_for_model_version(
self, array_element, model_version, site, collection, production_table
):
"""
Get parameters for a specific model version and array element.
Uses caching wherever possible.
"""
cache_key, cache_dict = self._read_cache(
DatabaseHandler.model_parameters_cached,
names.validate_site_name(site) if site else None,
array_element,
model_version,
collection,
)
if cache_dict:
self._logger.debug(f"Found {array_element} in cache (key: {cache_key})")
return cache_dict
self._logger.debug(f"Did not find {array_element} in cache (key: {cache_key})")
try:
parameter_version_table = production_table["parameters"][array_element]
except KeyError: # allow missing array elements (parameter dict is checked later)
return {}
DatabaseHandler.model_parameters_cached[cache_key] = self._read_db(
query=self._get_query_from_parameter_version_table(
parameter_version_table, array_element, site
),
collection_name=collection,
)
return DatabaseHandler.model_parameters_cached[cache_key]
[docs]
def get_collection(self, collection_name, db_name=None):
"""
Get a collection from the DB.
Parameters
----------
collection_name: str
Name of the collection.
db_name: str
Name of the DB.
Returns
-------
pymongo.collection.Collection
The collection from the DB.
"""
db_name = db_name or self.db_name
return self.mongo_db_handler.get_collection(collection_name, db_name)
[docs]
def get_collections(self, db_name=None, model_collections_only=False):
"""
List of collections in the DB.
Parameters
----------
db_name: str
Database name.
model_collections_only: bool
If True, only return model collections (i.e. exclude fs.files, fs.chunks)
Returns
-------
list
List of collection names
"""
return self.mongo_db_handler.get_collections(
db_name or self.db_name, model_collections_only
)
[docs]
def export_model_file(
self,
parameter,
site,
array_element_name,
model_version=None,
parameter_version=None,
export_file_as_table=False,
):
"""
Export single model file from the DB identified by the parameter name.
The parameter can be identified by model or parameter version.
Files can be exported as astropy tables (ecsv format).
Parameters
----------
parameter: str
Name of the parameter.
site: str
Site name.
array_element_name: str
Name of the array element model (e.g. MSTN, SSTS).
parameter_version: str
Version of the parameter.
model_version: str
Version of the model.
export_file_as_table: bool
If True, export the file as an astropy table (ecsv format).
Returns
-------
astropy.table.Table or None
If export_file_as_table is True
"""
parameters = self.get_model_parameter(
parameter,
site,
array_element_name,
parameter_version=parameter_version,
model_version=model_version,
)
self.export_model_files(parameters=parameters, dest=self.io_handler.get_output_directory())
if export_file_as_table:
return simtel_table_reader.read_simtel_table(
parameter,
self.io_handler.get_output_directory().joinpath(parameters[parameter]["value"]),
)
return None
[docs]
def export_model_files(self, parameters=None, file_names=None, dest=None, db_name=None):
"""
Export models files from the DB to given directory.
The files to be exported can be specified by file_name or retrieved from the database
using the parameters dictionary.
Parameters
----------
parameters: dict
Dict of model parameters
file_names: list, str
List (or string) of file names to export
dest: str or Path
Location where to write the files to.
Returns
-------
file_id: dict of GridOut._id
Dict of database IDs of files.
"""
db_name = db_name or self.db_name
if file_names:
file_names = [file_names] if not isinstance(file_names, list) else file_names
elif parameters:
file_names = [
info["value"]
for info in parameters.values()
if info and info.get("file") and info["value"] is not None
]
instance_ids = {}
for file_name in file_names:
if Path(dest).joinpath(file_name).exists():
instance_ids[file_name] = "file exists"
else:
file_path_instance = self.mongo_db_handler.get_file_from_db(db_name, file_name)
self._write_file_from_db_to_disk(db_name, dest, file_path_instance)
instance_ids[file_name] = file_path_instance._id # pylint: disable=protected-access
return instance_ids
def _get_query_from_parameter_version_table(
self, parameter_version_table, array_element_name, site
):
"""Return query based on parameter version table."""
query_dict = {
"$or": [
{"parameter": param, "parameter_version": version}
for param, version in parameter_version_table.items()
],
}
# 'xSTX-design' is a placeholder to ignore 'instrument' field in query.
if array_element_name and array_element_name != "xSTx-design":
query_dict["instrument"] = array_element_name
if site:
query_dict["site"] = site
return query_dict
def _read_db(self, query, collection_name):
"""
Query MongoDB.
Parameters
----------
query: dict
Query to execute.
collection_name: str
Collection name.
Returns
-------
dict containing the parameters
Raises
------
ValueError
if query returned no results.
"""
posts = self.mongo_db_handler.query_db(query, collection_name, self.db_name)
parameters = {}
for post in posts:
par_now = post["parameter"]
parameters[par_now] = post
parameters[par_now]["entry_date"] = self.mongo_db_handler.get_entry_date_from_document(
post
)
return {k: parameters[k] for k in sorted(parameters)}
[docs]
def read_production_table_from_db(self, collection_name, model_version):
"""
Read production table for the given collection from MongoDB.
Parameters
----------
collection_name: str
Name of the collection.
model_version: str
Version of the model.
Raises
------
ValueError
if query returned no results.
"""
model_version = resolve_version_to_latest_patch(
model_version, self.get_model_versions(collection_name)
)
try:
return DatabaseHandler.production_table_cached[
self._cache_key(None, None, model_version, collection_name)
]
except KeyError:
pass
query = {"model_version": model_version, "collection": collection_name}
post = self.mongo_db_handler.find_one(query, "production_tables", self.db_name)
if not post:
raise ValueError(f"The following query returned zero results: {query}")
return {
"collection": post["collection"],
"model_version": post["model_version"],
"parameters": post["parameters"],
"design_model": post.get("design_model", {}),
"entry_date": self.mongo_db_handler.get_entry_date_from_document(post),
}
[docs]
def get_model_versions(self, collection_name="telescopes"):
"""
Get list of model versions from the DB with caching.
Parameters
----------
collection_name: str
Name of the collection.
Returns
-------
list
List of model versions
"""
if collection_name not in DatabaseHandler.model_versions_cached:
collection = self.get_collection("production_tables", db_name=self.db_name)
DatabaseHandler.model_versions_cached[collection_name] = sorted(
{post["model_version"] for post in collection.find({"collection": collection_name})}
)
return DatabaseHandler.model_versions_cached[collection_name]
[docs]
def get_array_elements(self, model_version, collection="telescopes"):
"""
Get list array elements for a given model version and collection from the DB.
Parameters
----------
model_version: str
Version of the model.
collection: str
Which collection to get the array elements from:
i.e. telescopes, calibration_devices.
Returns
-------
list
Sorted list of all array elements found in collection
"""
model_version = resolve_version_to_latest_patch(
model_version, self.get_model_versions(collection)
)
production_table = self.read_production_table_from_db(collection, model_version)
return sorted([entry for entry in production_table["parameters"] if "-design" not in entry])
[docs]
def get_design_model(self, model_version, array_element_name, collection="telescopes"):
"""
Get the design model used for a given array element and a given model version.
Parameters
----------
model_version: str
Version of the model.
array_element_name: str
Name of the array element model (e.g. MSTN, SSTS).
collection: str
Which collection to get the array elements from:
i.e. telescopes, calibration_devices.
Returns
-------
str
Design model for a given array element.
"""
model_version = resolve_version_to_latest_patch(
model_version, self.get_model_versions(collection)
)
production_table = self.read_production_table_from_db(collection, model_version)
try:
return production_table["design_model"][array_element_name]
except KeyError:
# for eg. array_element_name == 'LSTN-design' returns 'LSTN-design'
return array_element_name
[docs]
def get_array_elements_of_type(self, array_element_type, model_version, collection):
"""
Get array elements of a certain type (e.g. 'LSTN') for a DB collection.
Does not return 'design' models.
Parameters
----------
array_element_type: str
Type of the array element (e.g. LSTN, MSTS).
model_version: str
Version of the model.
collection: str
Which collection to get the array elements from:
i.e. telescopes, calibration_devices.
Returns
-------
list
Sorted list of all array element names found in collection
"""
model_version = resolve_version_to_latest_patch(
model_version, self.get_model_versions(collection)
)
production_table = self.read_production_table_from_db(collection, model_version)
all_array_elements = production_table["parameters"]
return sorted(
[
entry
for entry in all_array_elements
if entry.startswith(array_element_type) and "-design" not in entry
]
)
[docs]
def get_simulation_configuration_parameters(
self, simulation_software, site, array_element_name, model_version
):
"""
Get simulation configuration parameters from the DB.
Parameters
----------
simulation_software: str
Name of the simulation software.
site: str
Site name.
array_element_name: str
Name of the array element model (e.g. MSTN, SSTS).
model_version: str
Version of the model.
Returns
-------
dict containing the parameters
Raises
------
ValueError
if simulation_software is not valid.
"""
if simulation_software == "corsika":
return self.get_model_parameters(
None,
None,
model_version=model_version,
collection="configuration_corsika",
)
if simulation_software == "sim_telarray":
return (
self.get_model_parameters(
site,
array_element_name,
model_version=model_version,
collection="configuration_sim_telarray",
)
if site and array_element_name
else {}
)
raise ValueError(f"Unknown simulation software: {simulation_software}")
def _write_file_from_db_to_disk(self, db_name, path, file):
"""
Extract a file from MongoDB and write it to disk.
Parameters
----------
db_name: str
the name of the DB with files of tabulated data
path: str or Path
The path to write the file to
file: GridOut
A file instance returned by GridFS find_one
"""
self.mongo_db_handler.write_file_from_db_to_disk(db_name, path, file)
[docs]
def get_ecsv_file_as_astropy_table(self, file_name, db_name=None):
"""
Read contents of an ECSV file from the database and return it as an Astropy Table.
Files are not written to disk.
Parameters
----------
file_name: str
The name of the ECSV file.
db_name: str
The name of the database.
Returns
-------
astropy.table.Table
The contents of the ECSV file as an Astropy Table.
"""
return self.mongo_db_handler.get_ecsv_file_as_astropy_table(
file_name, db_name or self.db_name
)
[docs]
def add_production_table(self, production_table, db_name=None):
"""
Add a production table to the DB.
Parameters
----------
production_table: dict
The production table to add to the DB.
db_name: str
the name of the DB.
"""
self._logger.debug(f"Adding production for {production_table.get('collection')} to the DB")
self.mongo_db_handler.insert_one(
production_table, "production_tables", db_name or self.db_name
)
DatabaseHandler.production_table_cached.clear()
DatabaseHandler.model_versions_cached.clear()
[docs]
def add_new_parameter(
self,
par_dict,
db_name=None,
collection_name="telescopes",
file_prefix=None,
):
"""
Add a new parameter dictionary to the DB.
A new document will be added to the DB, with all fields taken from the input parameters.
Parameter dictionaries are validated before submission using the corresponding schema.
Parameters
----------
par_dict: dict
dictionary with parameter data
db_name: str
the name of the DB
collection_name: str
The name of the collection to add a parameter to.
file_prefix: str or Path
where to find files to upload to the DB
"""
par_dict = validate_data.DataValidator.validate_model_parameter(par_dict)
db_name = db_name or self.db_name
par_dict["value"], _base_unit, _ = value_conversion.get_value_unit_type(
value=par_dict["value"], unit_str=par_dict.get("unit", None)
)
par_dict["unit"] = _base_unit if _base_unit else None
files_to_add_to_db = set()
if par_dict["file"] and par_dict["value"]:
if file_prefix is None:
raise FileNotFoundError(
"The location of the file to upload, "
f"corresponding to the {par_dict['parameter']} parameter, must be provided."
)
file_path = Path(file_prefix).joinpath(par_dict["value"])
files_to_add_to_db.add(f"{file_path}")
self._logger.debug(
f"Adding a new entry to DB {db_name} and collection {collection_name}:\n{par_dict}"
)
self.mongo_db_handler.insert_one(par_dict, collection_name, db_name)
for file_to_insert_now in files_to_add_to_db:
self._logger.debug(f"Will also add the file {file_to_insert_now} to the DB")
self.insert_file_to_db(file_to_insert_now, db_name)
self._reset_parameter_cache()
[docs]
def insert_file_to_db(self, file_name, db_name=None):
"""
Insert a file to the DB.
Parameters
----------
file_name: str or Path
The name of the file to insert (full path).
db_name: str
the name of the DB
Returns
-------
file_id: GridOut._id
If the file exists, return its GridOut._id, otherwise insert the file and return
its newly created DB GridOut._id.
"""
return self.mongo_db_handler.insert_file_to_db(file_name, db_name or self.db_name)
def _cache_key(self, site=None, array_element_name=None, model_version=None, collection=None):
"""
Create a cache key for the parameter cache dictionaries.
Parameters
----------
site: str
Site name.
array_element_name: str
Array element name.
model_version: str
Model version.
collection: str
DB collection name.
Returns
-------
str
Cache key.
"""
return "-".join(
part for part in [model_version, collection, site, array_element_name] if part
)
def _read_cache(
self, cache_dict, site=None, array_element_name=None, model_version=None, collection=None
):
"""
Read parameters from cache.
Parameters
----------
cache_dict: dict
Cache dictionary.
site: str
Site name.
array_element_name: str
Array element name.
model_version: str
Model version.
collection: str
DB collection name.
Returns
-------
str
Cache key.
"""
cache_key = self._cache_key(site, array_element_name, model_version, collection)
try:
return cache_key, cache_dict[cache_key]
except KeyError:
return cache_key, None
def _reset_parameter_cache(self):
"""Reset the cache for the parameters."""
DatabaseHandler.model_parameters_cached.clear()
DatabaseHandler.model_versions_cached.clear()
def _get_array_element_list(self, array_element_name, site, production_table, collection):
"""
Return list of array elements for DB queries (add design model if needed).
Design model is added if found in the production table.
Parameters
----------
array_element_name: str
Name of the array element.
site: str
Site name.
production_table: dict
Production table.
collection: str
collection of array element (e.g. telescopes, calibration_devices).
Returns
-------
list
List of array elements
"""
if collection == "configuration_corsika":
return ["xSTx-design"] # placeholder to ignore 'instrument' field in query.
if collection == "sites":
return [f"OBS-{site}"]
if names.is_design_type(array_element_name):
return [array_element_name]
if collection == "configuration_sim_telarray":
# get design model from 'telescope' or 'calibration_device' production tables
production_table = self.read_production_table_from_db(
names.get_collection_name_from_array_element_name(array_element_name),
production_table["model_version"],
)
try:
return [
production_table["design_model"][array_element_name],
array_element_name,
]
except KeyError as exc:
raise KeyError(
f"Failed generated array element list for db query for {array_element_name}"
) from exc