hyperion.models.photon_arrival_time_nflow.net

Functions:

eval_log_prob

eval_log_prob(dist_builder, traf_params, samples)

Compute log p(samples | traf_params) under the flow.

Parameters:
  • dist_builder (_TrafDistBuilder or callable) –

    Object returned by :func:traf_dist_builder.

  • traf_params (ndarray) –

    Array with shape (batch, total_flow_params).

  • samples (ndarray) –

    Array with shape (batch,) of sample values.

Returns:
  • ndarray

    Log-probabilities with shape (batch,).

Source code in hyperion/models/photon_arrival_time_nflow/net.py
def eval_log_prob(dist_builder, traf_params, samples):
    """Compute log p(samples | traf_params) under the flow.

    Parameters
    ----------
    dist_builder : _TrafDistBuilder or callable
        Object returned by :func:`traf_dist_builder`.
    traf_params : jnp.ndarray
        Array with shape (batch, total_flow_params).
    samples : jnp.ndarray
        Array with shape (batch,) of sample values.

    Returns
    -------
    jnp.ndarray
        Log-probabilities with shape (batch,).
    """
    if isinstance(dist_builder, _TrafDistBuilder):
        return dist_builder.log_prob(traf_params, samples)
    # Fallback: call builder and use its .log_prob()
    return dist_builder(traf_params).log_prob(samples)

make_conditioner

make_conditioner(hidden_sizes, out_params_activ, init_zero=True)

Conditioner MLP factory (kept for training API compatibility).

Parameters:
  • hidden_sizes (sequence) –

    Hidden layer sizes for the MLP.

  • out_params_activ (callable or None) –

    Activation applied to output parameters (kept for API compatibility).

  • init_zero (bool, default: True ) –

    If True, initialise output parameters to zero.

Returns:
  • _ConditionerFn

    Callable conditioner object with an .apply(params, x) method.

Source code in hyperion/models/photon_arrival_time_nflow/net.py
def make_conditioner(hidden_sizes, out_params_activ, init_zero=True):
    """Conditioner MLP factory (kept for training API compatibility).

    Parameters
    ----------
    hidden_sizes : sequence
        Hidden layer sizes for the MLP.
    out_params_activ : callable or None
        Activation applied to output parameters (kept for API compatibility).
    init_zero : bool, optional
        If True, initialise output parameters to zero.

    Returns
    -------
    _ConditionerFn
        Callable conditioner object with an ``.apply(params, x)`` method.
    """
    return _ConditionerFn(n_hidden_layers=len(hidden_sizes))

make_counts_net_fn

make_counts_net_fn(config)

Build the counts-model MLP.

Source code in hyperion/models/photon_arrival_time_nflow/net.py
def make_counts_net_fn(config):
    """Build the counts-model MLP."""
    return _ConditionerFn(n_hidden_layers=config["mlp_num_layers"])

make_shape_conditioner_fn

make_shape_conditioner_fn(mlp_hidden_size, mlp_num_layers, flow_num_bins, flow_num_layers)

Build the shape-model conditioner (MLP).

Parameters:
  • mlp_hidden_size (int) –

    Hidden layer size for the MLP.

  • mlp_num_layers (int) –

    Number of hidden layers.

  • flow_num_bins (int) –

    Number of bins used in the flow (kept for API compatibility).

  • flow_num_layers (int) –

    Number of flow layers (kept for API compatibility).

Returns:
  • _ConditionerFn

    Conditioner callable.

Source code in hyperion/models/photon_arrival_time_nflow/net.py
def make_shape_conditioner_fn(mlp_hidden_size, mlp_num_layers, flow_num_bins, flow_num_layers):
    """Build the shape-model conditioner (MLP).

    Parameters
    ----------
    mlp_hidden_size : int
        Hidden layer size for the MLP.
    mlp_num_layers : int
        Number of hidden layers.
    flow_num_bins : int
        Number of bins used in the flow (kept for API compatibility).
    flow_num_layers : int
        Number of flow layers (kept for API compatibility).

    Returns
    -------
    _ConditionerFn
        Conditioner callable.
    """
    return _ConditionerFn(n_hidden_layers=mlp_num_layers)

