olympus.event_generation.photon_propagation.norm_flow_photons

Functions:

make_generate_norm_flow_photons

make_generate_norm_flow_photons(shape_model_path, counts_model_path, c_medium)

Build a photon-generation function backed by normalizing flow models.

Loads the shape and counts normalizing flow models from disk and returns a compiled JAX function that samples detected photon arrival times for a given set of sources and modules.

Parameters:
  • shape_model_path (str) –

    Path to the pickled shape normalizing flow model.

  • counts_model_path (str) –

    Path to the pickled photon counts model.

  • c_medium (float) –

    Speed of light in the medium in meters per nanosecond.

Returns:
  • generate_norm_flow_photons( callable ) –

    Function with signature (module_coords, module_efficiencies, source_pos, source_dir, source_time, source_nphotons, seed) -> awkward.Array that returns per-module detected photon arrival times.

Source code in olympus/event_generation/photon_propagation/norm_flow_photons.py
def make_generate_norm_flow_photons(shape_model_path, counts_model_path, c_medium):
    """Build a photon-generation function backed by normalizing flow models.

    Loads the shape and counts normalizing flow models from disk and returns a
    compiled JAX function that samples detected photon arrival times for a given
    set of sources and modules.

    Parameters
    ----------
    shape_model_path : str
        Path to the pickled shape normalizing flow model.
    counts_model_path : str
        Path to the pickled photon counts model.
    c_medium : float
        Speed of light in the medium in meters per nanosecond.

    Returns
    -------
    generate_norm_flow_photons : callable
        Function with signature
        ``(module_coords, module_efficiencies, source_pos, source_dir,
        source_time, source_nphotons, seed) -> awkward.Array``
        that returns per-module detected photon arrival times.
    """
    shape_config, shape_params = pickle.load(open(shape_model_path, "rb"))
    counts_config, counts_params = pickle.load(open(counts_model_path, "rb"))

    shape_conditioner = make_shape_conditioner_fn(
        shape_config["mlp_hidden_size"],
        shape_config["mlp_num_layers"],
        shape_config["flow_num_bins"],
        shape_config["flow_num_layers"],
    )

    @jax.jit
    def apply_fn(params, x):
        return shape_conditioner.apply(params, x)

    dist_builder = traf_dist_builder(
        shape_config["flow_num_layers"],
        (shape_config["flow_rmin"], shape_config["flow_rmax"]),
        return_base=True,
    )

    counts_net = make_counts_net_fn(counts_config)

    # def sample_model(traf_params, key):
    # return sample_shape_model(dist_builder, traf_params,
    # traf_params.shape[0], key)

    @jax.jit
    def sample_model_inner(traf_params, key):
        return sample_shape_model(dist_builder, traf_params, traf_params.shape[0], key)

    def sample_model(traf_params, key):

        base = 4
        log_cnt = np.log(traf_params.shape[0]) / np.log(base)
        pad_len = int(np.power(base, np.ceil(log_cnt)))

        padded = jnp.pad(traf_params, ((0, pad_len - traf_params.shape[0]), (0, 0)))

        result = sample_model_inner(padded, key)

        return result[: traf_params.shape[0]]

    def generate_norm_flow_photons(
        module_coords,
        module_efficiencies,
        source_pos,
        source_dir,
        source_time,
        source_nphotons,
        seed=31337,
    ):

        # TODO: Reimplement using padding / bucket compile (jax.mask???)

        if isinstance(seed, int):
            key = random.PRNGKey(seed)
        else:
            key = seed

        inp_pars, time_geo = sources_to_model_input(
            module_coords,
            source_pos,
            source_dir,
            source_time,
            c_medium,
        )

        inp_pars = jnp.swapaxes(inp_pars, 0, 1)
        time_geo = jnp.swapaxes(time_geo, 0, 1)

        # flatten [densely pack [modules, sources] in 1D array]
        inp_pars = inp_pars.reshape(
            (source_pos.shape[0] * module_coords.shape[0], inp_pars.shape[-1])
        )
        time_geo = time_geo.reshape(
            (source_pos.shape[0] * module_coords.shape[0], time_geo.shape[-1])
        )
        source_photons = jnp.tile(source_nphotons, module_coords.shape[0]).T.ravel()
        mod_eff_factor = jnp.repeat(module_efficiencies, source_pos.shape[0])

        # Normalizing flows only built up to 300
        # TODO: Check lower bound as well
        distance_mask = inp_pars[:, 0] < np.log10(300)

        inp_params_masked = inp_pars[distance_mask]
        time_geo_masked = time_geo[distance_mask]
        source_photons_masked = source_photons[distance_mask]
        mod_eff_factor_masked = mod_eff_factor[distance_mask]

        # Eval count net to obtain survival fraction
        ph_frac = jnp.power(10, counts_net.apply(counts_params, inp_params_masked)).squeeze()

        # Sample number of detected photons
        n_photons_masked = ph_frac * source_photons_masked * mod_eff_factor_masked

        key, subkey = random.split(key)
        n_photons_masked = (
            random.poisson(subkey, n_photons_masked, shape=n_photons_masked.shape)
            .squeeze()
            .astype(jnp.int32)
        )

        if jnp.all(n_photons_masked == 0):
            times = [] * module_coords.shape[0]
            return ak.Array(times)

        # Obtain flow parameters and repeat them for each detected photon
        traf_params = apply_fn(shape_params, inp_params_masked)
        traf_params_rep = jnp.repeat(traf_params, n_photons_masked, axis=0)
        # Also repeat the geometric time for each detected photon
        time_geo_rep = jnp.repeat(time_geo_masked, n_photons_masked, axis=0).squeeze()

        # Calculate number of photons per module
        # Start with zero array and fill in the poisson samples using distance mask
        n_photons = jnp.zeros(source_pos.shape[0] * module_coords.shape[0], dtype=jnp.int32)
        n_photons = n_photons.at[distance_mask].set(n_photons_masked)
        n_photons = n_photons.reshape(module_coords.shape[0], source_pos.shape[0])
        n_ph_per_mod = np.sum(n_photons, axis=1)

        # Sample times from flow
        key, subkey = random.split(key)
        samples = sample_model(traf_params_rep, subkey)
        times = np.atleast_1d(np.asarray(samples.squeeze() + time_geo_rep))

        if len(times) == 1:
            ix = np.argwhere(n_ph_per_mod).squeeze()
            times = [[] if i != ix else times for i in range(module_coords.shape[0])]
        else:
            # Split per module and convert to awkward.Array
            times = np.split(times, np.cumsum(n_ph_per_mod)[:-1])

        return ak.Array(times)

    return generate_norm_flow_photons