olympus.optimization.fisher_information

Functions:

  • calc_fisher_info_cascades

    Estimate the Fisher information matrix for cascades by averaging over events.

  • pad_array_log_bucket

    Pad a 1-D array to the next power of base using numpy.inf fill.

  • pad_event

    Pad an event array to the next multiple of 256 hits per module.

calc_fisher_info_cascades

calc_fisher_info_cascades(det, event_data, key, converter, ph_prop, lh_func, c_medium, n_ev=20, pad_base=4)

Estimate the Fisher information matrix for cascades by averaging over events.

Parameters:
  • det (Detector) –

    Detector instance.

  • event_data (dict) –

    Event parameters used for cascade generation.

  • key (PRNGKey) –

    JAX random key.

  • converter (callable) –

    Function converting event parameters to photon source descriptions.

  • ph_prop (callable) –

    Photon propagation function.

  • lh_func (callable) –

    Per-module likelihood function.

  • c_medium (float) –

    Speed of light in the medium.

  • n_ev (int, default: 20 ) –

    Number of events to average over.

  • pad_base (int, default: 4 ) –

    Base used for log-bucket padding of hit arrays.

Returns:
  • fisher( ndarray ) –

    Estimated Fisher information matrix of shape (7, 7).

Source code in olympus/optimization/fisher_information.py
def calc_fisher_info_cascades(
    det, event_data, key, converter, ph_prop, lh_func, c_medium, n_ev=20, pad_base=4
):
    """Estimate the Fisher information matrix for cascades by averaging over events.

    Parameters
    ----------
    det : Detector
        Detector instance.
    event_data : dict
        Event parameters used for cascade generation.
    key : jax.random.PRNGKey
        JAX random key.
    converter : callable
        Function converting event parameters to photon source descriptions.
    ph_prop : callable
        Photon propagation function.
    lh_func : callable
        Per-module likelihood function.
    c_medium : float
        Speed of light in the medium.
    n_ev : int, optional
        Number of events to average over.
    pad_base : int, optional
        Base used for log-bucket padding of hit arrays.

    Returns
    -------
    fisher : np.ndarray
        Estimated Fisher information matrix of shape ``(7, 7)``.
    """

    def eval_for_mod(x, y, z, theta, phi, t, log10e, times, mod_coords, noise_rate, key):

        print("Retracing")

        pos = jnp.asarray([x, y, z])
        dir = sph_to_cart_jnp(theta, phi)

        sources = converter(pos, t, dir, 10**log10e, particle_id=event_data["particle_id"], key=key)

        return lh_func(
            times,
            jnp.sum(jnp.isfinite(times)),
            mod_coords,
            sources[0],
            sources[1],
            sources[2],
            sources[3],
            c_medium,
            noise_rate,
        )

    eval_jacobian = jax.jit(jax.jacobian(eval_for_mod, [0, 1, 2, 3, 4, 5, 6]))

    matrices = []
    for _ in range(n_ev):
        key, k1, k2 = random.split(key, 3)
        event, _ = generate_cascade(
            det,
            event_data,
            pprop_func=ph_prop,
            seed=k1,
            converter_func=converter,
        )

        event, _ = simulate_noise(det, event)

        # padded = pad_event(event)
        # padded = [np.asarray(event[j]) for j in range(len(event))]
        jacsum = 0
        for j in range(len(event)):
            padded = pad_array_log_bucket(event[j], pad_base)
            res = jnp.stack(
                eval_jacobian(
                    event_data["pos"][0],
                    event_data["pos"][1],
                    event_data["pos"][2],
                    event_data["theta"],
                    event_data["phi"],
                    event_data["time"],
                    np.log10(event_data["energy"]),
                    padded,
                    det.module_coords[j],
                    det.module_noise_rates[j],
                    k2,
                )
            )
            jacsum += res
        matrices.append(np.asarray(jacsum[:, np.newaxis] * jacsum[np.newaxis, :]))

    fisher = np.average(np.stack(matrices), axis=0)
    return fisher

pad_array_log_bucket

pad_array_log_bucket(array, base)

Pad a 1-D array to the next power of base using numpy.inf fill.

Parameters:
  • array (Array) –

    Array to pad.

  • base (int) –

    Logarithmic bucket base used to determine the target length.

Returns:
  • ev_np( ndarray ) –

    Padded array with numpy.inf fill values, or an empty float array if array is empty.

Source code in olympus/optimization/fisher_information.py
def pad_array_log_bucket(array, base):
    """Pad a 1-D array to the next power of ``base`` using ``numpy.inf`` fill.

    Parameters
    ----------
    array : ak.Array
        Array to pad.
    base : int
        Logarithmic bucket base used to determine the target length.

    Returns
    -------
    ev_np : np.ndarray
        Padded array with ``numpy.inf`` fill values, or an empty float array if
        ``array`` is empty.
    """
    if ak.count(array) == 0:
        return np.array([], dtype=np.float)

    log_cnt = np.log(ak.count(array)) / np.log(base)
    pad_len = int(np.power(base, np.ceil(log_cnt)))
    if ak.count(array) > pad_len:
        raise RuntimeError()

    padded = ak.pad_none(array, target=pad_len, clip=True, axis=0)
    ev_np = np.asarray((ak.fill_none(padded, np.inf)))
    return ev_np

pad_event

pad_event(event)

Pad an event array to the next multiple of 256 hits per module.

Parameters:
  • event (Array) –

    Per-module hit-time arrays to pad.

Returns:
  • ev_np( ndarray ) –

    Padded array with numpy.inf fill values.

Source code in olympus/optimization/fisher_information.py
def pad_event(event):
    """Pad an event array to the next multiple of 256 hits per module.

    Parameters
    ----------
    event : ak.Array
        Per-module hit-time arrays to pad.

    Returns
    -------
    ev_np : np.ndarray
        Padded array with ``numpy.inf`` fill values.
    """
    pad_len = np.int32(np.ceil(ak.max(ak.count(event, axis=1)) / 256) * 256)

    if ak.max(ak.count(event, axis=1)) > pad_len:
        raise RuntimeError()
    padded = ak.pad_none(event, target=pad_len, clip=True, axis=1)
    # mask = ak.is_none(padded, axis=1)
    ev_np = np.asarray((ak.fill_none(padded, np.inf)))
    return ev_np