make_spl_flow

make_spl_flow(spl_params_list, rmin, rmax)

Convert a list of raw spline parameter arrays into knot-tuple lists.

Parameters:
  • spl_params_list (sequence of jnp.ndarray) –

    List of spline parameter arrays; each element has shape (batch, 3*num_bins + 1).

  • rmin (float) –

    Minimum range value for the spline.

  • rmax (float) –

    Maximum range value for the spline.

Returns:
  • list

    List of tuples (x_pos, y_pos, knot_slopes) each with shape (batch, num_bins + 1).

Source code in hyperion/models/photon_arrival_time_nflow/net.py
def make_spl_flow(spl_params_list, rmin, rmax):
    """Convert a list of raw spline parameter arrays into knot-tuple lists.

    Parameters
    ----------
    spl_params_list : sequence of jnp.ndarray
        List of spline parameter arrays; each element has shape (batch, 3*num_bins + 1).
    rmin : float
        Minimum range value for the spline.
    rmax : float
        Maximum range value for the spline.

    Returns
    -------
    list
        List of tuples ``(x_pos, y_pos, knot_slopes)`` each with shape
        (batch, num_bins + 1).
    """
    return [_build_rqs_knots_batched(sp, float(rmin), float(rmax)) for sp in spl_params_list]

sample_shape_model

sample_shape_model(dist_builder, traf_params, n_photons, seed)

Sample from the shape model.

Parameters:
  • dist_builder (callable) –

    Builder returned by :func:traf_dist_builder with return_base=True.

  • traf_params (ndarray) –

    Array with shape (batch, total_flow_params).

  • n_photons (int or tuple) –

    Number of base samples to draw or sample shape.

  • seed (PRNGKey) –

    JAX PRNG key.

Returns:
  • ndarray

    Samples drawn from the transformed shape model.

Source code in hyperion/models/photon_arrival_time_nflow/net.py
def sample_shape_model(dist_builder, traf_params, n_photons, seed):
    """Sample from the shape model.

    Parameters
    ----------
    dist_builder : callable
        Builder returned by :func:`traf_dist_builder` with ``return_base=True``.
    traf_params : jnp.ndarray
        Array with shape (batch, total_flow_params).
    n_photons : int or tuple
        Number of base samples to draw or sample shape.
    seed : jax.random.PRNGKey
        JAX PRNG key.

    Returns
    -------
    jnp.ndarray
        Samples drawn from the transformed shape model.
    """
    base_dist, trafo = dist_builder(traf_params)
    base_samples = base_dist.sample(seed=seed, sample_shape=n_photons)
    return trafo.forward(base_samples)

traf_dist_builder

traf_dist_builder(flow_num_layers, flow_range, return_base=False)

Return a callable that builds the transformed distribution.

Parameters:
  • flow_num_layers (int) –

    Number of spline layers in the flow.

  • flow_range (tuple) –

    (rmin, rmax) range for the spline.

  • return_base (bool, default: False ) –

    If True, calling the returned object returns (base_dist, flow). If False, returns a dist-like object with .log_prob().

Returns:
  • _TrafDistBuilder

    Builder callable for the transformed distribution.

Source code in hyperion/models/photon_arrival_time_nflow/net.py
def traf_dist_builder(flow_num_layers, flow_range, return_base=False):
    """Return a callable that builds the transformed distribution.

    Parameters
    ----------
    flow_num_layers : int
        Number of spline layers in the flow.
    flow_range : tuple
        ``(rmin, rmax)`` range for the spline.
    return_base : bool, optional
        If True, calling the returned object returns ``(base_dist, flow)``.
        If False, returns a dist-like object with ``.log_prob()``.

    Returns
    -------
    _TrafDistBuilder
        Builder callable for the transformed distribution.
    """
    return _TrafDistBuilder(
        flow_num_layers=flow_num_layers,
        rmin=flow_range[0],
        rmax=flow_range[1],
        return_base=return_base,
    )

