Source code for caloutils.processing.batch_to_Exyz

import torch
from torch_geometric.data import Batch

from .. import calorimeter


[docs]def batch_to_Exyz(batch: Batch) -> Batch: """ Transforms the cylindrical coordinate representation (z, alpha, r) in a batch of graphs to Cartesian coordinates (x, y, z), and calculates pseudo-rapidity (eta) and azimuthal angle (phi) for each node in the graph. Parameters ---------- batch : Batch The batch of graphs where each node feature vector includes a cylindrical coordinate representation (Energy, z, alpha, r) where z is the long axis of the cylindrical coordinate system, alpha is the azimuthal angle, and r is the radial distance from the z-axis. Returns ------- batch : Batch The modified batch object where an additional "xyz" attribute is added which includes the Cartesian coordinates for each node in the graphs. Additionally, "eta" (pseudo-rapidity) and "phi" (azimuthal angle) attributes are calculated and added to the batch object. Note ---- Pseudo-rapidity (eta) and azimuthal angle (phi) are commonly used in particle physics to describe the direction of a particle relative to the beam axis (z-axis in this context). """ zalphar: torch.Tensor = batch.x[..., [1, 2, 3]] num_alpha = calorimeter.num_alpha # num_z = 45 # num_r = 9 z, alpha, r = zalphar.T # shift idx by one to go from (min,max) # (0, 1-1/num) -> (1/num,1) # z = (z + 1) / num_z # r = (r + 1) / num_r alpha = (alpha + 1) / num_alpha * torch.pi * 2 y = r * torch.cos(alpha) x = r * torch.sin(alpha) batch.xyz = torch.stack([x, y, z]).T # r # ^ # / | # / | # / | # / θ) | # ------> z theta = torch.arctan(r / z) if (theta.abs() > torch.pi).any(): raise RuntimeError("θ not in forward direction") if (theta < 0).any(): raise RuntimeError("θ not in forward direction") batch.eta = -torch.log(torch.tan(theta / 2.0)).reshape(-1, 1) batch.phi = alpha.reshape(-1, 1) return batch