Source code for db.db_handler

"""Module to handle interaction with DB."""

import logging
import re
from collections import defaultdict
from pathlib import Path
from threading import Lock

import gridfs
import jsonschema
from bson.objectid import ObjectId
from packaging.version import Version
from pymongo import MongoClient

from simtools.data_model import validate_data
from simtools.io_operations import io_handler
from simtools.simtel import simtel_table_reader
from simtools.utils import names, value_conversion

__all__ = ["DatabaseHandler"]

logging.getLogger("pymongo").setLevel(logging.WARNING)


# pylint: disable=unsubscriptable-object
# The above comment is because pylint does not know that DatabaseHandler.db_client is subscriptable


jsonschema_db_dict = {
    "$schema": "https://json-schema.org/draft/2020-12/schema#",
    "type": "object",
    "description": "MongoDB configuration",
    "properties": {
        "db_server": {"type": "string", "description": "DB server address"},
        "db_api_port": {
            "type": "integer",
            "minimum": 1,
            "maximum": 65535,
            "default": 27017,
            "description": "Port to use",
        },
        "db_api_user": {"type": "string", "description": "API username"},
        "db_api_pw": {"type": "string", "description": "Password for the API user"},
        "db_api_authentication_database": {
            "type": ["string", "null"],
            "default": "admin",
            "description": "DB with user info (optional)",
        },
        "db_simulation_model": {
            "type": "string",
            "description": "Name of simulation model database",
        },
    },
    "required": ["db_server", "db_api_port", "db_api_user", "db_api_pw", "db_simulation_model"],
}