train_counts_model

train_counts_model(config, train_loader, test_loader, seed=1337, writer=None)

Train the counts model using MLP regression.

Parameters:
  • config (dict) –

    Training and model configuration.

  • train_loader (iterable) –

    Training data loader.

  • test_loader (iterable) –

    Test data loader.

  • seed (int, default: 1337 ) –

    Random seed.

  • writer (SummaryWriter or None, default: None ) –

    Optional writer for logging metrics.

Returns:
  • dict

    Trained parameter dictionary.

Source code in hyperion/models/photon_arrival_time_nflow/net.py
def train_counts_model(config, train_loader, test_loader, seed=1337, writer=None):
    """Train the counts model using MLP regression.

    Parameters
    ----------
    config : dict
        Training and model configuration.
    train_loader : iterable
        Training data loader.
    test_loader : iterable
        Test data loader.
    seed : int, optional
        Random seed.
    writer : SummaryWriter or None, optional
        Optional writer for logging metrics.

    Returns
    -------
    dict
        Trained parameter dictionary.
    """

    net_fn = make_counts_net_fn(config)

    @jax.jit
    def loss_fn(params, batch):
        """Mean-squared error loss for a batch.

        Parameters
        ----------
        params : dict
            Model parameters.
        batch : tuple
            Batch data (inputs, targets, ...).

        Returns
        -------
        jnp.ndarray
            Scalar loss value.
        """
        inp = jnp.concatenate(batch[:2]).T
        out = net_fn.apply(params, inp).squeeze()
        return 0.5 * jnp.average((out - batch[2]) ** 2)

    @jax.jit
    def update(params, opt_state, batch):
        """Single optimizer update for counts model.

        Returns updated parameters, optimizer state and loss.
        """
        lval, grads = jax.value_and_grad(loss_fn)(params, batch)
        updates, new_opt_state = optimizer.update(grads, opt_state)
        return optax.apply_updates(params, updates), new_opt_state, lval

    scheduler = optax.cosine_decay_schedule(config["lr"], config["steps"], alpha=0.0)
    optimizer = optax.adam(learning_rate=scheduler)

    params = _init_mlp_params(2, config["mlp_hidden_size"], config["mlp_num_layers"], 1, seed)
    opt_state = optimizer.init(params)

    train_iter = iter(train_loader)
    for i in range(1, config["steps"] + 1):
        train = next(train_iter)
        params, opt_state, train_loss = update(params, opt_state, train)

        if i % 100 == 0:
            test_loss = sum(loss_fn(params, t) for t in test_loader) / test_loader._n_batches
            train_loss, test_loss = jax.device_get((train_loss, test_loss))
            if writer is not None:
                writer.add_scalar("Loss/train", train_loss, i)
                writer.add_scalar("Loss/test", test_loss, i)
                writer.flush()
            print(f"Epoch: {i} \t Train/Test: {train_loss:.3E} / {test_loss:.3E}")

    if writer is not None:
        test_loss = sum(loss_fn(params, t) for t in test_loader) / test_loader._n_batches
        test_loss = jax.device_get(test_loss)
        writer.add_hparams(dict(config), {"hparam/test_loss": test_loss})
        writer.flush()
        writer.close()

    return params

train_shape_model

train_shape_model(config, train_loader, test_loader, seed=1337, writer=None)

Train the shape model using the provided data loaders.

Parameters:
  • config (dict) –

    Training and model configuration dictionary.

  • train_loader (iterable) –

    Training data loader.

  • test_loader (iterable) –

    Test data loader.

  • seed (int, default: 1337 ) –

    Random seed for parameter initialisation (default is 1337).

  • writer (SummaryWriter or None, default: None ) –

    Optional writer for logging metrics.

