from typing import Optional
import torch
from torch_geometric.data import Batch
from torch_scatter import scatter_add
from .. import calorimeter
from .shift_multi_hits import shift_multi_hits
from .utils import fix_slice_dict_nodeattr, ptr_from_batchidx, scatter_sort
[docs]def shift_sum_multi_hits(batch, forbid_dublicates=False):
"""
Faster combination of `shift_sum_multi_hits` and `shift_sum_multi_hits`.
First shifts hits assigned to the same cell to empty neigboring cells, then
sums the energy of then hits remaining in the same cell.
If forbid_dublicates=False, the function also verifies that there were no duplicate hits in the original data.
The function modifies the batch in place, updating batch.x, batch.batch, and batch.ptr.
Parameters
----------
batch : Batch
A Batch object from the PyTorch Geometric library that contains the point cloud
representation of the events. Batch.x contains the hit energy and 3D coordinates of hits.
Batch.batch contains the indices that map which hit belongs to which shower.
forbid_dublicates : bool, optional
If True, asserts that there were no duplicate hits in the original data. Defaults to True.
shiftmultihit : bool, optional
If True, tries to move hits from overfilled cells to neighboring empty cells. Defaults to True.
Returns
-------
batch : Batch
The modified Batch object where duplicate hits have been summed.
"""
# batch = batch.to("cpu")
batchidx = batch.batch
assert (batchidx.diff() >= 0).all()
globalidx = calorimeter.globalidx_from_pos(batch.x[:, 1:].long(), batchidx)
# get new positions and global index
# and the current index of these events
batch, globalidx = shift_multi_hits(
batch, globalidx.clone(), return_globalidx=True
)
batch = sum_multi_hits(batch, globalidx, forbid_dublicates)
return batch
[docs]def sum_multi_hits(
batch: Batch,
globalidx: Optional[torch.Tensor] = None,
forbid_dublicates: bool = False,
):
"""
Sums the energy of duplicate hits in the same cell for each event.
If fake=False, the function also verifies that there were no duplicate hits in the original data.
The function modifies the batch in place, updating batch.x, batch.batch, and batch.ptr.
Parameters
----------
batch : Batch
A Batch object from the PyTorch Geometric library that contains the point cloud
representation of the events. Batch.x contains the hit energy and 3D coordinates of hits.
Batch.batch contains the indices that map which hit belongs to which shower.
forbid_dublicates : bool, optional
If True, asserts that there were no duplicate hits in the original data. Defaults to True.
Returns
-------
batch : Batch
The modified Batch object where duplicate hits have been summed.
"""
dev = batch.x.device
batchidx = batch.batch
if globalidx is None:
globalidx: torch.Tensor = calorimeter.globalidx_from_pos(
batch.x[:, 1:].long(), batchidx
)
# sort the globalidx
globalidx, index_perm = scatter_sort(globalidx, batchidx)
batch.x = batch.x[index_perm]
hitE = batch.x[:, 0]
pos = batch.x[:, 1:].long()
assert (batchidx[index_perm] == batchidx).all()
# unique_cells_idx counts up every time a new cell in an
# even is accessed in globalidx
# counts gives the times the cell/event idx is occupied
_, unique_cells_idx, counts = torch.unique(
globalidx, return_inverse=True, return_counts=True
)
if forbid_dublicates:
assert (counts - 1 == 0).all()
# begin sum
hitE_new = scatter_add(hitE, unique_cells_idx)
sel_new_idx = counts.cumsum(-1) - 1
if forbid_dublicates:
assert (sel_new_idx == torch.arange(len(batch.x)).to(dev)).all()
batchidx_new = batchidx[sel_new_idx]
pos_new = pos[sel_new_idx]
# count the cells, that have been hit multiple times
n_multihit = scatter_add(counts - 1, batchidx_new)
if forbid_dublicates:
assert (n_multihit == 0).all()
new_counts = torch.unique_consecutive(batchidx_new, return_counts=True)[1]
x_new = torch.hstack([hitE_new.reshape(-1, 1), pos_new])
# # TODO remove sanity test:
# old_counts = torch.unique_consecutive(batchidx, return_counts=True)[1]
# if "n_pointsv" in batch.keys:
# assert (old_counts == batch.n_pointsv).all()
# assert ((old_counts - new_counts) == n_multihit).all()
# assert torch.allclose(
# scatter_add(hitE_new, batchidx_new), scatter_add(hitE, batchidx)
# )
if forbid_dublicates:
assert (n_multihit == 0).all()
assert (batch.batch == batchidx_new).all()
assert (batch.n_pointsv == new_counts).all()
for i in range(4):
assert torch.allclose(batch.x.T[i][index_perm], x_new.T[i])
batch.n_multihit = n_multihit
batch.batch = batchidx_new
batch.x = x_new
batch.n_pointsv = new_counts
# need to shift the ptr by the number of removed hits
batch.ptr = ptr_from_batchidx(batchidx_new)
fix_slice_dict_nodeattr(batch, "x")
return batch.to(dev)