Skip to content

Commit

Permalink
add vmapped versions of get log likelihoods
Browse files Browse the repository at this point in the history
  • Loading branch information
gyoge0 committed Jul 16, 2024
1 parent f56b7bb commit b6c5693
Show file tree
Hide file tree
Showing 4 changed files with 169 additions and 85 deletions.
112 changes: 112 additions & 0 deletions sparkle_stats/likelihoods.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import jax
import jax.numpy as jnp
from blinx.trace_model import get_trace_log_likelihood

from sparkle_stats.parameters_util import parameters_array_to_object


def vmap_get_trace_log_likelihoods(
traces,
y,
parameters,
hyper_parameters,
):
"""
Get the log_likelihood of N traces and parameters
Args:
traces (array of shape (N, T):
N sequences of T intensity observations
y (int):
the total number of fluorescent emitters
parameters (array of shape (N, 7):
N sets of parameters
hyper_parameters (:class:`HyperParameters`, optional):
The hyperparameters used for the maximum log_likelihood estimation
Returns:
log_likelihoods (array of shape (N,)):
log likelihood for each of the N traces and parameters
"""

mapped = jax.vmap(
_get_trace_log_likelihood_from_packed_params,
in_axes=(0, None, 0, None),
)
log_likelihoods = mapped(traces, y, parameters, hyper_parameters)

return log_likelihoods


def _get_trace_log_likelihood_from_packed_params(
trace,
y,
parameters,
hyper_parameters,
):
parameters_obj = parameters_array_to_object(parameters)
return get_trace_log_likelihood(trace, y, parameters_obj, hyper_parameters)


def get_y_log_likelihoods(
y_values,
traces,
parameters,
hyper_parameters,
):
"""
Get the log_likelihood of N traces and parameters for multiple y values
Args:
y_values (array of shape (Y,)):
array of y values to try for each set of traces and parameters
traces (array of shape (N, T):
N sequences of T intensity observations
parameters (array of shape (N, 7):
N sets of parameters
hyper_parameters (:class:`HyperParameters`, optional):
The hyperparameters used for the maximum log_likelihood estimation
Returns:
y_log_likelihoods (array of shape (Y, N)):
log likelihood of each Y for each of the N traces and parameters
"""

y_log_likelihoods = []
for y in y_values:
log_likelihoods = vmap_get_trace_log_likelihoods(
traces, y, parameters, hyper_parameters
)
y_log_likelihoods.append(log_likelihoods)
y_log_likelihoods = jnp.stack(y_log_likelihoods)
return y_log_likelihoods


def select_best_y_log_likelihoods(y_values, y_log_likelihoods):
"""
Select the most likely y from multiple log likelihoods per y
Args:
y_values (array of shape (Y,)):
array of y values for each set of traces and parameters
y_log_likelihoods (array of shape (Y, N)):
log likelihood of each y for each of the N traces and parameters
Returns:
max_y (array of shape (N,)):
the y with the highest likelihood for each set of traces and parameters
max_log_likelihoods (array of shape (N,)):
the highest likelihood for each set of traces and parameters
"""

max_log_likelihoods = jnp.max(y_log_likelihoods, axis=0)
max_y = y_values[jnp.argmax(y_log_likelihoods, axis=0)]
return max_y, max_log_likelihoods
63 changes: 0 additions & 63 deletions sparkle_stats/trace_model.py

This file was deleted.

57 changes: 57 additions & 0 deletions tests/test_likelihoods.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import jax.numpy as jnp
from blinx import HyperParameters
from sparkle_stats.generate_dataset import generate_memory_dataset
from sparkle_stats.likelihoods import (
get_y_log_likelihoods,
vmap_get_trace_log_likelihoods,
)


def test_vmap_get_trace_log_likelihoods():
y_list = [6, 7]
traces_per_y = 10
num_frames = 4000
hyper_parameters = HyperParameters()
seed = 1

traces, parameters, all_ys = generate_memory_dataset(
y_list, traces_per_y, num_frames, hyper_parameters, seed=seed
)
traces = traces[:, :, 0]
hyper_parameters.max_x = traces.max()

log_likelihoods = vmap_get_trace_log_likelihoods(
traces,
6,
parameters,
hyper_parameters,
)

assert log_likelihoods.shape == (traces.shape[0],)
assert jnp.isfinite(log_likelihoods).all()


def test_get_y_log_likelihoods():
y_list = [6, 7]
traces_per_y = 10
num_frames = 4000
hyper_parameters = HyperParameters()
seed = 1

traces, parameters, all_ys = generate_memory_dataset(
y_list, traces_per_y, num_frames, hyper_parameters, seed=seed
)
traces = traces[:, :, 0]
hyper_parameters.max_x = traces.max()

y_values = jnp.array([5, 6, 7, 8]).reshape(-1)

y_log_likelihoods = get_y_log_likelihoods(
y_values,
traces,
parameters,
hyper_parameters,
)

assert y_log_likelihoods.shape == (y_values.shape[0], traces.shape[0])
assert jnp.isfinite(y_log_likelihoods).all()
22 changes: 0 additions & 22 deletions tests/test_vmap_get_trace_log_likelihood.py

This file was deleted.

0 comments on commit b6c5693

Please sign in to comment.