from math import prod
import torch
from torch_geometric.data import Batch
from torch_scatter import scatter_add
from .. import calorimeter
from .utils import _construct_global_cellevent_idx
[docs]def pc_to_voxel(batch: Batch) -> torch.Tensor:
"""
Converts a pytorch geometric batch of point clouds into a torch batch of 3D voxel grids.
Parameters
----------
batch : Batch
A Batch object from the PyTorch Geometric library that contains the point cloud
representation of the events.
Returns
-------
torch.Tensor
A tensor of shape (batch_size, num_z, num_alpha, num_r), where each element
represents the energy in the corresponding voxel.
"""
dims = calorimeter.dims
batch_size = int(batch.batch[-1] + 1)
x = batch.x
shower_index = batch.batch
Ehit = x.T[0]
valid_coordinates = x.T[1:].int()
indices = torch.cat((shower_index.unsqueeze(1), valid_coordinates.t()), dim=1)
full_event_cell_idx = _construct_global_cellevent_idx(batch_size).to(x.device)
scatter_index = full_event_cell_idx[
indices[..., 0], indices[..., 1], indices[..., 2], indices[..., 3]
]
vox = scatter_add(
src=Ehit,
index=scatter_index,
dim_size=prod(dims) * batch_size,
)
return vox.reshape(batch_size, *dims)