[docs] class DatabaseHandler: """ DatabaseHandler provides the interface to the DB. Parameters ---------- mongo_db_config: dict Dictionary with the MongoDB configuration (see jsonschema_db_dict for details). """ ALLOWED_FILE_EXTENSIONS = [".dat", ".txt", ".lis", ".cfg", ".yml", ".yaml", ".ecsv"] db_client = None production_table_cached = {} model_parameters_cached = {} def __init__(self, mongo_db_config=None): """Initialize the DatabaseHandler class.""" self._logger = logging.getLogger(__name__) self.mongo_db_config = self._validate_mongo_db_config(mongo_db_config) self.io_handler = io_handler.IOHandler() self.list_of_collections = {} self._set_up_connection() self._find_latest_simulation_model_db() def _set_up_connection(self): """Open the connection to MongoDB.""" if self.mongo_db_config and DatabaseHandler.db_client is None: lock = Lock() with lock: DatabaseHandler.db_client = self._open_mongo_db() def _validate_mongo_db_config(self, mongo_db_config): """Validate the MongoDB configuration.""" if mongo_db_config is None or all(value is None for value in mongo_db_config.values()): return None try: jsonschema.validate(instance=mongo_db_config, schema=jsonschema_db_dict) return mongo_db_config except jsonschema.exceptions.ValidationError as err: raise ValueError("Invalid MongoDB configuration") from err def _open_mongo_db(self): """ Open a connection to MongoDB and return the client to read/write to the DB with. Returns ------- A PyMongo DB client Raises ------ KeyError If the DB configuration is invalid """ direct_connection = self.mongo_db_config["db_server"] in ( "localhost", "simtools-mongodb", "mongodb", ) return MongoClient( self.mongo_db_config["db_server"], port=self.mongo_db_config["db_api_port"], username=self.mongo_db_config["db_api_user"], password=self.mongo_db_config["db_api_pw"], authSource=self.mongo_db_config.get("db_api_authentication_database") if self.mongo_db_config.get("db_api_authentication_database") else "admin", directConnection=direct_connection, ssl=not direct_connection, tlsallowinvalidhostnames=True, tlsallowinvalidcertificates=True, ) def _find_latest_simulation_model_db(self): """ Find the latest released version of the simulation model and update the DB config. This is indicated by adding "LATEST" to the name of the simulation model database (field "db_simulation_model" in the database configuration dictionary). Only released versions are considered, pre-releases are ignored. Raises ------ ValueError If the "LATEST" version is requested but no versions are found in the DB. """ try: db_simulation_model = self.mongo_db_config["db_simulation_model"] if not db_simulation_model.endswith("LATEST"): return except TypeError: # db_simulation_model is None return prefix = db_simulation_model.replace("LATEST", "") list_of_db_names = self.db_client.list_database_names() filtered_list_of_db_names = [s for s in list_of_db_names if s.startswith(prefix)] versioned_strings = [] version_pattern = re.compile( rf"{re.escape(prefix)}v?(\d+)-(\d+)-(\d+)(?:-([a-zA-Z0-9_.]+))?" ) for s in filtered_list_of_db_names: match = version_pattern.search(s) # A version is considered a pre-release if it contains a '-' character (re group 4) if match and match.group(4) is None: version_str = match.group(1) + "." + match.group(2) + "." + match.group(3) version = Version(version_str) versioned_strings.append((s, version)) if versioned_strings: latest_string, _ = max(versioned_strings, key=lambda x: x[1]) self.mongo_db_config["db_simulation_model"] = latest_string self._logger.info( f"Updated the DB simulation model to the latest version {latest_string}" ) else: raise ValueError("Found LATEST in the DB name but no matching versions found in DB.")
[docs] def get_model_parameter( self, parameter, site, array_element_name, parameter_version=None, model_version=None, ): """ Get a single model parameter using model or parameter version. Note that this function should not be called in a loop for many parameters, as it each call queries the database. Parameters ---------- parameter: str Name of the parameter. site: str Site name. array_element_name: str Name of the array element model. parameter_version: str Version of the parameter. model_version: str Version of the model. Returns ------- dict containing the parameter """ collection_name = names.get_collection_name_from_parameter_name(parameter) if model_version: production_table = self._read_production_table_from_mongo_db( collection_name, model_version ) array_element_list = self._get_array_element_list( array_element_name, site, production_table, collection_name ) for array_element in reversed(array_element_list): parameter_version = ( production_table["parameters"].get(array_element, {}).get(parameter) ) if parameter_version: array_element_name = array_element break query = { "parameter_version": parameter_version, "parameter": parameter, } if array_element_name: query["instrument"] = array_element_name if site: query["site"] = site return self._read_mongo_db(query=query, collection_name=collection_name)
[docs] def get_model_parameters(self, site, array_element_name, collection, model_version): """ Get model parameters using the model version. Queries parameters for design and for the specified array element (if necessary). Parameters ---------- site: str Site name. array_element_name: str Name of the array element model (e.g. LSTN-01, MSTx-FlashCam, ILLN-01). model_version: str, list Version(s) of the model. collection: str Collection of array element (e.g. telescopes, calibration_devices). Returns ------- dict containing the parameters """ pars = {} production_table = self._read_production_table_from_mongo_db(collection, model_version) array_element_list = self._get_array_element_list( array_element_name, site, production_table, collection ) for array_element in array_element_list: pars.update( self._get_parameter_for_model_version( array_element, model_version, site, collection, production_table ) ) return pars
[docs] def get_model_parameters_for_all_model_versions(self, site, array_element_name, collection): """ Get model parameters for all model versions. Parameters ---------- site: str Site name. array_element_name: str Name of the array element model (e.g. LSTN-01, MSTx-FlashCam, ILLN-01). collection: str Collection of array element (e.g. telescopes, calibration_devices). Returns ------- dict containing the parameters with model version as first key """ pars = defaultdict(dict) for _model_version in self.get_model_versions(collection): try: parameter_data = self.get_model_parameters( site, array_element_name, collection, _model_version ) pars[_model_version].update(parameter_data) except KeyError: self._logger.debug( f"Skipping model version {_model_version} - {array_element_name} not found" ) continue return pars
def _get_parameter_for_model_version( self, array_element, model_version, site, collection, production_table ): cache_key, cache_dict = self._read_cache( DatabaseHandler.model_parameters_cached, names.validate_site_name(site) if site else None, array_element, model_version, collection, ) if cache_dict: self._logger.debug(f"Found {array_element} in cache (key: {cache_key})") return cache_dict self._logger.debug(f"Did not find {array_element} in cache (key: {cache_key})") try: parameter_version_table = production_table["parameters"][array_element] except KeyError: # allow missing array elements (parameter dict is checked later) return {} DatabaseHandler.model_parameters_cached[cache_key] = self._read_mongo_db( query=self._get_query_from_parameter_version_table( parameter_version_table, array_element, site ), collection_name=collection, ) return DatabaseHandler.model_parameters_cached[cache_key]
[docs] def get_collection(self, db_name, collection_name): """ Get a collection from the DB. Parameters ---------- db_name: str Name of the DB. collection_name: str Name of the collection. Returns ------- pymongo.collection.Collection The collection from the DB. """ db_name = self._get_db_name(db_name) return DatabaseHandler.db_client[db_name][collection_name]
[docs] def get_collections(self, db_name=None, model_collections_only=False): """ List of collections in the DB. Parameters ---------- db_name: str Database name. model_collections_only: bool If True, only return model collections (i.e. exclude fs.files, fs.chunks) Returns ------- list List of collection names """ db_name = db_name or self._get_db_name() if db_name not in self.list_of_collections: self.list_of_collections[db_name] = DatabaseHandler.db_client[ db_name ].list_collection_names() collections = self.list_of_collections[db_name] if model_collections_only: return [collection for collection in collections if not collection.startswith("fs.")] return collections
[docs] def export_model_file( self, parameter, site, array_element_name, model_version=None, parameter_version=None, export_file_as_table=False, ): """ Export single model file from the DB identified by the parameter name. The parameter can be identified by model or parameter version. Files can be exported as astropy tables (ecsv format). Parameters ---------- parameter: str Name of the parameter. site: str Site name. array_element_name: str Name of the array element model (e.g. MSTN, SSTS). parameter_version: str Version of the parameter. model_version: str Version of the model. export_file_as_table: bool If True, export the file as an astropy table (ecsv format). """ parameters = self.get_model_parameter( parameter, site, array_element_name, parameter_version=parameter_version, model_version=model_version, ) self.export_model_files(parameters=parameters, dest=self.io_handler.get_output_directory()) if export_file_as_table: return simtel_table_reader.read_simtel_table( parameter, self.io_handler.get_output_directory().joinpath(parameters[parameter]["value"]), ) return None
[docs] def export_model_files(self, parameters=None, file_names=None, dest=None, db_name=None): """ Export models files from the DB to given directory. The files to be exported can be specified by file_name or retrieved from the database using the parameters dictionary. Parameters ---------- parameters: dict Dict of model parameters file_names: list, str List (or string) of file names to export dest: str or Path Location where to write the files to. Returns ------- file_id: dict of GridOut._id Dict of database IDs of files. """ db_name = self._get_db_name(db_name) if file_names: file_names = [file_names] if not isinstance(file_names, list) else file_names elif parameters: file_names = [ info["value"] for info in parameters.values() if info and info.get("file") and info["value"] is not None ] instance_ids = {} for file_name in file_names: if Path(dest).joinpath(file_name).exists(): instance_ids[file_name] = "file exists" else: file_path_instance = self._get_file_mongo_db(self._get_db_name(), file_name) self._write_file_from_mongo_to_disk(self._get_db_name(), dest, file_path_instance) instance_ids[file_name] = file_path_instance._id # pylint: disable=protected-access return instance_ids
def _get_query_from_parameter_version_table( self, parameter_version_table, array_element_name, site ): """Return query based on parameter version table.""" query_dict = { "$or": [ {"parameter": param, "parameter_version": version} for param, version in parameter_version_table.items() ], } # 'xSTX-design' is a placeholder to ignore 'instrument' field in query. if array_element_name and array_element_name != "xSTx-design": query_dict["instrument"] = array_element_name if site: query_dict["site"] = site return query_dict def _read_mongo_db(self, query, collection_name): """ Query MongoDB. Parameters ---------- query: dict Query to execute. collection_name: str Collection name. Returns ------- dict containing the parameters Raises ------ ValueError if query returned no results. """ db_name = self._get_db_name() collection = self.get_collection(db_name, collection_name) posts = list(collection.find(query)) if not posts: raise ValueError( f"The following query for {collection_name} returned zero results: {query} " ) parameters = {} for post in posts: par_now = post["parameter"] parameters[par_now] = post parameters[par_now]["entry_date"] = ObjectId(post["_id"]).generation_time return {k: parameters[k] for k in sorted(parameters)} def _read_production_table_from_mongo_db(self, collection_name, model_version): """ Read production table from MongoDB. Parameters ---------- collection_name: str Name of the collection. model_version: str Version of the model. Raises ------ ValueError if query returned no results. """ try: return DatabaseHandler.production_table_cached[ self._cache_key(None, None, model_version, collection_name) ] except KeyError: pass query = {"model_version": model_version, "collection": collection_name} collection = self.get_collection(self._get_db_name(), "production_tables") post = collection.find_one(query) if not post: raise ValueError(f"The following query returned zero results: {query}") return { "collection": post["collection"], "model_version": post["model_version"], "parameters": post["parameters"], "design_model": post.get("design_model", {}), "entry_date": ObjectId(post["_id"]).generation_time, }
[docs] def get_model_versions(self, collection_name="telescopes"): """ Get list of model versions from the DB. Parameters ---------- collection_name: str Name of the collection. Returns ------- list List of model versions """ collection = self.get_collection(self._get_db_name(), "production_tables") return sorted( {post["model_version"] for post in collection.find({"collection": collection_name})} )
[docs] def get_array_elements(self, model_version, collection="telescopes"): """ Get list array elements for a given model version and collection from the DB. Parameters ---------- model_version: str Version of the model. collection: str Which collection to get the array elements from: i.e. telescopes, calibration_devices. Returns ------- list Sorted list of all array elements found in collection """ production_table = self._read_production_table_from_mongo_db(collection, model_version) return sorted([entry for entry in production_table["parameters"] if "-design" not in entry])
[docs] def get_design_model(self, model_version, array_element_name, collection="telescopes"): """ Get the design model used for a given array element and a given model version. Parameters ---------- model_version: str Version of the model. array_element_name: str Name of the array element model (e.g. MSTN, SSTS). collection: str Which collection to get the array elements from: i.e. telescopes, calibration_devices. Returns ------- str Design model for a given array element. """ production_table = self._read_production_table_from_mongo_db(collection, model_version) try: return production_table["design_model"][array_element_name] except KeyError: # for eg. array_element_name == 'LSTN-design' returns 'LSTN-design' return array_element_name
[docs] def get_array_elements_of_type(self, array_element_type, model_version, collection): """ Get array elements of a certain type (e.g. 'LSTN') for a DB collection. Does not return 'design' models. Parameters ---------- array_element_type: str Type of the array element (e.g. LSTN, MSTS). model_version: str Version of the model. collection: str Which collection to get the array elements from: i.e. telescopes, calibration_devices. Returns ------- list Sorted list of all array element names found in collection """ production_table = self._read_production_table_from_mongo_db(collection, model_version) all_array_elements = production_table["parameters"] return sorted( [ entry for entry in all_array_elements if entry.startswith(array_element_type) and "-design" not in entry ] )
[docs] def get_simulation_configuration_parameters( self, simulation_software, site, array_element_name, model_version ): """ Get simulation configuration parameters from the DB. Parameters ---------- simulation_software: str Name of the simulation software. site: str Site name. array_element_name: str Name of the array element model (e.g. MSTN, SSTS). model_version: str Version of the model. Returns ------- dict containing the parameters Raises ------ ValueError if simulation_software is not valid. """ if simulation_software == "corsika": return self.get_model_parameters( None, None, model_version=model_version, collection="configuration_corsika", ) if simulation_software == "simtel": return ( self.get_model_parameters( site, array_element_name, model_version=model_version, collection="configuration_sim_telarray", ) if site and array_element_name else {} ) raise ValueError(f"Unknown simulation software: {simulation_software}")
@staticmethod def _get_file_mongo_db(db_name, file_name): """ Extract a file from MongoDB and return GridFS file instance. Parameters ---------- db_name: str the name of the DB with files of tabulated data file_name: str The name of the file requested Returns ------- GridOut A file instance returned by GridFS find_one Raises ------ FileNotFoundError If the desired file is not found. """ db = DatabaseHandler.db_client[db_name] file_system = gridfs.GridFS(db) if file_system.exists({"filename": file_name}): return file_system.find_one({"filename": file_name}) raise FileNotFoundError(f"The file {file_name} does not exist in the database {db_name}") @staticmethod def _write_file_from_mongo_to_disk(db_name, path, file): """ Extract a file from MongoDB and write it to disk. Parameters ---------- db_name: str the name of the DB with files of tabulated data path: str or Path The path to write the file to file: GridOut A file instance returned by GridFS find_one """ db = DatabaseHandler.db_client[db_name] fs_output = gridfs.GridFSBucket(db) with open(Path(path).joinpath(file.filename), "wb") as output_file: fs_output.download_to_stream_by_name(file.filename, output_file)
[docs] def add_production_table(self, db_name, production_table): """ Add a production table to the DB. Parameters ---------- db_name: str the name of the DB. production_table: dict The production table to add to the DB. """ db_name = self._get_db_name(db_name) collection = self.get_collection(db_name, "production_tables") self._logger.info(f"Adding production for {production_table.get('collection')} to to DB") collection.insert_one(production_table) DatabaseHandler.production_table_cached.clear()
[docs] def add_new_parameter( self, db_name, par_dict, collection_name="telescopes", file_prefix=None, ): """ Add a new parameter dictionary to the DB. A new document will be added to the DB, with all fields taken from the input parameters. Parameter dictionaries are validated before submission using the corresponding schema. Parameters ---------- db_name: str the name of the DB par_dict: dict dictionary with parameter data collection_name: str The name of the collection to add a parameter to. file_prefix: str or Path where to find files to upload to the DB """ par_dict = validate_data.DataValidator.validate_model_parameter(par_dict) db_name = self._get_db_name(db_name) collection = self.get_collection(db_name, collection_name) par_dict["value"], _base_unit, _ = value_conversion.get_value_unit_type( value=par_dict["value"], unit_str=par_dict.get("unit", None) ) par_dict["unit"] = _base_unit if _base_unit else None files_to_add_to_db = set() if par_dict["file"] and par_dict["value"]: if file_prefix is None: raise FileNotFoundError( "The location of the file to upload, " f"corresponding to the {par_dict['parameter']} parameter, must be provided." ) file_path = Path(file_prefix).joinpath(par_dict["value"]) files_to_add_to_db.add(f"{file_path}") self._logger.info( f"Adding a new entry to DB {db_name} and collection {collection_name}:\n{par_dict}" ) collection.insert_one(par_dict) for file_to_insert_now in files_to_add_to_db: self._logger.info(f"Will also add the file {file_to_insert_now} to the DB") self.insert_file_to_db(file_to_insert_now, db_name) self._reset_parameter_cache()
def _get_db_name(self, db_name=None): """ Return database name. If not provided, return the default database name. Parameters ---------- db_name: str Database name Returns ------- str Database name """ return self.mongo_db_config["db_simulation_model"] if db_name is None else db_name
[docs] def insert_file_to_db(self, file_name, db_name=None, **kwargs): """ Insert a file to the DB. Parameters ---------- file_name: str or Path The name of the file to insert (full path). db_name: str the name of the DB **kwargs (optional): keyword arguments for file creation. The full list of arguments can be found in, \ https://docs.mongodb.com/manual/core/gridfs/#the-files-collection mostly these are unnecessary though. Returns ------- file_iD: GridOut._id If the file exists, return its GridOut._id, otherwise insert the file and return its" "newly created DB GridOut._id. """ db_name = self._get_db_name(db_name) db = DatabaseHandler.db_client[db_name] file_system = gridfs.GridFS(db) kwargs.setdefault("content_type", "ascii/dat") kwargs.setdefault("filename", Path(file_name).name) if file_system.exists({"filename": kwargs["filename"]}): self._logger.warning( f"The file {kwargs['filename']} exists in the DB. Returning its ID" ) return file_system.find_one( # pylint: disable=protected-access {"filename": kwargs["filename"]} )._id self._logger.debug(f"Writing file to DB: {file_name}") with open(file_name, "rb") as data_file: return file_system.put(data_file, **kwargs)
def _cache_key(self, site=None, array_element_name=None, model_version=None, collection=None): """ Create a cache key for the parameter cache dictionaries. Parameters ---------- site: str Site name. array_element_name: str Array element name. model_version: str Model version. collection: str DB collection name. Returns ------- str Cache key. """ return "-".join( part for part in [model_version, collection, site, array_element_name] if part ) def _read_cache( self, cache_dict, site=None, array_element_name=None, model_version=None, collection=None ): """ Read parameters from cache. Parameters ---------- cache_dict: dict Cache dictionary. site: str Site name. array_element_name: str Array element name. model_version: str Model version. collection: str DB collection name. Returns ------- str Cache key. """ cache_key = self._cache_key(site, array_element_name, model_version, collection) try: return cache_key, cache_dict[cache_key] except KeyError: return cache_key, None def _reset_parameter_cache(self): """Reset the cache for the parameters.""" DatabaseHandler.model_parameters_cached.clear() def _get_array_element_list(self, array_element_name, site, production_table, collection): """ Return list of array elements for DB queries (add design model if needed). Design model is added if found in the production table. Parameters ---------- array_element_name: str Name of the array element. site: str Site name. production_table: dict Production table. collection: str collection of array element (e.g. telescopes, calibration_devices). Returns ------- list List of array elements """ if collection == "configuration_corsika": return ["xSTx-design"] # placeholder to ignore 'instrument' field in query. if collection == "sites": return [f"OBS-{site}"] if names.is_design_type(array_element_name): return [array_element_name] if collection == "configuration_sim_telarray": # get design model from 'telescope' or 'calibration_device' production tables production_table = self._read_production_table_from_mongo_db( names.get_collection_name_from_array_element_name(array_element_name), production_table["model_version"], ) try: return [ production_table["design_model"][array_element_name], array_element_name, ] except KeyError as exc: raise KeyError( f"Failed generated array element list for db query for {array_element_name}" ) from exc