hyperion.models.photon_arrival_time_nflow.net
Functions:
-
eval_log_prob–Compute log p(samples | traf_params) under the flow.
-
make_conditioner–Conditioner MLP factory (kept for training API compatibility).
-
make_counts_net_fn–Build the counts-model MLP.
-
make_shape_conditioner_fn–Build the shape-model conditioner (MLP).
-
make_spl_flow–Convert a list of raw spline parameter arrays into knot-tuple lists.
-
sample_shape_model–Sample from the shape model.
-
traf_dist_builder–Return a callable that builds the transformed distribution.
-
train_counts_model–Train the counts model using MLP regression.
-
train_shape_model–Train the shape model using the provided data loaders.
eval_log_prob
eval_log_prob(dist_builder, traf_params, samples)
Compute log p(samples | traf_params) under the flow.
| Parameters: |
|
|---|
| Returns: |
|
|---|
Source code in hyperion/models/photon_arrival_time_nflow/net.py
def eval_log_prob(dist_builder, traf_params, samples):
"""Compute log p(samples | traf_params) under the flow.
Parameters
----------
dist_builder : _TrafDistBuilder or callable
Object returned by :func:`traf_dist_builder`.
traf_params : jnp.ndarray
Array with shape (batch, total_flow_params).
samples : jnp.ndarray
Array with shape (batch,) of sample values.
Returns
-------
jnp.ndarray
Log-probabilities with shape (batch,).
"""
if isinstance(dist_builder, _TrafDistBuilder):
return dist_builder.log_prob(traf_params, samples)
# Fallback: call builder and use its .log_prob()
return dist_builder(traf_params).log_prob(samples)
make_conditioner
make_conditioner(hidden_sizes, out_params_activ, init_zero=True)
Conditioner MLP factory (kept for training API compatibility).
| Parameters: |
|
|---|
| Returns: |
|
|---|
Source code in hyperion/models/photon_arrival_time_nflow/net.py
def make_conditioner(hidden_sizes, out_params_activ, init_zero=True):
"""Conditioner MLP factory (kept for training API compatibility).
Parameters
----------
hidden_sizes : sequence
Hidden layer sizes for the MLP.
out_params_activ : callable or None
Activation applied to output parameters (kept for API compatibility).
init_zero : bool, optional
If True, initialise output parameters to zero.
Returns
-------
_ConditionerFn
Callable conditioner object with an ``.apply(params, x)`` method.
"""
return _ConditionerFn(n_hidden_layers=len(hidden_sizes))
make_counts_net_fn
make_counts_net_fn(config)
Build the counts-model MLP.
Source code in hyperion/models/photon_arrival_time_nflow/net.py
def make_counts_net_fn(config):
"""Build the counts-model MLP."""
return _ConditionerFn(n_hidden_layers=config["mlp_num_layers"])
make_shape_conditioner_fn
make_shape_conditioner_fn(mlp_hidden_size, mlp_num_layers, flow_num_bins, flow_num_layers)
Build the shape-model conditioner (MLP).
| Parameters: |
|
|---|
| Returns: |
|
|---|
Source code in hyperion/models/photon_arrival_time_nflow/net.py
def make_shape_conditioner_fn(mlp_hidden_size, mlp_num_layers, flow_num_bins, flow_num_layers):
"""Build the shape-model conditioner (MLP).
Parameters
----------
mlp_hidden_size : int
Hidden layer size for the MLP.
mlp_num_layers : int
Number of hidden layers.
flow_num_bins : int
Number of bins used in the flow (kept for API compatibility).
flow_num_layers : int
Number of flow layers (kept for API compatibility).
Returns
-------
_ConditionerFn
Conditioner callable.
"""
return _ConditionerFn(n_hidden_layers=mlp_num_layers)
make_spl_flow
make_spl_flow(spl_params_list, rmin, rmax)
Convert a list of raw spline parameter arrays into knot-tuple lists.
| Parameters: |
|
|---|
| Returns: |
|
|---|
Source code in hyperion/models/photon_arrival_time_nflow/net.py
def make_spl_flow(spl_params_list, rmin, rmax):
"""Convert a list of raw spline parameter arrays into knot-tuple lists.
Parameters
----------
spl_params_list : sequence of jnp.ndarray
List of spline parameter arrays; each element has shape (batch, 3*num_bins + 1).
rmin : float
Minimum range value for the spline.
rmax : float
Maximum range value for the spline.
Returns
-------
list
List of tuples ``(x_pos, y_pos, knot_slopes)`` each with shape
(batch, num_bins + 1).
"""
return [_build_rqs_knots_batched(sp, float(rmin), float(rmax)) for sp in spl_params_list]
sample_shape_model
sample_shape_model(dist_builder, traf_params, n_photons, seed)
Sample from the shape model.
| Parameters: |
|
|---|
| Returns: |
|
|---|
Source code in hyperion/models/photon_arrival_time_nflow/net.py
def sample_shape_model(dist_builder, traf_params, n_photons, seed):
"""Sample from the shape model.
Parameters
----------
dist_builder : callable
Builder returned by :func:`traf_dist_builder` with ``return_base=True``.
traf_params : jnp.ndarray
Array with shape (batch, total_flow_params).
n_photons : int or tuple
Number of base samples to draw or sample shape.
seed : jax.random.PRNGKey
JAX PRNG key.
Returns
-------
jnp.ndarray
Samples drawn from the transformed shape model.
"""
base_dist, trafo = dist_builder(traf_params)
base_samples = base_dist.sample(seed=seed, sample_shape=n_photons)
return trafo.forward(base_samples)
traf_dist_builder
traf_dist_builder(flow_num_layers, flow_range, return_base=False)
Return a callable that builds the transformed distribution.
| Parameters: |
|
|---|
| Returns: |
|
|---|
Source code in hyperion/models/photon_arrival_time_nflow/net.py
def traf_dist_builder(flow_num_layers, flow_range, return_base=False):
"""Return a callable that builds the transformed distribution.
Parameters
----------
flow_num_layers : int
Number of spline layers in the flow.
flow_range : tuple
``(rmin, rmax)`` range for the spline.
return_base : bool, optional
If True, calling the returned object returns ``(base_dist, flow)``.
If False, returns a dist-like object with ``.log_prob()``.
Returns
-------
_TrafDistBuilder
Builder callable for the transformed distribution.
"""
return _TrafDistBuilder(
flow_num_layers=flow_num_layers,
rmin=flow_range[0],
rmax=flow_range[1],
return_base=return_base,
)
train_counts_model
train_counts_model(config, train_loader, test_loader, seed=1337, writer=None)
Train the counts model using MLP regression.
| Parameters: |
|
|---|
| Returns: |
|
|---|
Source code in hyperion/models/photon_arrival_time_nflow/net.py
def train_counts_model(config, train_loader, test_loader, seed=1337, writer=None):
"""Train the counts model using MLP regression.
Parameters
----------
config : dict
Training and model configuration.
train_loader : iterable
Training data loader.
test_loader : iterable
Test data loader.
seed : int, optional
Random seed.
writer : SummaryWriter or None, optional
Optional writer for logging metrics.
Returns
-------
dict
Trained parameter dictionary.
"""
net_fn = make_counts_net_fn(config)
@jax.jit
def loss_fn(params, batch):
"""Mean-squared error loss for a batch.
Parameters
----------
params : dict
Model parameters.
batch : tuple
Batch data (inputs, targets, ...).
Returns
-------
jnp.ndarray
Scalar loss value.
"""
inp = jnp.concatenate(batch[:2]).T
out = net_fn.apply(params, inp).squeeze()
return 0.5 * jnp.average((out - batch[2]) ** 2)
@jax.jit
def update(params, opt_state, batch):
"""Single optimizer update for counts model.
Returns updated parameters, optimizer state and loss.
"""
lval, grads = jax.value_and_grad(loss_fn)(params, batch)
updates, new_opt_state = optimizer.update(grads, opt_state)
return optax.apply_updates(params, updates), new_opt_state, lval
scheduler = optax.cosine_decay_schedule(config["lr"], config["steps"], alpha=0.0)
optimizer = optax.adam(learning_rate=scheduler)
params = _init_mlp_params(2, config["mlp_hidden_size"], config["mlp_num_layers"], 1, seed)
opt_state = optimizer.init(params)
train_iter = iter(train_loader)
for i in range(1, config["steps"] + 1):
train = next(train_iter)
params, opt_state, train_loss = update(params, opt_state, train)
if i % 100 == 0:
test_loss = sum(loss_fn(params, t) for t in test_loader) / test_loader._n_batches
train_loss, test_loss = jax.device_get((train_loss, test_loss))
if writer is not None:
writer.add_scalar("Loss/train", train_loss, i)
writer.add_scalar("Loss/test", test_loss, i)
writer.flush()
print(f"Epoch: {i} \t Train/Test: {train_loss:.3E} / {test_loss:.3E}")
if writer is not None:
test_loss = sum(loss_fn(params, t) for t in test_loader) / test_loader._n_batches
test_loss = jax.device_get(test_loss)
writer.add_hparams(dict(config), {"hparam/test_loss": test_loss})
writer.flush()
writer.close()
return params
train_shape_model
train_shape_model(config, train_loader, test_loader, seed=1337, writer=None)
Train the shape model using the provided data loaders.
| Parameters: |
|
|---|
| Returns: |
|
|---|
Source code in hyperion/models/photon_arrival_time_nflow/net.py
def train_shape_model(config, train_loader, test_loader, seed=1337, writer=None):
"""Train the shape model using the provided data loaders.
Parameters
----------
config : dict
Training and model configuration dictionary.
train_loader : iterable
Training data loader.
test_loader : iterable
Test data loader.
seed : int, optional
Random seed for parameter initialisation (default is 1337).
writer : SummaryWriter or None, optional
Optional writer for logging metrics.
Returns
-------
dict
Trained parameter dictionary.
"""
dist_builder = traf_dist_builder(
config["flow_num_layers"],
(config["flow_rmin"], config["flow_rmax"]),
)
shape_conditioner = make_shape_conditioner_fn(
config["mlp_hidden_size"],
config["mlp_num_layers"],
config["flow_num_bins"],
config["flow_num_layers"],
)
@jax.jit
def ema_update(params, avg_params):
"""Exponential moving average update for parameters.
Parameters
----------
params : dict
Current parameters.
avg_params : dict
Current EMA parameters.
Returns
-------
dict
Updated EMA parameters.
"""
return optax.incremental_update(params, avg_params, step_size=0.001)
@jax.jit
def loss_fn(params, cond, samples):
"""Compute negative log-likelihood loss for a batch.
Parameters
----------
params : dict
Model parameters for the conditioner MLP.
cond : array-like
Conditioning inputs.
samples : array-like
Observed sample values.
Returns
-------
jnp.ndarray
Scalar loss value.
"""
traf_params = shape_conditioner.apply(params, cond)
lprobs = eval_log_prob(dist_builder, traf_params, samples)
return -jnp.mean(lprobs * jnp.isfinite(lprobs))
@jax.jit
def update(params, opt_state, cond, samples):
"""Perform a single optimization step.
Parameters
----------
params : dict
Current model parameters.
opt_state : optax.OptState
Current optimizer state.
cond : array-like
Conditioning inputs for the batch.
samples : array-like
Observed samples for the batch.
Returns
-------
tuple
``(new_params, new_opt_state, loss_value)``.
"""
lval, grads = jax.value_and_grad(loss_fn)(params, cond, samples)
updates, new_opt_state = optimizer.update(grads, opt_state)
return optax.apply_updates(params, updates), new_opt_state, lval
scheduler = optax.cosine_decay_schedule(config["lr"], config["steps"], alpha=0.0)
optimizer = optax.adam(learning_rate=scheduler)
n_out = (3 * config["flow_num_bins"] + 1) * config["flow_num_layers"]
params = avg_params = _init_mlp_params(
2, config["mlp_hidden_size"], config["mlp_num_layers"], n_out, seed
)
opt_state = optimizer.init(params)
train_iter = iter(train_loader)
for i in range(1, config["steps"] + 1):
train = next(train_iter)
cond = jnp.concatenate(train[:2]).T
samples = jnp.squeeze(train[2])
params, opt_state, train_loss = update(params, opt_state, cond, samples)
avg_params = ema_update(params, avg_params)
if i % 100 == 0:
test_loss = (
sum(
loss_fn(avg_params, jnp.concatenate(t[:2]).T, jnp.squeeze(t[2]))
for t in test_loader
)
/ test_loader._n_batches
)
train_loss, test_loss = jax.device_get((train_loss, test_loss))
if writer is not None:
writer.add_scalar("Loss/train", train_loss, i)
writer.add_scalar("Loss/test", test_loss, i)
writer.flush()
print(f"Epoch: {i} \t Train/Test: {train_loss:.3E} / {test_loss:.3E}")
return params