"""General functions useful across different parts of the code."""
import copy
import json
import logging
import os
import tempfile
import time
import urllib.error
import urllib.request
from pathlib import Path
from urllib.parse import urlparse
import numpy as np
import yaml
__all__ = [
"InvalidConfigDataError",
"change_dict_keys_case",
"collect_data_from_file",
"collect_final_lines",
"collect_kwargs",
"get_log_excerpt",
"get_log_level_from_user",
"remove_substring_recursively_from_dict",
"set_default_kwargs",
"sort_arrays",
]
_logger = logging.getLogger(__name__)
[docs]
class InvalidConfigDataError(Exception):
"""Exception for invalid configuration data."""
def join_url_or_path(url_or_path, *args):
"""
Join URL or path with additional subdirectories and file.
This is the equivalent to Path.join(), with extended functionality
working also for URLs.
Parameters
----------
url_or_path: str or Path
URL or path to be extended.
args: list
Additional arguments to be added to the URL or path.
Returns
-------
str or Path
Extended URL or path.
"""
if "://" in str(url_or_path):
return "/".join([url_or_path.rstrip("/"), *args])
return Path(url_or_path).joinpath(*args)
def is_url(url):
"""
Check if a string is a valid URL.
Parameters
----------
url: str
String to be checked.
Returns
-------
bool
True if url is a valid URL.
"""
try:
result = urlparse(url)
return all([result.scheme, result.netloc])
except AttributeError:
return False
def collect_data_from_http(url):
"""
Download yaml or json file from url and return it contents as dict.
File is downloaded as a temporary file and deleted afterwards.
Parameters
----------
url: str
URL of the yaml/json file.
Returns
-------
dict
Dictionary containing the file content.
Raises
------
TypeError
If url is not a valid URL.
FileNotFoundError
If downloading the yaml file fails.
"""
try:
with tempfile.NamedTemporaryFile(mode="w+t") as tmp_file:
urllib.request.urlretrieve(url, tmp_file.name)
if url.endswith("yml") or url.endswith("yaml"):
try:
data = yaml.safe_load(tmp_file)
except yaml.constructor.ConstructorError:
data = _load_yaml_using_astropy(tmp_file)
elif url.endswith("json"):
data = json.load(tmp_file)
elif url.endswith("list"):
lines = tmp_file.readlines()
data = [line.strip() for line in lines]
else:
msg = f"File extension of {url} not supported (should be json or yaml)"
_logger.error(msg)
raise TypeError(msg)
except TypeError as exc:
msg = "Invalid url {url}"
_logger.error(msg)
raise TypeError(msg) from exc
except urllib.error.HTTPError as exc:
msg = f"Failed to download file from {url}"
_logger.error(msg)
raise FileNotFoundError(msg) from exc
_logger.debug(f"Downloaded file from {url}")
return data
[docs]
def collect_data_from_file(file_name):
"""
Collect data from file based on its extension.
Parameters
----------
file_name: str
Name of the yaml/json/ascii file.
Returns
-------
data: dict or list
Data as dict or list.
"""
if is_url(file_name):
return collect_data_from_http(file_name)
with open(file_name, encoding="utf-8") as file:
if Path(file_name).suffix.lower() == ".json":
return json.load(file)
if Path(file_name).suffix.lower() == ".list":
lines = file.readlines()
return [line.strip() for line in lines]
try:
return yaml.safe_load(file)
except yaml.constructor.ConstructorError:
return _load_yaml_using_astropy(file)
[docs]
def collect_kwargs(label, in_kwargs):
"""
Collect kwargs of the type label_* and return them as a dict.
Parameters
----------
label: str
Label to be collected in kwargs.
in_kwargs: dict
kwargs.
Returns
-------
dict
Dictionary with the collected kwargs.
"""
out_kwargs = {}
for key, value in in_kwargs.items():
if label + "_" in key:
out_kwargs[key.replace(label + "_", "")] = value
return out_kwargs
[docs]
def set_default_kwargs(in_kwargs, **kwargs):
"""
Fill in a dict with a set of default kwargs and return it.
Parameters
----------
in_kwargs: dict
Input dict to be filled in with the default kwargs.
**kwargs:
Default kwargs to be set.
Returns
-------
dict
Dictionary containing the default kwargs.
"""
for par, value in kwargs.items():
if par not in in_kwargs.keys():
in_kwargs[par] = value
return in_kwargs
[docs]
def collect_final_lines(file, n_lines):
"""
Collect final lines.
Parameters
----------
file: str or Path
File to collect the lines from.
n_lines: int
Number of lines to be collected.
Returns
-------
str
Final lines collected.
"""
list_of_lines = []
if Path(file).suffix == ".gz":
import gzip # pylint: disable=import-outside-toplevel
file_open_function = gzip.open
else:
file_open_function = open
with file_open_function(file, "rb") as read_obj:
# Move the cursor to the end of the file
read_obj.seek(0, os.SEEK_END)
# Create a buffer to keep the last read line
buffer = bytearray()
# Get the current position of pointer i.e eof
pointer_location = read_obj.tell()
# Loop till pointer reaches the top of the file
while pointer_location >= 0:
# Move the file pointer to the location pointed by pointer_location
read_obj.seek(pointer_location)
# Shift pointer location by -1
pointer_location = pointer_location - 1
# read that byte / character
new_byte = read_obj.read(1)
# If the read byte is new line character then it means one line is read
if new_byte == b"\n":
# Save the line in list of lines
list_of_lines.append(buffer.decode()[::-1])
# If the size of list reaches n_lines, then return the reversed list
if len(list_of_lines) == n_lines:
return "".join(list(reversed(list_of_lines)))
# Reinitialize the byte array to save next line
buffer = bytearray()
else:
# If last read character is not eol then add it in buffer
buffer.extend(new_byte)
# As file is read completely, if there is still data in buffer, then its first line.
if len(buffer) > 0:
list_of_lines.append(buffer.decode()[::-1])
return "".join(list(reversed(list_of_lines)))
[docs]
def get_log_level_from_user(log_level):
"""
Map between logging level from the user to logging levels of the logging module.
Parameters
----------
log_level: str
Log level from the user.
Returns
-------
logging.LEVEL
The requested logging level to be used as input to logging.setLevel().
"""
possible_levels = {
"info": logging.INFO,
"debug": logging.DEBUG,
"warn": logging.WARNING,
"warning": logging.WARNING,
"error": logging.ERROR,
"critical": logging.CRITICAL,
}
try:
log_level_lower = log_level.lower()
except AttributeError:
log_level_lower = log_level
if log_level_lower not in possible_levels:
raise ValueError(
f"'{log_level}' is not a logging level, "
f"only possible ones are {list(possible_levels.keys())}"
)
return possible_levels[log_level_lower]
def copy_as_list(value):
"""
Copy value and, if it is not a list, turn it into a list with a single entry.
Parameters
----------
value single variable of any type or list
Returns
-------
value: list
Copy of value if it is a list of [value] otherwise.
"""
if isinstance(value, str):
return [value]
try:
return list(value)
except TypeError:
return [value]
def program_is_executable(program):
"""
Check if program exists and is executable.
Follows https://stackoverflow.com/questions/377017/
"""
program = Path(program)
def is_exe(fpath):
return fpath.is_file() and os.access(fpath, os.X_OK)
fpath, _ = os.path.split(program)
if fpath:
if is_exe(program):
return program
else:
try:
for path in os.environ["PATH"].split(os.pathsep):
exe_file = Path(path) / program
if is_exe(exe_file):
return exe_file
except KeyError:
_logger.warning("PATH environment variable is not set.")
return None
return None
def _search_directory(directory, filename, rec=False):
if not Path(directory).exists():
_logger.debug(f"Directory {directory} does not exist")
return None
file = Path(directory).joinpath(filename)
if file.exists():
_logger.debug(f"File {filename} found in {directory}")
return file
if rec:
for subdir in Path(directory).iterdir():
if subdir.is_dir():
file = _search_directory(subdir, filename, True)
if file:
return file
return None
def find_file(name, loc):
"""
Search for files inside of given directories, recursively, and return its full path.
Parameters
----------
name: str
File name to be searched for.
loc: Path or list of Path
Location of where to search for the file.
Returns
-------
Path
Full path of the file to be found if existing. Otherwise, None.
Raises
------
FileNotFoundError
If the desired file is not found.
"""
all_locations = [loc] if not isinstance(loc, list) else loc
# Searching file locally
file = _search_directory(".", name)
if file:
return file
# Searching file in given locations
for location in all_locations:
file = _search_directory(location, name, True)
if file:
return file
msg = f"File {name} could not be found in {all_locations}"
_logger.error(msg)
raise FileNotFoundError(msg)
[docs]
def get_log_excerpt(log_file, n_last_lines=30):
"""
Get an excerpt from a log file, namely the n_last_lines of the file.
Parameters
----------
log_file: str or Path
Log file to get the excerpt from.
n_last_lines: int
Number of last lines of the file to get.
Returns
-------
str
Excerpt from log file with header/footer
"""
return (
"\n\nRuntime error - See below the relevant part of the log/err file.\n\n"
f"{log_file}\n"
"====================================================================\n\n"
f"{collect_final_lines(log_file, n_last_lines)}\n\n"
"====================================================================\n"
)
def get_file_age(file_path):
"""Get the age of a file in seconds since the last modification."""
if not Path(file_path).is_file():
raise FileNotFoundError(f"'{file_path}' does not exist or is not a file.")
file_stats = Path(file_path).stat()
modification_time = file_stats.st_mtime
current_time = time.time()
return (current_time - modification_time) / 60
def _process_dict_keys(input_dict, case_func):
"""
Process dictionary keys recursively.
Parameters
----------
input_dict: dict
Dictionary to be processed.
case_func: function
Function to change case of keys (e.g., str.lower, str.upper).
Returns
-------
dict
Processed dictionary with keys changed.
"""
output_dict = {}
for key, value in input_dict.items():
processed_key = case_func(key)
if isinstance(value, dict):
output_dict[processed_key] = _process_dict_keys(value, case_func)
elif isinstance(value, list):
processed_list = [
_process_dict_keys(item, case_func) if isinstance(item, dict) else item
for item in value
]
output_dict[processed_key] = processed_list
else:
output_dict[processed_key] = value
return output_dict
[docs]
def change_dict_keys_case(data_dict, lower_case=True):
"""
Change keys of a dictionary to lower or upper case recursively.
Parameters
----------
data_dict: dict
Dictionary to be converted.
lower_case: bool
Change keys to lower (upper) case if True (False).
Returns
-------
dict
Dictionary with keys converted to lower or upper case.
"""
# Determine which case function to use
case_func = str.lower if lower_case else str.upper
try:
return _process_dict_keys(data_dict, case_func)
except AttributeError as exc:
_logger.error(f"Input is not a proper dictionary: {data_dict}")
raise AttributeError from exc
[docs]
def remove_substring_recursively_from_dict(data_dict, substring="\n"):
"""
Remove substrings from all strings in a dictionary.
Recursively crawls through the dictionary This e.g., allows to remove all newline characters
from a dictionary.
Parameters
----------
data_dict: dict
Dictionary to be converted.
substring: str
Substring to be removed.
Raises
------
AttributeError:
if input is not a proper dictionary.
"""
try:
for key, value in data_dict.items():
if isinstance(value, str):
data_dict[key] = value.replace(substring, "")
elif isinstance(value, list):
modified_items = [
item.replace(substring, "") if isinstance(item, str) else item for item in value
]
modified_items = [
(
remove_substring_recursively_from_dict(item, substring)
if isinstance(item, dict)
else item
)
for item in modified_items
]
data_dict[key] = modified_items
elif isinstance(value, dict):
data_dict[key] = remove_substring_recursively_from_dict(value, substring)
except AttributeError:
_logger.debug(f"Input is not a dictionary: {data_dict}")
return data_dict
[docs]
def sort_arrays(*args):
"""Sort arrays.
Parameters
----------
*args
Arguments to be sorted.
Returns
-------
list
Sorted args.
"""
if len(args) == 0:
return args
order_array = copy.copy(args[0])
new_args = []
for arg in args:
_, value = zip(*sorted(zip(order_array, arg)))
new_args.append(list(value))
return new_args
def user_confirm():
"""
Ask the user to enter y or n (case-insensitive) on the command line.
Returns
-------
bool:
True if the answer is Y/y.
"""
while True:
try:
answer = input("Is this OK? [y/n]").lower()
return answer == "y"
except EOFError:
break
return False
def _get_value_dtype(value):
"""
Get the data type of the given value.
Parameters
----------
Value to determine the data type.
Returns
-------
type:
Data type of the value.
"""
if isinstance(value, (list | np.ndarray)):
value = np.array(value)
return value.dtype
return type(value)
def validate_data_type(reference_dtype, value=None, dtype=None, allow_subtypes=True):
"""
Validate data type of value or type object against a reference data type.
Allow to check for exact data type or allow subtypes (e.g. uint is accepted for int).
Take into account 'file' type as used in the model parameter database.
Parameters
----------
reference_dtype: str
Reference data type to be checked against.
value: any, optional
Value to be checked (if dtype is None).
dtype: type, optional
Type object to be checked (if value is None).
allow_subtypes: bool, optional
If True, allow subtypes to be accepted.
Returns
-------
bool:
True if the data type is valid.
"""
if value is None and dtype is None:
raise ValueError("Either value or dtype must be given.")
if value is not None and dtype is None:
dtype = _get_value_dtype(value)
# Strict comparison
if not allow_subtypes:
return np.issubdtype(dtype, reference_dtype)
# Allow any sub-type of integer or float for success
if (np.issubdtype(dtype, np.str_) or np.issubdtype(dtype, "object")) and reference_dtype in (
"string",
"str",
"file",
):
return True
if np.issubdtype(dtype, np.bool_) and reference_dtype in ("boolean", "bool"):
return True
if np.issubdtype(dtype, np.integer) and (
np.issubdtype(reference_dtype, np.integer) or np.issubdtype(reference_dtype, np.floating)
):
return True
if np.issubdtype(dtype, np.floating) and np.issubdtype(reference_dtype, np.floating):
return True
return False
def convert_list_to_string(data, comma_separated=False, shorten_list=False, collapse_list=False):
"""
Convert arrays to string (if required).
Parameters
----------
data: object
Object of data to convert (e.g., double or list)
comma_separated: bool
If True, returns elements as a comma-separated string (default is space-separated).
shorten_list: bool
If True and all elements in the list are identical, returns a summary string
like "all: value". This is useful to make the configuration files more readable.
collapse_list: bool
If True and all elements in the list are identical, returns a single value
instead of the entire list.
Returns
-------
object or str:
Converted data as string (if required)
"""
if data is None or not isinstance(data, list | np.ndarray):
return data
if shorten_list and len(data) > 10 and all(np.isclose(item, data[0]) for item in data):
return f"all: {data[0]}"
if collapse_list and len(sorted(set(data))) == 1:
data = [data[0]]
if comma_separated:
return ", ".join(str(item) for item in data)
return " ".join(str(item) for item in data)
def convert_string_to_list(data_string, is_float=True):
"""
Convert string (as used e.g. in sim_telarray) to list.
Allow coma or space separated strings.
Parameters
----------
data_string: object
String to be converted
Returns
-------
list, str
Converted data from string (if required).
Return data_string if conversion fails.
"""
try:
if is_float:
return [float(v) for v in data_string.split()]
return [int(v) for v in data_string.split()]
except ValueError:
pass
if "," in data_string:
result = data_string.split(",")
return [item.strip() for item in result]
if " " in data_string:
return data_string.split()
return data_string
def _load_yaml_using_astropy(file):
"""
Load a yaml file using astropy's yaml loader.
Parameters
----------
file: file
File to be loaded.
Returns
-------
dict
Dictionary containing the file content.
"""
# pylint: disable=import-outside-toplevel
import astropy.io.misc.yaml as astropy_yaml
file.seek(0)
return astropy_yaml.load(file)
def read_file_encoded_in_utf_or_latin(file_name):
"""
Read a file encoded in UTF-8 or Latin-1.
Parameters
----------
file_name: str
Name of the file to be read.
Returns
-------
list
List of lines read from the file.
Raises
------
UnicodeDecodeError
If the file cannot be decoded using UTF-8 or Latin-1.
"""
try:
with open(file_name, encoding="utf-8") as file:
lines = file.readlines()
except UnicodeDecodeError:
logging.debug("Unable to decode file using UTF-8. Trying Latin-1.")
try:
with open(file_name, encoding="latin-1") as file:
lines = file.readlines()
except UnicodeDecodeError as exc:
raise UnicodeDecodeError("Unable to decode file using UTF-8 or Latin-1.") from exc
return lines