olympus.event_generation.photon_propagation.norm_flow_photons
Functions:
-
make_generate_norm_flow_photons–Build a photon-generation function backed by normalizing flow models.
make_generate_norm_flow_photons
make_generate_norm_flow_photons(shape_model_path, counts_model_path, c_medium)
Build a photon-generation function backed by normalizing flow models.
Loads the shape and counts normalizing flow models from disk and returns a compiled JAX function that samples detected photon arrival times for a given set of sources and modules.
| Parameters: |
|
|---|
| Returns: |
|
|---|
Source code in olympus/event_generation/photon_propagation/norm_flow_photons.py
def make_generate_norm_flow_photons(shape_model_path, counts_model_path, c_medium):
"""Build a photon-generation function backed by normalizing flow models.
Loads the shape and counts normalizing flow models from disk and returns a
compiled JAX function that samples detected photon arrival times for a given
set of sources and modules.
Parameters
----------
shape_model_path : str
Path to the pickled shape normalizing flow model.
counts_model_path : str
Path to the pickled photon counts model.
c_medium : float
Speed of light in the medium in meters per nanosecond.
Returns
-------
generate_norm_flow_photons : callable
Function with signature
``(module_coords, module_efficiencies, source_pos, source_dir,
source_time, source_nphotons, seed) -> awkward.Array``
that returns per-module detected photon arrival times.
"""
shape_config, shape_params = pickle.load(open(shape_model_path, "rb"))
counts_config, counts_params = pickle.load(open(counts_model_path, "rb"))
shape_conditioner = make_shape_conditioner_fn(
shape_config["mlp_hidden_size"],
shape_config["mlp_num_layers"],
shape_config["flow_num_bins"],
shape_config["flow_num_layers"],
)
@jax.jit
def apply_fn(params, x):
return shape_conditioner.apply(params, x)
dist_builder = traf_dist_builder(
shape_config["flow_num_layers"],
(shape_config["flow_rmin"], shape_config["flow_rmax"]),
return_base=True,
)
counts_net = make_counts_net_fn(counts_config)
# def sample_model(traf_params, key):
# return sample_shape_model(dist_builder, traf_params,
# traf_params.shape[0], key)
@jax.jit
def sample_model_inner(traf_params, key):
return sample_shape_model(dist_builder, traf_params, traf_params.shape[0], key)
def sample_model(traf_params, key):
base = 4
log_cnt = np.log(traf_params.shape[0]) / np.log(base)
pad_len = int(np.power(base, np.ceil(log_cnt)))
padded = jnp.pad(traf_params, ((0, pad_len - traf_params.shape[0]), (0, 0)))
result = sample_model_inner(padded, key)
return result[: traf_params.shape[0]]
def generate_norm_flow_photons(
module_coords,
module_efficiencies,
source_pos,
source_dir,
source_time,
source_nphotons,
seed=31337,
):
# TODO: Reimplement using padding / bucket compile (jax.mask???)
if isinstance(seed, int):
key = random.PRNGKey(seed)
else:
key = seed
inp_pars, time_geo = sources_to_model_input(
module_coords,
source_pos,
source_dir,
source_time,
c_medium,
)
inp_pars = jnp.swapaxes(inp_pars, 0, 1)
time_geo = jnp.swapaxes(time_geo, 0, 1)
# flatten [densely pack [modules, sources] in 1D array]
inp_pars = inp_pars.reshape(
(source_pos.shape[0] * module_coords.shape[0], inp_pars.shape[-1])
)
time_geo = time_geo.reshape(
(source_pos.shape[0] * module_coords.shape[0], time_geo.shape[-1])
)
source_photons = jnp.tile(source_nphotons, module_coords.shape[0]).T.ravel()
mod_eff_factor = jnp.repeat(module_efficiencies, source_pos.shape[0])
# Normalizing flows only built up to 300
# TODO: Check lower bound as well
distance_mask = inp_pars[:, 0] < np.log10(300)
inp_params_masked = inp_pars[distance_mask]
time_geo_masked = time_geo[distance_mask]
source_photons_masked = source_photons[distance_mask]
mod_eff_factor_masked = mod_eff_factor[distance_mask]
# Eval count net to obtain survival fraction
ph_frac = jnp.power(10, counts_net.apply(counts_params, inp_params_masked)).squeeze()
# Sample number of detected photons
n_photons_masked = ph_frac * source_photons_masked * mod_eff_factor_masked
key, subkey = random.split(key)
n_photons_masked = (
random.poisson(subkey, n_photons_masked, shape=n_photons_masked.shape)
.squeeze()
.astype(jnp.int32)
)
if jnp.all(n_photons_masked == 0):
times = [] * module_coords.shape[0]
return ak.Array(times)
# Obtain flow parameters and repeat them for each detected photon
traf_params = apply_fn(shape_params, inp_params_masked)
traf_params_rep = jnp.repeat(traf_params, n_photons_masked, axis=0)
# Also repeat the geometric time for each detected photon
time_geo_rep = jnp.repeat(time_geo_masked, n_photons_masked, axis=0).squeeze()
# Calculate number of photons per module
# Start with zero array and fill in the poisson samples using distance mask
n_photons = jnp.zeros(source_pos.shape[0] * module_coords.shape[0], dtype=jnp.int32)
n_photons = n_photons.at[distance_mask].set(n_photons_masked)
n_photons = n_photons.reshape(module_coords.shape[0], source_pos.shape[0])
n_ph_per_mod = np.sum(n_photons, axis=1)
# Sample times from flow
key, subkey = random.split(key)
samples = sample_model(traf_params_rep, subkey)
times = np.atleast_1d(np.asarray(samples.squeeze() + time_geo_rep))
if len(times) == 1:
ix = np.argwhere(n_ph_per_mod).squeeze()
times = [[] if i != ix else times for i in range(module_coords.shape[0])]
else:
# Split per module and convert to awkward.Array
times = np.split(times, np.cumsum(n_ph_per_mod)[:-1])
return ak.Array(times)
return generate_norm_flow_photons