#!/usr/bin/python3
"""Module for visualization."""
import copy
import logging
import re
from collections import OrderedDict
from pathlib import Path
import astropy.units as u
import matplotlib.pyplot as plt
from astropy.table import QTable
from cycler import cycler
from matplotlib import gridspec
__all__ = [
"get_colors",
"get_lines",
"get_markers",
"plot_1d",
"plot_hist_2d",
"plot_table",
"save_figure",
"set_style",
]
COLORS = {}
COLORS["classic"] = [
"#ba2c54",
"#5B90DC",
"#FFAB44",
"#0C9FB3",
"#57271B",
"#3B507D",
"#794D88",
"#FD6989",
"#8A978E",
"#3B507D",
"#D8153C",
"#cc9214",
]
COLORS["modified classic"] = [
"#D6088F",
"#424D9C",
"#178084",
"#AF99DA",
"#F58D46",
"#634B5B",
"#0C9FB3",
"#7C438A",
"#328cd6",
"#8D0F25",
"#8A978E",
"#ffcb3d",
]
COLORS["autumn"] = [
"#A9434D",
"#4E615D",
"#3C8DAB",
"#A4657A",
"#424D9C",
"#DC575A",
"#1D2D38",
"#634B5B",
"#56276D",
"#577580",
"#134663",
"#196096",
]
COLORS["purples"] = [
"#a57bb7",
"#343D80",
"#EA60BF",
"#B7308E",
"#E099C3",
"#7C438A",
"#AF99DA",
"#4D428E",
"#56276D",
"#CC4B93",
"#DC4E76",
"#5C4AE4",
]
COLORS["greens"] = [
"#268F92",
"#abc14d",
"#8A978E",
"#0C9FB3",
"#BDA962",
"#B0CB9E",
"#769168",
"#5E93A5",
"#178084",
"#B7BBAD",
"#163317",
"#76A63F",
]
COLORS["default"] = COLORS["classic"]
MARKERS = ["o", "s", "v", "^", "*", "P", "d", "X", "p", "<", ">", "h"]
LINES = [
(0, ()), # solid
(0, (1, 1)), # densely dotted
(0, (3, 1, 1, 1)), # densely dashdotted
(0, (5, 5)), # dashed
(0, (3, 1, 1, 1, 1, 1)), # densely dashdotdotted
(0, (5, 1)), # densely dashed
(0, (1, 5)), # dotted
(0, (3, 5, 1, 5)), # dashdotted
(0, (3, 5, 1, 5, 1, 5)), # dashdotdotted
(0, (5, 10)), # loosely dashed
(0, (1, 10)), # loosely dotted
(0, (3, 10, 1, 10)), # loosely dashdotted
]
_logger = logging.getLogger(__name__)
def _add_unit(title, array):
"""
Add a unit to "title" (presumably an axis title).
The unit is extracted from the unit field of the array, in case array is an astropy quantity.
If a unit is found, it is added to title in the form [unit]. If a unit already is present in
title (in the same form), a warning is printed and no unit is added. The function assumes
array not to be empty and returns the modified title.
Parameters
----------
title: str
array: astropy.Quantity
Returns
-------
str
Title with units.
"""
if title and "[" in title and "]" in title:
_logger.warning(
"Tried to add a unit from astropy.unit, "
"but axis already has an explicit unit. Left axis title as is."
)
return title
if not isinstance(array, u.Quantity) or not str(array[0].unit):
return title
unit = str(array[0].unit)
unit_str = f" [{unit}]"
if re.search(r"\d", unit_str):
unit_str = re.sub(r"(\d)", r"^\1", unit_str)
unit_str = unit_str.replace("[", r"[$").replace("]", r"$]")
return f"{title}{unit_str}"
[docs]
def set_style(palette="default", big_plot=False):
"""
Set the plotting style to homogenize plot style across the framework.
The function receives the colour palette name and whether it is a big plot or not.\
The latter sets the fonts and marker to be bigger in case it is a big plot. The available \
colour palettes are as follows:
- classic (default): A classic colorful palette with strong colors and contrast.
- modified classic: Similar to the classic, with slightly different colors.
- autumn: A slightly darker autumn style colour palette.
- purples: A pseudo sequential purple colour palette (not great for contrast).
- greens: A pseudo sequential green colour palette (not great for contrast).
To use the function, simply call it before plotting anything. The function is made public, so \
that it can be used outside the visualize module. However, it is highly recommended to create\
plots only through the visualize module.
Parameters
----------
palette: str
Colour palette.
big_plot: bool
Flag to set fonts and marker bigger. If True, it sets them bigger.
Raises
------
KeyError
if provided palette does not exist.
"""
if palette not in COLORS:
raise KeyError(f"palette must be one of {', '.join(COLORS)}")
fontsize = {"default": 17, "big_plot": 30}
markersize = {"default": 6, "big_plot": 18}
plot_size = "big_plot" if big_plot else "default"
plt.rc("lines", linewidth=2, markersize=markersize[plot_size])
plt.rc(
"axes",
prop_cycle=(
cycler(color=COLORS[palette]) + cycler(linestyle=LINES) + cycler(marker=MARKERS)
),
titlesize=fontsize[plot_size],
labelsize=fontsize[plot_size],
labelpad=5,
grid=True,
axisbelow=True,
)
plt.rc("xtick", labelsize=fontsize[plot_size])
plt.rc("ytick", labelsize=fontsize[plot_size])
plt.rc("legend", loc="best", shadow=False, fontsize="medium")
plt.rc("font", family="serif", size=fontsize[plot_size])
[docs]
def get_colors(palette="default"):
"""
Get the colour list of the palette requested.
If no palette is provided, the default is returned.
Parameters
----------
palette: str
Colour palette.
Returns
-------
list
Colour list.
Raises
------
KeyError
if provided palette does not exist.
"""
if palette not in COLORS:
raise KeyError(f"palette must be one of {', '.join(COLORS)}")
return COLORS[palette]
[docs]
def get_markers():
"""
Get the marker list used in this module.
Returns
-------
list
List with markers.
"""
return MARKERS
[docs]
def get_lines():
"""
Get the line style list used in this module.
Returns
-------
list
List with line styles.
"""
return LINES
def filter_plot_kwargs(kwargs):
"""Filter out kwargs that are not valid for plt.plot."""
valid_keys = {
"color",
"linestyle",
"linewidth",
"marker",
"markersize",
"markerfacecolor",
"markeredgecolor",
"markeredgewidth",
"alpha",
"label",
"zorder",
"dashes",
"gapcolor",
"solid_capstyle",
"solid_joinstyle",
"dash_capstyle",
"dash_joinstyle",
"antialiased",
}
return {k: v for k, v in kwargs.items() if k in valid_keys}
[docs]
def plot_1d(data, **kwargs):
"""
Produce a high contrast one dimensional plot from multiple data sets.
A ratio plot can be added at the bottom to allow easy comparison. Additional options,
such as plot title, plot legend, etc., are given in kwargs. Any option that can be
changed after plotting (e.g., axes limits, log scale, etc.) should be done using the
returned plt instance.
Parameters
----------
data: numpy structured array or a dictionary of structured arrays
Each structured array has two columns, the first is the x-axis and the second the y-axis.
The titles of the columns are set as the axes titles.
The labels of each dataset set are given in the dictionary keys
and will be used in the legend.
**kwargs:
* palette: string
Choose a colour palette (see set_style for additional information).
* title: string
Set a plot title.
* no_legend: bool
Do not print a legend for the plot.
* big_plot: bool
Increase marker and font sizes (like in a wide light curve).
* no_markers: bool
Do not print markers.
* empty_markers: bool
Print empty (hollow) markers
* plot_ratio: bool
Add a ratio plot at the bottom. The first entry in the data dictionary
is used as the reference for the ratio.
If data dictionary is not an OrderedDict, the reference will be random.
* plot_difference: bool
Add a difference plot at the bottom. The first entry in the data dictionary
is used as the reference for the difference.
If data dictionary is not an OrderedDict, the reference will be random.
* Any additional kwargs for plt.plot
Returns
-------
pyplot.figure
Instance of pyplot.figure in which the plot was produced
Raises
------
ValueError
if asked to plot a ratio or difference with just one set of data
"""
kwargs = handle_kwargs(kwargs)
data_dict, plot_ratio, plot_difference = prepare_data(data, kwargs)
fig, ax1, gs = setup_plot(kwargs, plot_ratio, plot_difference)
plot_args = filter_plot_kwargs(kwargs)
plot_main_data(data_dict, kwargs, plot_args)
if plot_ratio or plot_difference:
plot_ratio_difference(ax1, data_dict, plot_ratio, gs, plot_args)
if not (plot_ratio or plot_difference):
plt.tight_layout()
return fig
def handle_kwargs(kwargs):
"""Extract and handle keyword arguments."""
kwargs_defaults = {
"palette": "default",
"big_plot": False,
"title": "",
"no_legend": False,
"no_markers": False,
"empty_markers": False,
"plot_ratio": False,
"plot_difference": False,
"xscale": "linear",
"yscale": "linear",
"xlim": (None, None),
"ylim": (None, None),
"xtitle": None,
"ytitle": None,
}
for key, default in kwargs_defaults.items():
kwargs[key] = kwargs.pop(key, default)
if kwargs["no_markers"]:
kwargs["marker"] = "None"
kwargs["linewidth"] = 4
if kwargs["empty_markers"]:
kwargs["markerfacecolor"] = "None"
return kwargs
def prepare_data(data, kwargs):
"""Prepare data for plotting."""
if not isinstance(data, dict):
data_dict = {"_default": data}
else:
data_dict = data
if (kwargs["plot_ratio"] or kwargs["plot_difference"]) and len(data_dict) < 2:
raise ValueError("Asked to plot a ratio or difference with just one set of data")
return data_dict, kwargs["plot_ratio"], kwargs["plot_difference"]
def setup_plot(kwargs, plot_ratio, plot_difference):
"""Set up the plot style, figure, and gridspec."""
set_style(kwargs["palette"], kwargs["big_plot"])
if plot_ratio or plot_difference:
gs = gridspec.GridSpec(2, 1, height_ratios=[3, 1])
fig = plt.figure(figsize=(8, 8))
gs.update(hspace=0.02)
else:
gs = gridspec.GridSpec(1, 1)
fig = plt.figure(figsize=(8, 6))
plt.subplot(gs[0])
ax1 = plt.gca()
return fig, ax1, gs
def _plot_error_plots(kwargs, data_now, x_col, y_col, x_err_col, y_err_col, color):
"""Plot error plots."""
if kwargs.get("error_type") == "fill_between" and y_err_col:
plt.fill_between(
data_now[x_col],
data_now[y_col] - data_now[y_err_col],
data_now[y_col] + data_now[y_err_col],
alpha=0.2,
color=color,
)
if kwargs.get("error_type") == "errorbar" and (x_err_col or y_err_col):
plt.errorbar(
data_now[x_col],
data_now[y_col],
xerr=data_now[x_err_col] if x_err_col else None,
yerr=data_now[y_err_col] if y_err_col else None,
fmt=".",
color=color,
)
def _get_data_columns(data_now):
"""Return data columns depending on availability."""
columns = data_now.dtype.names
assert len(columns) >= 2, "Input array must have at least two columns with titles."
x_col, y_col = columns[:2]
if len(columns) == 3:
x_err_col = None
y_err_col = columns[2]
elif len(columns) == 4:
x_err_col = columns[2]
y_err_col = columns[3]
else:
x_err_col = None
y_err_col = None
return x_col, y_col, x_err_col, y_err_col
def plot_main_data(data_dict, kwargs, plot_args):
"""Plot the main data."""
for label, data_now in data_dict.items():
x_col, y_col, x_err_col, y_err_col = _get_data_columns(data_now)
x_title = kwargs.get("xtitle", x_col)
y_title = kwargs.get("ytitle", y_col)
x_title_unit = _add_unit(x_title, data_now[x_col])
y_title_unit = _add_unit(y_title, data_now[y_col])
(line,) = plt.plot(data_now[x_col], data_now[y_col], label=label, **plot_args)
color = line.get_color()
_plot_error_plots(kwargs, data_now, x_col, y_col, x_err_col, y_err_col, color)
plt.xscale(kwargs["xscale"])
plt.yscale(kwargs["yscale"])
plt.xlim(kwargs["xlim"])
plt.ylim(kwargs["ylim"])
plt.ylabel(y_title_unit)
if not (kwargs["plot_ratio"] or kwargs["plot_difference"]):
plt.xlabel(x_title_unit)
if kwargs["title"]:
plt.title(kwargs["title"], y=1.02)
if "_default" not in data_dict and not kwargs["no_legend"]:
plt.legend()
def plot_ratio_difference(ax1, data_dict, plot_ratio, gs, plot_args):
"""Plot the ratio or difference plot."""
plt.subplot(gs[1], sharex=ax1)
plt.plot([], []) # Advance cycler for consistent colors/styles
data_ref_name = next(iter(data_dict))
for label, data_now in data_dict.items():
if label == data_ref_name:
continue
x_title, y_title = data_now.dtype.names[0], data_now.dtype.names[1]
x_title_unit = _add_unit(x_title, data_now[x_title])
if plot_ratio:
y_values = data_now[y_title] / data_dict[data_ref_name][y_title]
y_title_ratio = f"Ratio to {data_ref_name}"
else:
y_values = data_now[y_title] - data_dict[data_ref_name][y_title]
y_title_ratio = f"Difference to {data_ref_name}"
plt.plot(data_now[x_title], y_values, **plot_args)
plt.xlabel(x_title_unit)
if len(y_title_ratio) > 20 and plot_ratio:
y_title_ratio = "Ratio"
plt.ylabel(y_title_ratio)
ylim = plt.gca().get_ylim()
nbins = min(int((ylim[1] - ylim[0]) / 0.05 + 1), 6)
plt.locator_params(axis="y", nbins=nbins)
[docs]
def plot_table(table, y_title, **kwargs):
"""
Produce a high contrast one dimensional plot from the data in an astropy.Table.
A ratio plot can be added at the bottom to allow easy comparison. Additional options, such
as plot title, plot legend, etc., are given in kwargs. Any option that can be changed after
plotting (e.g., axes limits, log scale, etc.) should be done using the returned plt instance.
Parameters
----------
table: astropy.Table or astropy.QTable
The first column of the table is the x-axis and the second column is the y-axis. Any \
additional columns will be treated as additional data to plot. The column titles are \
used in the legend (except for the first column).
y_title: str
The y-axis title.
**kwargs:
* palette: choose a colour palette (see set_style for additional information).
* title: set a plot title.
* no_legend: do not print a legend for the plot.
* big_plot: increase marker and font sizes (like in a wide light curve).
* no_markers: do not print markers.
* empty_markers: print empty (hollow) markers
* plot_ratio: bool
Add a ratio plot at the bottom. The first entry in the data dictionary is used as the \
reference for the ratio. If data dictionary is not an OrderedDict, the reference will be\
random.
* plot_difference: bool
Add a difference plot at the bottom. The first entry in the data dictionary is used as \
the reference for the difference. If data dictionary is not an OrderedDict, the reference\
will be random.
* Any additional kwargs for plt.plot
Returns
-------
pyplot.fig
Instance of pyplot.fig.
Raises
------
ValueError
if table has less than two columns.
"""
if len(table.keys()) < 2:
raise ValueError("Table has to have at least two columns")
x_axis = table.keys()[0]
data_dict = OrderedDict()
for column in table.keys()[1:]:
data_dict[column] = QTable([table[x_axis], table[column]], names=[x_axis, y_title])
return plot_1d(data_dict, **kwargs)
[docs]
def plot_hist_2d(data, **kwargs):
"""
Produce a two dimensional histogram plot.
Any option that can be changed after plotting (e.g., axes limits, log scale, etc.)
should be done using the returned plt instance.
Parameters
----------
data: numpy structured array
The columns of the structured array are used as the x-axis and y-axis titles.
**kwargs:
* title: set a plot title.
* Any additional kwargs for plt.hist2d
Returns
-------
pyplot.figure
Instance of pyplot.figure in which the plot was produced.
"""
cmap = plt.cm.gist_heat_r
if "title" in kwargs:
title = kwargs["title"]
kwargs.pop("title", None)
else:
title = ""
# Set default style since the usual options do not affect 2d plots (for now).
set_style()
gs = gridspec.GridSpec(1, 1)
fig = plt.figure(figsize=(8, 6))
##########################################################################################
# Plot the data
##########################################################################################
plt.subplot(gs[0])
assert len(data.dtype.names) == 2, "Input array must have two columns with titles."
x_title, y_title = data.dtype.names[0], data.dtype.names[1]
x_title_unit = _add_unit(x_title, data[x_title])
y_title_unit = _add_unit(y_title, data[y_title])
plt.hist2d(data[x_title], data[y_title], cmap=cmap, **kwargs)
plt.xlabel(x_title_unit)
plt.ylabel(y_title_unit)
plt.gca().set_aspect("equal", adjustable="datalim")
if len(title) > 0:
plt.title(title, y=1.02)
fig.tight_layout()
return fig
def plot_simtel_ctapipe(filename, cleaning_args, distance, return_cleaned=False):
"""
Read in a sim_telarray file and plots reconstructed photoelectrons via ctapipe.
Parameters
----------
filename : str
Path to the sim_telarray file.
cleaning_args : tuple, optional
Cleaning parameters as (boundary_thresh, picture_thresh, min_number_picture_neighbors).
distance : astropy Quantity, optional
Distance to the target.
return_cleaned : bool, optional
If True, apply cleaning to the image.
Returns
-------
fig : matplotlib.figure.Figure
The matplotlib figure containing the plot.
"""
# pylint:disable=import-outside-toplevel
import numpy as np
from ctapipe.calib import CameraCalibrator
from ctapipe.image import tailcuts_clean
from ctapipe.io import EventSource
from ctapipe.visualization import CameraDisplay
default_cleaning_levels = {
"CHEC": (2, 4, 2),
"LSTCam": (3.5, 7, 2),
"FlashCam": (3.5, 7, 2),
"NectarCam": (4, 8, 2),
}
source = EventSource(filename, max_events=1)
event = None
events = [copy.deepcopy(event) for event in source]
if len(events) > 1:
event = events[-1]
else:
try:
event = events[0]
except IndexError:
event = events
tel_id = sorted(event.r1.tel.keys())[0]
calib = CameraCalibrator(subarray=source.subarray)
calib(event)
geometry = source.subarray.tel[1].camera.geometry
image = event.dl1.tel[tel_id].image
cleaned = image.copy()
if return_cleaned:
if cleaning_args is None:
boundary, picture, min_neighbors = default_cleaning_levels[geometry.name]
else:
boundary, picture, min_neighbors = cleaning_args
mask = tailcuts_clean(
geometry,
image,
picture_thresh=picture,
boundary_thresh=boundary,
min_number_picture_neighbors=min_neighbors,
)
cleaned[~mask] = 0
fig, ax = plt.subplots(dpi=300)
title = f"CT{tel_id}, run {event.index.obs_id} event {event.index.event_id}"
disp = CameraDisplay(geometry, image=cleaned, norm="symlog", ax=ax)
disp.cmap = "RdBu_r"
disp.add_colorbar(fraction=0.02, pad=-0.1)
disp.set_limits_percent(100)
ax.set_title(title, pad=20)
ax.annotate(
f"tel type: {source.subarray.tel[1].type.name}\n"
f"optics: {source.subarray.tel[1].optics.name}\n"
f"camera: {source.subarray.tel[1].camera_name}\n"
f"distance: {distance.to(u.m)}",
xy=(0, 0),
xytext=(0.1, 1),
xycoords="axes fraction",
va="top",
size=7,
)
ax.annotate(
f"dl1 image,\ntotal $p.e._{{reco}}$: {np.round(np.sum(image))}\n",
xy=(0, 0),
xytext=(0.75, 1),
xycoords="axes fraction",
va="top",
ha="left",
size=7,
)
ax.set_axis_off()
fig.tight_layout()
return fig