Returns:
  • dict

    Trained parameter dictionary.

Source code in hyperion/models/photon_arrival_time_nflow/net.py
def train_shape_model(config, train_loader, test_loader, seed=1337, writer=None):
    """Train the shape model using the provided data loaders.

    Parameters
    ----------
    config : dict
        Training and model configuration dictionary.
    train_loader : iterable
        Training data loader.
    test_loader : iterable
        Test data loader.
    seed : int, optional
        Random seed for parameter initialisation (default is 1337).
    writer : SummaryWriter or None, optional
        Optional writer for logging metrics.

    Returns
    -------
    dict
        Trained parameter dictionary.
    """

    dist_builder = traf_dist_builder(
        config["flow_num_layers"],
        (config["flow_rmin"], config["flow_rmax"]),
    )
    shape_conditioner = make_shape_conditioner_fn(
        config["mlp_hidden_size"],
        config["mlp_num_layers"],
        config["flow_num_bins"],
        config["flow_num_layers"],
    )

    @jax.jit
    def ema_update(params, avg_params):
        """Exponential moving average update for parameters.

        Parameters
        ----------
        params : dict
            Current parameters.
        avg_params : dict
            Current EMA parameters.

        Returns
        -------
        dict
            Updated EMA parameters.
        """
        return optax.incremental_update(params, avg_params, step_size=0.001)

    @jax.jit
    def loss_fn(params, cond, samples):
        """Compute negative log-likelihood loss for a batch.

        Parameters
        ----------
        params : dict
            Model parameters for the conditioner MLP.
        cond : array-like
            Conditioning inputs.
        samples : array-like
            Observed sample values.

        Returns
        -------
        jnp.ndarray
            Scalar loss value.
        """
        traf_params = shape_conditioner.apply(params, cond)
        lprobs = eval_log_prob(dist_builder, traf_params, samples)
        return -jnp.mean(lprobs * jnp.isfinite(lprobs))

    @jax.jit
    def update(params, opt_state, cond, samples):
        """Perform a single optimization step.

        Parameters
        ----------
        params : dict
            Current model parameters.
        opt_state : optax.OptState
            Current optimizer state.
        cond : array-like
            Conditioning inputs for the batch.
        samples : array-like
            Observed samples for the batch.

        Returns
        -------
        tuple
            ``(new_params, new_opt_state, loss_value)``.
        """
        lval, grads = jax.value_and_grad(loss_fn)(params, cond, samples)
        updates, new_opt_state = optimizer.update(grads, opt_state)
        return optax.apply_updates(params, updates), new_opt_state, lval

    scheduler = optax.cosine_decay_schedule(config["lr"], config["steps"], alpha=0.0)
    optimizer = optax.adam(learning_rate=scheduler)

    n_out = (3 * config["flow_num_bins"] + 1) * config["flow_num_layers"]
    params = avg_params = _init_mlp_params(
        2, config["mlp_hidden_size"], config["mlp_num_layers"], n_out, seed
    )
    opt_state = optimizer.init(params)

    train_iter = iter(train_loader)
    for i in range(1, config["steps"] + 1):
        train = next(train_iter)
        cond = jnp.concatenate(train[:2]).T
        samples = jnp.squeeze(train[2])
        params, opt_state, train_loss = update(params, opt_state, cond, samples)
        avg_params = ema_update(params, avg_params)

        if i % 100 == 0:
            test_loss = (
                sum(
                    loss_fn(avg_params, jnp.concatenate(t[:2]).T, jnp.squeeze(t[2]))
                    for t in test_loader
                )
                / test_loader._n_batches
            )
            train_loss, test_loss = jax.device_get((train_loss, test_loss))
            if writer is not None:
                writer.add_scalar("Loss/train", train_loss, i)
                writer.add_scalar("Loss/test", test_loss, i)
                writer.flush()
            print(f"Epoch: {i} \t Train/Test: {train_loss:.3E} / {test_loss:.3E}")

    return params