From b6c569346a4dc9ba89985990d6c7ea4075fcd8dc Mon Sep 17 00:00:00 2001 From: Yogesh Thambidurai Date: Tue, 16 Jul 2024 14:50:39 -0400 Subject: [PATCH] add vmapped versions of get log likelihoods --- sparkle_stats/likelihoods.py | 112 ++++++++++++++++++++ sparkle_stats/trace_model.py | 63 ----------- tests/test_likelihoods.py | 57 ++++++++++ tests/test_vmap_get_trace_log_likelihood.py | 22 ---- 4 files changed, 169 insertions(+), 85 deletions(-) create mode 100644 sparkle_stats/likelihoods.py delete mode 100644 sparkle_stats/trace_model.py create mode 100644 tests/test_likelihoods.py delete mode 100644 tests/test_vmap_get_trace_log_likelihood.py diff --git a/sparkle_stats/likelihoods.py b/sparkle_stats/likelihoods.py new file mode 100644 index 0000000..fdaa830 --- /dev/null +++ b/sparkle_stats/likelihoods.py @@ -0,0 +1,112 @@ +import jax +import jax.numpy as jnp +from blinx.trace_model import get_trace_log_likelihood + +from sparkle_stats.parameters_util import parameters_array_to_object + + +def vmap_get_trace_log_likelihoods( + traces, + y, + parameters, + hyper_parameters, +): + """ + Get the log_likelihood of N traces and parameters + + Args: + traces (array of shape (N, T): + N sequences of T intensity observations + + y (int): + the total number of fluorescent emitters + + parameters (array of shape (N, 7): + N sets of parameters + + hyper_parameters (:class:`HyperParameters`, optional): + The hyperparameters used for the maximum log_likelihood estimation + + Returns: + log_likelihoods (array of shape (N,)): + log likelihood for each of the N traces and parameters + """ + + mapped = jax.vmap( + _get_trace_log_likelihood_from_packed_params, + in_axes=(0, None, 0, None), + ) + log_likelihoods = mapped(traces, y, parameters, hyper_parameters) + + return log_likelihoods + + +def _get_trace_log_likelihood_from_packed_params( + trace, + y, + parameters, + hyper_parameters, +): + parameters_obj = parameters_array_to_object(parameters) + return get_trace_log_likelihood(trace, y, parameters_obj, hyper_parameters) + + +def get_y_log_likelihoods( + y_values, + traces, + parameters, + hyper_parameters, +): + """ + Get the log_likelihood of N traces and parameters for multiple y values + + Args: + y_values (array of shape (Y,)): + array of y values to try for each set of traces and parameters + + traces (array of shape (N, T): + N sequences of T intensity observations + + parameters (array of shape (N, 7): + N sets of parameters + + hyper_parameters (:class:`HyperParameters`, optional): + The hyperparameters used for the maximum log_likelihood estimation + + Returns: + y_log_likelihoods (array of shape (Y, N)): + log likelihood of each Y for each of the N traces and parameters + """ + + y_log_likelihoods = [] + for y in y_values: + log_likelihoods = vmap_get_trace_log_likelihoods( + traces, y, parameters, hyper_parameters + ) + y_log_likelihoods.append(log_likelihoods) + y_log_likelihoods = jnp.stack(y_log_likelihoods) + return y_log_likelihoods + + +def select_best_y_log_likelihoods(y_values, y_log_likelihoods): + """ + Select the most likely y from multiple log likelihoods per y + + Args: + y_values (array of shape (Y,)): + array of y values for each set of traces and parameters + + y_log_likelihoods (array of shape (Y, N)): + log likelihood of each y for each of the N traces and parameters + + Returns: + max_y (array of shape (N,)): + the y with the highest likelihood for each set of traces and parameters + + max_log_likelihoods (array of shape (N,)): + the highest likelihood for each set of traces and parameters + """ + + max_log_likelihoods = jnp.max(y_log_likelihoods, axis=0) + max_y = y_values[jnp.argmax(y_log_likelihoods, axis=0)] + return max_y, max_log_likelihoods diff --git a/sparkle_stats/trace_model.py b/sparkle_stats/trace_model.py deleted file mode 100644 index 67d8843..0000000 --- a/sparkle_stats/trace_model.py +++ /dev/null @@ -1,63 +0,0 @@ -import jax -from blinx.trace_model import get_trace_log_likelihood - -__all__ = [ - "vmap_get_trace_log_likelihood", -] - -from sparkle_stats.parameters_util import parameters_array_to_object - - -def vmap_get_trace_log_likelihood(traces, y, parameters, hyper_parameters): - """ - Get the log_likelihood of a sets of parameters for multiple traces. - - Args: - traces (tensor of shape n x t: - - several sequences of intensity observations - - y (int): - - the total number of fluorescent emitters to test for - - parameters (array): - - - n x PARAMETER_COUNT array of parameters with each row corresponding to the parameters to try - - hyper_parameters (:class:`HyperParameters`, optional): - - The hyper-parameters used for the maximum likelihood estimation. - - Returns: - - log_likelihood (array of shape (n)): - - log_likelihood of observing the traces given the parameters - - """ - mapped = jax.vmap( - _get_trace_log_likelihood_from_packed_params, - in_axes=(0, None, 0, None), - ) - output = mapped(traces, y, parameters, hyper_parameters) - return output.reshape(-1, 1) - - -def _get_trace_log_likelihood_from_packed_params( - traces, y, parameters, hypter_parameters -): - """Gets a trace's log likelihood from parameters packed into an array. - - Args: - parameters (array): - 1 x PARAMETER_COUNT array of parameters passed to the Parameters constructor. - """ - print(parameters.shape) - parameters_obj = parameters_array_to_object(parameters) - return get_trace_log_likelihood( - traces, - y, - parameters_obj, - hypter_parameters, - ) diff --git a/tests/test_likelihoods.py b/tests/test_likelihoods.py new file mode 100644 index 0000000..cb51883 --- /dev/null +++ b/tests/test_likelihoods.py @@ -0,0 +1,57 @@ +import jax.numpy as jnp +from blinx import HyperParameters +from sparkle_stats.generate_dataset import generate_memory_dataset +from sparkle_stats.likelihoods import ( + get_y_log_likelihoods, + vmap_get_trace_log_likelihoods, +) + + +def test_vmap_get_trace_log_likelihoods(): + y_list = [6, 7] + traces_per_y = 10 + num_frames = 4000 + hyper_parameters = HyperParameters() + seed = 1 + + traces, parameters, all_ys = generate_memory_dataset( + y_list, traces_per_y, num_frames, hyper_parameters, seed=seed + ) + traces = traces[:, :, 0] + hyper_parameters.max_x = traces.max() + + log_likelihoods = vmap_get_trace_log_likelihoods( + traces, + 6, + parameters, + hyper_parameters, + ) + + assert log_likelihoods.shape == (traces.shape[0],) + assert jnp.isfinite(log_likelihoods).all() + + +def test_get_y_log_likelihoods(): + y_list = [6, 7] + traces_per_y = 10 + num_frames = 4000 + hyper_parameters = HyperParameters() + seed = 1 + + traces, parameters, all_ys = generate_memory_dataset( + y_list, traces_per_y, num_frames, hyper_parameters, seed=seed + ) + traces = traces[:, :, 0] + hyper_parameters.max_x = traces.max() + + y_values = jnp.array([5, 6, 7, 8]).reshape(-1) + + y_log_likelihoods = get_y_log_likelihoods( + y_values, + traces, + parameters, + hyper_parameters, + ) + + assert y_log_likelihoods.shape == (y_values.shape[0], traces.shape[0]) + assert jnp.isfinite(y_log_likelihoods).all() diff --git a/tests/test_vmap_get_trace_log_likelihood.py b/tests/test_vmap_get_trace_log_likelihood.py deleted file mode 100644 index a5017ce..0000000 --- a/tests/test_vmap_get_trace_log_likelihood.py +++ /dev/null @@ -1,22 +0,0 @@ -from blinx import HyperParameters -from sparkle_stats.generate_dataset import generate_memory_dataset -from sparkle_stats.trace_model import vmap_get_trace_log_likelihood - - -def test_vmap_get_trace_log_likelihood(): - y_list = [6, 7] - traces_per_y = 10 - num_frames = 10 - hyper_parameters = HyperParameters( - max_x=10, - num_x_bins=10, - num_outliers=2, - ) - seed = 1 - - traces, parameters, _ = generate_memory_dataset( - y_list, traces_per_y, num_frames, hyper_parameters, seed - ) - - output = vmap_get_trace_log_likelihood(traces, 6, parameters, hyper_parameters) - assert output.shape == (len(y_list) * traces_per_y, 1)