Source code for caloutils.variables.calculate_variables

from torch_geometric.data import Batch

from ..processing import shift_multi_hits, shift_sum_multi_hits, sum_multi_hits
from ..processing.batch_to_Exyz import batch_to_Exyz
from .analyze_layers import analyze_layers
from .energy_ratios import cyratio, sphereratio
from .first_principal_components import first_principal_components
from .response import response


[docs]def calc_vars( batch: Batch, sum_multihits: bool = True, shift_multihits: bool = False, vars: list = [ "voxel", "sphereratio", "cyratio", "fpc", "showershape", "response", ], ): """ Calculates specified variables for each event in a batch. The function supports calculation of "sphereratio", "cyratio", "fpc", "showershape", and "response". Calculated variables are added to the input batch object. Parameters ---------- batch : Batch A Batch object from the PyTorch Geometric library that contains the point cloud representation of the events. sum_multihits : bool, optional If True, aggregates the energies of hits in the same cell for each event. Defaults to True. shift_multihits : bool, optional If True, tries to shift hits that occupy the same cell to neihgboring, empty cells. Defaults to False. vars : list of str, optional List of variables to calculate for the events in the batch. Defaults to all supported variables: ["voxel","sphereratio","cyratio","fpc","showershape","response"]. Returns ------- Batch The input Batch object with calculated variables added. """ if sum_multihits and shift_multihits: batch = shift_sum_multi_hits(batch) elif sum_multihits: batch = sum_multi_hits(batch) elif shift_multihits: batch = shift_multi_hits(batch) batch = batch_to_Exyz(batch) if "sphereratio" in vars: batch["sphereratio"] = sphereratio(batch) if "cyratio" in vars: batch["cyratio"] = cyratio(batch) if "fpc" in vars: batch["fpc"] = first_principal_components(batch) if "showershape" in vars: batch["showershape"] = analyze_layers(batch) if "response" in vars: batch["response"] = response(batch) return batch