From b92afc95de17d1d6c043e28b147b458942036488 Mon Sep 17 00:00:00 2001 From: stefanradev93 Date: Wed, 23 Oct 2024 19:39:18 -0400 Subject: [PATCH] Adapt examples --- README.md | 4 +- ...ynb => Bayesian_Experimental_Design.ipynb} | 4 +- examples/TwoMoons_ConsistencyModel.ipynb | 704 ------------------ examples/TwoMoons_StarterNotebook.ipynb | 190 +++-- 4 files changed, 95 insertions(+), 807 deletions(-) rename examples/{michaelis_menten_BED_tutorial.ipynb => Bayesian_Experimental_Design.ipynb} (99%) delete mode 100644 examples/TwoMoons_ConsistencyModel.ipynb diff --git a/README.md b/README.md index 17433280..0ea14988 100644 --- a/README.md +++ b/README.md @@ -87,8 +87,8 @@ conda env create --file environment.yaml --name bayesflow Check out some of our walk-through notebooks below. We are actively working on porting all notebooks to the new interface so more will be available soon! -1. [Two moons toy example](examples/TwoMoons_FlowMatching.ipynb) -2. [Bayesian experimental design (BED)](examples/michaelis_menten_BED_tutorial.ipynb) +1. [Two moons starter toy example](examples/TwoMoons_StarterNotebook.ipynb) +2. [Bayesian experimental design (BED)](examples/Bayesian_Experimental_Design.ipynb) 3. Coming soon... ## Documentation \& Help diff --git a/examples/michaelis_menten_BED_tutorial.ipynb b/examples/Bayesian_Experimental_Design.ipynb similarity index 99% rename from examples/michaelis_menten_BED_tutorial.ipynb rename to examples/Bayesian_Experimental_Design.ipynb index a2a13cd9..af875fbc 100644 --- a/examples/michaelis_menten_BED_tutorial.ipynb +++ b/examples/Bayesian_Experimental_Design.ipynb @@ -752,7 +752,7 @@ ], "metadata": { "kernelspec": { - "display_name": "bf-torch", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -766,7 +766,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.9" + "version": "3.11.5" } }, "nbformat": 4, diff --git a/examples/TwoMoons_ConsistencyModel.ipynb b/examples/TwoMoons_ConsistencyModel.ipynb deleted file mode 100644 index 7b1efa7c..00000000 --- a/examples/TwoMoons_ConsistencyModel.ipynb +++ /dev/null @@ -1,704 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "009b6adf", - "metadata": {}, - "source": [ - "# Consistency Models for Posterior Estimation\n" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "d5f88a59", - "metadata": { - "ExecuteTime": { - "end_time": "2024-09-23T14:39:46.551814Z", - "start_time": "2024-09-23T14:39:46.032170Z" - } - }, - "outputs": [], - "source": [ - "import matplotlib.pyplot as plt\n", - "import numpy as np\n", - "import seaborn as sns\n", - "\n", - "# ensure the backend is set\n", - "import os\n", - "if \"KERAS_BACKEND\" not in os.environ:\n", - " # set this to \"torch\", \"tensorflow\", or \"jax\"\n", - " os.environ[\"KERAS_BACKEND\"] = \"jax\"\n", - "\n", - "import keras\n", - "\n", - "# for bayesflow devs: this ensures that the latest dev version can be found\n", - "import sys\n", - "sys.path.append('../')\n", - "\n", - "import bayesflow as bf" - ] - }, - { - "cell_type": "markdown", - "id": "eadaf793-ab63-4f69-b962-178e343ca21b", - "metadata": {}, - "source": [ - "In this notebook, we use Consistency Models (CMs) as a plug-in replacement to obtain posterior samples with fewer sampling steps.\n", - "\n", - "CMs can be trained in two ways: First, they can be used to _distill_ an existing score-based diffusion model, thereby massively decreasing the sampling time at the expense of an additional training phase. Second, they can be trained from scratch using a procedure named _Consistency Training_. For now, we only support the latter.\n" - ] - }, - { - "cell_type": "markdown", - "id": "6286c800-460a-4881-87d8-c3aca7aeec70", - "metadata": {}, - "source": [ - "## Background\n" - ] - }, - { - "cell_type": "markdown", - "id": "fdff817a-6321-4af0-9d41-7ec80097f93b", - "metadata": {}, - "source": [ - "Consistency Models [1] leverage some nice properties of score-based diffusion to enable few-step sampling. Score-based diffusion initially relied on a stochastic differential equation (SDE) for sampling, but there is also a ordinary (non-stochastic) differential equation (ODE) has the same _marginal_ distribution at each time step $t$ [2]. This means that even though SDE and ODE produce different paths from the noise distribution to the target distribution, the resulting distributions when looking at many paths at time $t$ is the same. The ODE is also called Probability Flow ODE.\n" - ] - }, - { - "cell_type": "markdown", - "id": "4a2e996d-355a-4fab-8347-728e563c6014", - "metadata": {}, - "source": [ - "CMs now leverage the fact that there is no randomness in the ODE formulation. That means, if you start at a certain point in the latent space, you will always take the same path and always end up at the same point in the data space. The same is true for every point on the path: if you integrate to get to time $t=0$, you will end up at the same point as well. In short: for each path, there is exactly one corresponding point in latent space (at $t=T$) and one corresponding point in data space (at $t=0$). The goal of CMs is now the following: each point at a time point $t$ belongs to exactly one path, and we want to predict where this path will end up at $t=0$. The function that does this is called the _consistency function_ $f$. If we have the correct function for all $t\\in(0,T]$, we can just sample from the latent distribution ($t=T$) and use $f$ to directly map to the corresponding point at $t=0$, which is in the target distribution. So for sampling from the target distribution, we avoid any integration and only need one evaluation of the consistency function. In practice, the one-step sampling does not work very well. Instead, we leverage a multi-step sampling method where we call $f$ multiple times. Please check out the [1] for more background on this sampling procedure.\n" - ] - }, - { - "cell_type": "markdown", - "id": "25023294-3096-4ebc-83a6-a372208e0504", - "metadata": {}, - "source": [ - "When only reading the above you might wonder why we also learn the mapping to $t=0$ of all intermediate time steps $t\\in[0, T]$, and not only for $t=T$. The main answer is that for efficient training, we do not want to actually compute the two associated points explicitly. Doing so would require to do a precise integration at training time, which is often not feasible as it is too computationally costly. Learning all time steps opens up the possibility for a different training approach where we can avoid this.\n" - ] - }, - { - "cell_type": "markdown", - "id": "9e6eb121-f89f-4268-9a0d-6fa733c758ff", - "metadata": {}, - "source": [ - "The details of this become a bit more complicated, and we advise you to take a look at [1] if you are interested in a more thorough and mathematical discussion. Here we will give a rough description of the underlying concepts.\n" - ] - }, - { - "cell_type": "markdown", - "id": "9b201899-b946-4432-bcac-106b9d580d32", - "metadata": {}, - "source": [ - "First, we know that at $t=0$, it holds that $f(\\theta,t=0)=\\theta$, as $\\theta$ is part of the path that ends at $\\theta$. This _boundary condition_ serves as an \"anchor\" for our training, this is the information that the network knows at the start of the training procedure (we encode it with a time-dependent skip-connection, so the network is forced to be the identity function at $t=0$).\n" - ] - }, - { - "cell_type": "markdown", - "id": "a11f7ac7-9c11-49c1-b29c-3f0d44c4bf49", - "metadata": {}, - "source": [ - "For training, we now somehow have to propagate this information to the rest of the part. The basic idea for this is simple. We just take a point $\\theta_1$ closer to the data distribution (smaller time $t_1$) and integrate for a small time step $dt$ to a point $\\theta_2$ on the same path that is closer to the latent distribution (larger time $t_2=t_1+dt$). As we know that for $t=0$ our network provides the correct output for our path, we want to propagate the information from smaller times to larger times. Our training goal is to move the output of $f(\\theta_2, t=t_2)$ towards the output of $f(\\theta_1, t=t_1)$. How to choose $\\theta_1$, $t_1$ and $dt$ is an empirical question, see the [1] for some thoughts on what works well.\n" - ] - }, - { - "cell_type": "markdown", - "id": "73ff5ed5-dbd0-423b-b4f4-8d209947ed87", - "metadata": {}, - "source": [ - "In the case of _distillation_, we start with a trained score-based diffusion model. We can use it to integrate the Probability Flow ODE to get from $\\theta_1$ to $\\theta_2$. If we do not have such a model, it seems as if we were stuck. We do not know which points lie on the same path, so we do not know which outputs to make similar. Fortunately, it turns out that there is an _unbiased approximator_ that, if averaged over many samples (check out the paper for the exact description), will also give us the correct score. If we use this approximator instead of the score model, and use only a single Euler step to move along the path, we get an algorithm similar to the one described for distillation. It is called Consistency Training (CT) and allows us to train a consistency model using only _samples_ from the data distribution. The algorithm for this was improved a lot in [3], and we have incorporated those improvements into our implementation.\n" - ] - }, - { - "cell_type": "markdown", - "id": "283e8fea-9a36-4f8a-80f5-19e2fe98bd09", - "metadata": {}, - "source": [ - "We have made several approximations to get to a standalone Consistency Training algorithm. As a consequence, the introduced hyperparameters and their choice unfortunately becomes somewhat unintuitive. We have to rely on empirical observations and heuristics to see what works. This was done in [4], we encourage you to use the values provided there as starting points. If you happen to find hyperparameters that work significantly better, please let us know (e.g., by opening an issue or sending an email). This will help others to find the correct region in the hyperparameter space.\n" - ] - }, - { - "cell_type": "markdown", - "id": "8da3535f-0354-40a4-991f-845c33ff75b8", - "metadata": {}, - "source": [ - "To make this work for Bayesian inverse problems in simulation-based inference, we can make the whole process conditional on some quantity $x$, so we can produce conditional distributions as well. Below, you can see a conceptual visualization of posterior estimation with CMs.\n" - ] - }, - { - "cell_type": "markdown", - "id": "6c105be4-c490-4100-9502-60dd7405cfb4", - "metadata": {}, - "source": [ - "![Visualization of the way consistency models map from the path to the end point in the data distribution. Depicts the concepts described in the main text.](https://arxiv.org/html/2312.05440v2/extracted/5435837/figures/cmpe_main.png)\n" - ] - }, - { - "cell_type": "markdown", - "id": "baafb8fd-5b14-4ddf-a6dd-8272f706deaa", - "metadata": {}, - "source": [ - "### References\n", - "\n", - "[1] Song, Y., Dhariwal, P., Chen, M., & Sutskever, I. (2023). Consistency Models. _arXiv preprint_. [https://doi.org/10.48550/arXiv.2303.01469](https://doi.org/10.48550/arXiv.2303.01469)\n", - "\n", - "[2] Song, Y., Sohl-Dickstein, J., Kingma, D. P., Kumar, A., Ermon, S., & Poole, B. (2021). Score-Based Generative Modeling through Stochastic Differential Equations. In _International Conference on Learning Representations_. [https://openreview.net/forum?id=PxTIG12RRHS](https://openreview.net/forum?id=PxTIG12RRHS)\n", - "\n", - "[3] Song, Y., & Dhariwal, P. (2023). Improved Techniques for Training Consistency Models. _arXiv preprint_. [https://doi.org/10.48550/arXiv.2310.14189](https://doi.org/10.48550/arXiv.2310.14189)\n", - "\n", - "[4] Schmitt, M., Pratz, V., Köthe, U., Bürkner, P.-C., & Radev, S. T. (2024). Consistency Models for Scalable and Fast Simulation-Based Inference. _arXiv preprint_. [https://doi.org/10.48550/arXiv.2312.05440](https://doi.org/10.48550/arXiv.2312.05440)\n" - ] - }, - { - "cell_type": "markdown", - "id": "c63b26ba", - "metadata": {}, - "source": [ - "## Simulator: Two Moons\n" - ] - }, - { - "cell_type": "markdown", - "id": "9525ffd7", - "metadata": {}, - "source": [ - "We will use the Concistency Model as a plug-in replacement for Flow Matching. Check out the tutorial \"Two moons toy example with flow matching\" for more details on the simulator and setting.\n" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "4b89c861527c13b8", - "metadata": { - "ExecuteTime": { - "end_time": "2024-09-23T14:39:46.747091Z", - "start_time": "2024-09-23T14:39:46.744830Z" - } - }, - "outputs": [], - "source": [ - "simulator = bf.benchmarks.simulators.TwoMoons()" - ] - }, - { - "cell_type": "markdown", - "id": "f6e1eb5777c59eba", - "metadata": {}, - "source": [ - "We generate some data to see what the simulator does:\n" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "e6218e61d529e357", - "metadata": { - "ExecuteTime": { - "end_time": "2024-09-23T14:39:46.798575Z", - "start_time": "2024-09-23T14:39:46.790581Z" - } - }, - "outputs": [], - "source": [ - "# generate 64 random draws from the joint distribution p(r, alpha, theta, x)\n", - "sample_data = simulator.sample((64,))" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "46174ccb0167026c", - "metadata": { - "ExecuteTime": { - "end_time": "2024-09-23T14:39:46.854911Z", - "start_time": "2024-09-23T14:39:46.852129Z" - } - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Type of sample_data:\n", - "\t \n", - "Keys of sample_data:\n", - "\t dict_keys(['parameters', 'observables'])\n", - "Types of sample_data values:\n", - "\t {'parameters': , 'observables': }\n", - "Shapes of sample_data values:\n", - "\t {'parameters': (64, 2), 'observables': (64, 2)}\n" - ] - } - ], - "source": [ - "print(\"Type of sample_data:\\n\\t\", type(sample_data))\n", - "print(\"Keys of sample_data:\\n\\t\", sample_data.keys())\n", - "print(\"Types of sample_data values:\\n\\t\", {k: type(v) for k, v in sample_data.items()})\n", - "print(\"Shapes of sample_data values:\\n\\t\", {k: v.shape for k, v in sample_data.items()})" - ] - }, - { - "cell_type": "markdown", - "id": "fee88fcfd7a373b0", - "metadata": {}, - "source": [ - "## Data Adapter\n", - "\n", - "The next step is to tell BayesFlow how to deal with the simulated variables. You may also think of this as informing BayesFlow about the data flow, i.e., which variables go into which network.\n", - "\n", - "For this example, we want to learn the posterior distribution $p(\\theta\\,|\\,x)$, so we **infer** $\\theta$, **conditioning** on $x$. In the output from the last command, we see that the simulator provides $\\theta$ as `\"parameters\"` and $x$ as `\"observables\"`.\n" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "c9637c576d4ad4e5", - "metadata": { - "ExecuteTime": { - "end_time": "2024-09-23T14:39:46.905081Z", - "start_time": "2024-09-23T14:39:46.903091Z" - } - }, - "outputs": [], - "source": [ - "data_adapter = bf.ContinuousApproximator.build_data_adapter(\n", - " inference_variables=[\"parameters\"],\n", - " inference_conditions=[\"observables\"],\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "254e287b2bccdad", - "metadata": {}, - "source": [ - "## Dataset\n", - "\n", - "For this example, we will sample our training data ahead of time and use offline training with a `bf.datasets.OfflineDataset`.\n" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "39cb5a1c9824246f", - "metadata": { - "ExecuteTime": { - "end_time": "2024-09-23T14:39:46.950573Z", - "start_time": "2024-09-23T14:39:46.948624Z" - } - }, - "outputs": [], - "source": [ - "batch_size = 64\n", - "num_training_batches = 512\n", - "num_validation_batches = 128" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "9dee7252ef99affa", - "metadata": { - "ExecuteTime": { - "end_time": "2024-09-23T14:39:53.268860Z", - "start_time": "2024-09-23T14:39:46.994697Z" - } - }, - "outputs": [], - "source": [ - "training_samples = simulator.sample((num_training_batches * batch_size,))\n", - "validation_samples = simulator.sample((num_validation_batches * batch_size,))" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "51045bbed88cb5c2", - "metadata": { - "ExecuteTime": { - "end_time": "2024-09-23T14:39:53.281170Z", - "start_time": "2024-09-23T14:39:53.275921Z" - } - }, - "outputs": [], - "source": [ - "training_dataset = bf.datasets.OfflineDataset(training_samples, batch_size=batch_size, data_adapter=data_adapter)\n", - "validation_dataset = bf.datasets.OfflineDataset(validation_samples, batch_size=batch_size, data_adapter=data_adapter)" - ] - }, - { - "cell_type": "markdown", - "id": "2d4c6eb0", - "metadata": {}, - "source": [ - "## Training a neural network to approximate all posteriors\n", - "\n", - "The next step is to set up the neural network that will approximate the posterior $p(\\theta\\,|\\,x)$.\n", - "\n", - "Consistency models use _scheduling functions_ to adjust some of the hyperparameters, for example the time discretization during training. Consequently, we have to specify the total number of training steps (_gradient updates_) before the start of the training.\n", - "For offline training with a given number of epochs, we can calculate it as below:\n" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "516ac3c4-b66f-4cf0-a443-b00705a6ace5", - "metadata": {}, - "outputs": [], - "source": [ - "epochs = 30\n", - "total_steps = epochs * num_training_batches" - ] - }, - { - "cell_type": "markdown", - "id": "2c7980ec-5623-43e3-847e-16a5b6eb8777", - "metadata": {}, - "source": [ - "Apart from the usual parameters like learning rate and batch size, CMs come with a number of different hyperparameters. Unfortunately, they can heavily interact, so they can be hard to tune. The main hyperparameters are:\n", - "\n", - "- Maximum time `max_time`: This also serves as the standard deviation of the latent distribution. You can experiment with this, values from 10-200 seem to work well. In any case, it should be larger than the standard deviation of the target distribution.\n", - "- Minimum/maximum number of discretization steps during training `s0`/`s1`: The effect of those is hard to grasp. 10 works well for `s0`. Intuitively, increasing `s1` along with the number of epochs should lead to better result, but in practice we sometimes observe a breakdown for high values of `s1`. This seems to be problem-dependent, so just try it out.\n", - "- `sigma2` modifies the time-dependency of the skip connection. Its effect on the training is unclear, we recommend leaving it at 1.0 or setting it to the approximate variance of the target distribution.\n", - "- Smallest time value `eps` ($t=\\epsilon$ is used instead of $t=0$ for numerical reasons): No large effect in our experiments, as long as it is kept small enough. Probably not worth tuning.\n" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "09206e6f", - "metadata": { - "ExecuteTime": { - "end_time": "2024-09-23T14:39:53.339590Z", - "start_time": "2024-09-23T14:39:53.319852Z" - } - }, - "outputs": [], - "source": [ - "# Compute the empirical variance of the draws from the prior θ ~ p(θ)\n", - "sigma2 = keras.ops.var(training_samples[\"parameters\"], axis=0, keepdims=True)\n", - "\n", - "inference_network = bf.networks.ConsistencyModel(\n", - " subnet=\"mlp\",\n", - " subnet_kwargs=dict(\n", - " depth=6,\n", - " width=256,\n", - " ),\n", - " total_steps = total_steps,\n", - " max_time=10, # works well for this task\n", - " sigma2=sigma2, # pass the empirical variance to the network\n", - " # the remaining hyperparameters (s0, s1, eps) are the default values\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "851e522f", - "metadata": {}, - "source": [ - "This inference network is just a general consistency model architecture, not yet adapted to the specific inference task at hand (i.e., posterior appproximation). To achieve this adaptation, we combine the network with our data adapter, which together form an `approximator`. In this case, we need a `ContinuousApproximator` since the target we want to approximate is the posterior of the _continuous_ parameter vector $\\theta$.\n" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "96ca6ffa", - "metadata": { - "ExecuteTime": { - "end_time": "2024-09-23T14:39:53.371691Z", - "start_time": "2024-09-23T14:39:53.369375Z" - } - }, - "outputs": [], - "source": [ - "approximator = bf.ContinuousApproximator(\n", - " inference_network=inference_network,\n", - " data_adapter=data_adapter,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "566264eadc76c2c", - "metadata": {}, - "source": [ - "### Optimizer and Learning Rate\n", - "\n", - "We use an Adam optimizer with [cosine decay](https://keras.io/api/optimizers/learning_rate_schedules/cosine_decay/) to decrease the learning rate towards zero over the training time.\n" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "id": "e8d7e053", - "metadata": { - "ExecuteTime": { - "end_time": "2024-09-23T14:39:53.433012Z", - "start_time": "2024-09-23T14:39:53.415903Z" - } - }, - "outputs": [], - "source": [ - "initial_learning_rate = 5e-4\n", - "scheduled_lr = keras.optimizers.schedules.CosineDecay(\n", - " initial_learning_rate,\n", - " total_steps,\n", - ")\n", - "\n", - "optimizer = keras.optimizers.Adam(learning_rate=scheduled_lr)" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "51808fcd560489ac", - "metadata": { - "ExecuteTime": { - "end_time": "2024-09-23T14:39:53.476089Z", - "start_time": "2024-09-23T14:39:53.466001Z" - } - }, - "outputs": [], - "source": [ - "approximator.compile(optimizer=optimizer)" - ] - }, - { - "cell_type": "markdown", - "id": "708b1303", - "metadata": {}, - "source": [ - "### Training\n", - "\n", - "We are ready to train our deep posterior approximator on the two moons example. We pass the dataset object to the `fit` method and watch as bayesflow trains. This notebook is being executed on a consumer-grade CPU and training is still reasonably fast. If you have a GPU available, training will be even faster, especially for larger networks.\n" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "0f496bda", - "metadata": { - "ExecuteTime": { - "end_time": "2024-09-23T14:42:36.067393Z", - "start_time": "2024-09-23T14:39:53.513436Z" - } - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:bayesflow:Fitting on dataset instance of OfflineDataset.\n", - "INFO:bayesflow:Building on a test batch.\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 1/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 6ms/step - loss: 0.4117 - loss/inference_loss: 0.4117 - val_loss: 0.3387 - val_loss/inference_loss: 0.3387\n", - "Epoch 2/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 6ms/step - loss: 0.3568 - loss/inference_loss: 0.3568 - val_loss: 0.3603 - val_loss/inference_loss: 0.3603\n", - "Epoch 3/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 6ms/step - loss: 0.3503 - loss/inference_loss: 0.3503 - val_loss: 0.2898 - val_loss/inference_loss: 0.2898\n", - "Epoch 4/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 6ms/step - loss: 0.3362 - loss/inference_loss: 0.3362 - val_loss: 0.4429 - val_loss/inference_loss: 0.4429\n", - "Epoch 5/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 6ms/step - loss: 0.3343 - loss/inference_loss: 0.3343 - val_loss: 0.3929 - val_loss/inference_loss: 0.3929\n", - "Epoch 6/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 6ms/step - loss: 0.3311 - loss/inference_loss: 0.3311 - val_loss: 0.2825 - val_loss/inference_loss: 0.2825\n", - "Epoch 7/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 6ms/step - loss: 0.3264 - loss/inference_loss: 0.3264 - val_loss: 0.3029 - val_loss/inference_loss: 0.3029\n", - "Epoch 8/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 6ms/step - loss: 0.3222 - loss/inference_loss: 0.3222 - val_loss: 0.3447 - val_loss/inference_loss: 0.3447\n", - "Epoch 9/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 6ms/step - loss: 0.3222 - loss/inference_loss: 0.3222 - val_loss: 0.3209 - val_loss/inference_loss: 0.3209\n", - "Epoch 10/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 6ms/step - loss: 0.3141 - loss/inference_loss: 0.3141 - val_loss: 0.2195 - val_loss/inference_loss: 0.2195\n", - "Epoch 11/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 6ms/step - loss: 0.3123 - loss/inference_loss: 0.3123 - val_loss: 0.3043 - val_loss/inference_loss: 0.3043\n", - "Epoch 12/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 6ms/step - loss: 0.3121 - loss/inference_loss: 0.3121 - val_loss: 0.3225 - val_loss/inference_loss: 0.3225\n", - "Epoch 13/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 6ms/step - loss: 0.3089 - loss/inference_loss: 0.3089 - val_loss: 0.2082 - val_loss/inference_loss: 0.2082\n", - "Epoch 14/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 6ms/step - loss: 0.3028 - loss/inference_loss: 0.3028 - val_loss: 0.2394 - val_loss/inference_loss: 0.2394\n", - "Epoch 15/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 6ms/step - loss: 0.2996 - loss/inference_loss: 0.2996 - val_loss: 0.3735 - val_loss/inference_loss: 0.3735\n", - "Epoch 16/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 6ms/step - loss: 0.2949 - loss/inference_loss: 0.2949 - val_loss: 0.2624 - val_loss/inference_loss: 0.2624\n", - "Epoch 17/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 6ms/step - loss: 0.2956 - loss/inference_loss: 0.2956 - val_loss: 0.3925 - val_loss/inference_loss: 0.3925\n", - "Epoch 18/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 6ms/step - loss: 0.2904 - loss/inference_loss: 0.2904 - val_loss: 0.2991 - val_loss/inference_loss: 0.2991\n", - "Epoch 19/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 6ms/step - loss: 0.2953 - loss/inference_loss: 0.2953 - val_loss: 0.2517 - val_loss/inference_loss: 0.2517\n", - "Epoch 20/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 6ms/step - loss: 0.2867 - loss/inference_loss: 0.2867 - val_loss: 0.3187 - val_loss/inference_loss: 0.3187\n", - "Epoch 21/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 6ms/step - loss: 0.2880 - loss/inference_loss: 0.2880 - val_loss: 0.3218 - val_loss/inference_loss: 0.3218\n", - "Epoch 22/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 6ms/step - loss: 0.2819 - loss/inference_loss: 0.2819 - val_loss: 0.2689 - val_loss/inference_loss: 0.2689\n", - "Epoch 23/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 6ms/step - loss: 0.2775 - loss/inference_loss: 0.2775 - val_loss: 0.2354 - val_loss/inference_loss: 0.2354\n", - "Epoch 24/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 6ms/step - loss: 0.2848 - loss/inference_loss: 0.2848 - val_loss: 0.2992 - val_loss/inference_loss: 0.2992\n", - "Epoch 25/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 6ms/step - loss: 0.2699 - loss/inference_loss: 0.2699 - val_loss: 0.1976 - val_loss/inference_loss: 0.1976\n", - "Epoch 26/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 6ms/step - loss: 0.2760 - loss/inference_loss: 0.2760 - val_loss: 0.3003 - val_loss/inference_loss: 0.3003\n", - "Epoch 27/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 6ms/step - loss: 0.2774 - loss/inference_loss: 0.2774 - val_loss: 0.3333 - val_loss/inference_loss: 0.3333\n", - "Epoch 28/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 6ms/step - loss: 0.2745 - loss/inference_loss: 0.2745 - val_loss: 0.2938 - val_loss/inference_loss: 0.2938\n", - "Epoch 29/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 6ms/step - loss: 0.2795 - loss/inference_loss: 0.2795 - val_loss: 0.2968 - val_loss/inference_loss: 0.2968\n", - "Epoch 30/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 6ms/step - loss: 0.2720 - loss/inference_loss: 0.2720 - val_loss: 0.2105 - val_loss/inference_loss: 0.2105\n" - ] - } - ], - "source": [ - "history = approximator.fit(\n", - " epochs=epochs,\n", - " dataset=training_dataset,\n", - " validation_data=validation_dataset,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "b90a6062", - "metadata": {}, - "source": [ - "## Validation\n" - ] - }, - { - "cell_type": "markdown", - "id": "ca62b21d", - "metadata": {}, - "source": [ - "### Two Moons Posterior\n", - "\n", - "By design, the two moons posterior at point $x = (0, 0)$ should resemble two crescent moons, hence the name. Below, we plot the corresponding posterior samples.\n", - "\n", - "These results suggest that our consistency model posterior estimation setup can approximate the target posterior well. You can achieve an even better fit if you use online training, more epochs, or better hyperparameters. We won't do that here because this tutorial shall only illustrate the basic setup for consistency models in amortized inference with bayesflow.\n" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "id": "8562caeb", - "metadata": { - "ExecuteTime": { - "end_time": "2024-09-23T14:42:38.584554Z", - "start_time": "2024-09-23T14:42:36.076923Z" - } - }, - "outputs": [ - { - "data": { - "text/plain": [ - "(-0.4, 0.4)" - ] - }, - "execution_count": 15, - "metadata": {}, - "output_type": "execute_result" - }, - { - "data": { - "image/png": "", - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "# Set the number of posterior draws you want to get\n", - "num_samples = 3000\n", - "\n", - "# Obtain samples from amortized posterior\n", - "conditions = {\"observables\": np.array([[0.0, 0.0]]).astype(\"float32\")}\n", - "samples_at_origin = approximator.sample(conditions=conditions, num_samples=num_samples)[\"parameters\"]\n", - "\n", - "# Prepare figure\n", - "f, axes = plt.subplots(1, figsize=(6, 6))\n", - "\n", - "# Plot samples\n", - "axes.scatter(samples_at_origin[0, :, 0], samples_at_origin[0, :, 1], color=\"#153c7a\", alpha=0.75, s=0.5)\n", - "sns.despine(ax=axes)\n", - "axes.set_title(r\"Posterior samples at origin $x=(0, 0)$\")\n", - "axes.grid(alpha=0.3)\n", - "axes.set_aspect(\"equal\", adjustable=\"box\")\n", - "axes.set_xlim([-0.4, 0.4])\n", - "axes.set_ylim([-0.4, 0.4])" - ] - }, - { - "cell_type": "markdown", - "id": "01821d24", - "metadata": {}, - "source": [ - "The posterior looks as we have expected in this case. However, in general, we do not know how the posterior is supposed to look like for any specific dataset. As such, we need diagnostics that validate the correctness of the inferred posterior. One such diagnostic is simulation-based calibration (SBC), which we can compute essentially for free due to amortization. For more details on SBC and diagnostic plots, see:\n", - "\n", - "1. Talts, S., Betancourt, M., Simpson, D., Vehtari, A., & Gelman, A. (2018). Validating Bayesian inference algorithms with simulation-based calibration. _arXiv preprint_.\n", - "2. Säilynoja, T., Bürkner, P. C., & Vehtari, A. (2022). Graphical test for discrete uniformity and its applications in goodness-of-fit evaluation and multiple sample comparison. _Statistics and Computing_.\n" - ] - }, - { - "cell_type": "markdown", - "id": "cb38d0c8", - "metadata": {}, - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "bayesflow", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.9" - }, - "toc": { - "base_numbering": 1, - "nav_menu": {}, - "number_sections": true, - "sideBar": true, - "skip_h1_title": true, - "title_cell": "Table of Contents", - "title_sidebar": "Contents", - "toc_cell": true, - "toc_position": { - "height": "calc(100% - 180px)", - "left": "10px", - "top": "150px", - "width": "165px" - }, - "toc_section_display": true, - "toc_window_display": true - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/examples/TwoMoons_StarterNotebook.ipynb b/examples/TwoMoons_StarterNotebook.ipynb index 8151755e..299c3094 100644 --- a/examples/TwoMoons_StarterNotebook.ipynb +++ b/examples/TwoMoons_StarterNotebook.ipynb @@ -10,7 +10,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 15, "id": "d5f88a59", "metadata": { "ExecuteTime": { @@ -18,21 +18,13 @@ "start_time": "2024-09-23T14:39:46.032170Z" } }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n" - ] - } - ], + "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import seaborn as sns\n", "\n", - "# ensure the backend is set\n", + "# Ensure the backend is set\n", "import os\n", "if \"KERAS_BACKEND\" not in os.environ:\n", " # set this to \"torch\", \"tensorflow\", or \"jax\"\n", @@ -40,7 +32,7 @@ "\n", "import keras\n", "\n", - "# for BayesFlow devs: this ensures that the latest dev version can be found\n", + "# For BayesFlow devs: this ensures that the latest dev version can be found\n", "import sys\n", "sys.path.append('../')\n", "\n", @@ -90,7 +82,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 16, "id": "f761b142a0e1da66", "metadata": { "ExecuteTime": { @@ -122,7 +114,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 17, "id": "4b89c861527c13b8", "metadata": { "ExecuteTime": { @@ -145,7 +137,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 18, "id": "e6218e61d529e357", "metadata": { "ExecuteTime": { @@ -161,7 +153,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 19, "id": "46174ccb0167026c", "metadata": { "ExecuteTime": { @@ -211,12 +203,12 @@ "\n", "For this example, we want to learn the posterior distribution $p(\\theta | x)$, so we **infer** $\\theta$, **conditioning** on $x$.\n", "\n", - "First, we rename the raw simulator outputs so that trhe neural networks know how interpret them: the $\\theta$ vectors becomes the variables to be inferred (i.e., `inference_variables`) and the $\\x$ vector is designated as the variables to use as conditions (i.e., `inference_conditions`). " + "First, we rename the raw simulator outputs so that trhe neural networks know how interpret them: the $\\theta$ vectors becomes the variables to be inferred (i.e., `inference_variables`) and the $x$ vector is designated as the variables to use as conditions (i.e., `inference_conditions`). " ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 20, "id": "c9637c576d4ad4e5", "metadata": { "ExecuteTime": { @@ -248,7 +240,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 21, "id": "39cb5a1c9824246f", "metadata": { "ExecuteTime": { @@ -267,7 +259,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 22, "id": "9dee7252ef99affa", "metadata": { "ExecuteTime": { @@ -283,7 +275,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 23, "id": "51045bbed88cb5c2", "metadata": { "ExecuteTime": { @@ -324,7 +316,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 24, "id": "09206e6f", "metadata": { "ExecuteTime": { @@ -336,7 +328,7 @@ "source": [ "inference_network = bf.networks.FlowMatching(\n", " subnet=\"mlp\", \n", - " subnet_kwargs={\"depth\": 6, \"width\": 256}\n", + " subnet_kwargs={\"widths\": (256,)*6} # use an inner network with 6 hidden layers of 256 units\n", ")" ] }, @@ -350,7 +342,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 25, "id": "96ca6ffa", "metadata": { "ExecuteTime": { @@ -377,7 +369,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 26, "id": "e8d7e053", "metadata": { "ExecuteTime": { @@ -399,7 +391,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 27, "id": "51808fcd560489ac", "metadata": { "ExecuteTime": { @@ -424,7 +416,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 28, "id": "0f496bda", "metadata": { "ExecuteTime": { @@ -446,67 +438,67 @@ "output_type": "stream", "text": [ "Epoch 1/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 5ms/step - loss: 0.4254 - loss/inference_loss: 0.4254 - val_loss: 0.4950 - val_loss/inference_loss: 0.4950\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 4ms/step - loss: 0.4225 - loss/inference_loss: 0.4225 - val_loss: 0.3897 - val_loss/inference_loss: 0.3897\n", "Epoch 2/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 5ms/step - loss: 0.3796 - loss/inference_loss: 0.3796 - val_loss: 0.3187 - val_loss/inference_loss: 0.3187\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 3ms/step - loss: 0.3764 - loss/inference_loss: 0.3764 - val_loss: 0.2469 - val_loss/inference_loss: 0.2469\n", "Epoch 3/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 5ms/step - loss: 0.3694 - loss/inference_loss: 0.3694 - val_loss: 0.3467 - val_loss/inference_loss: 0.3467\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 4ms/step - loss: 0.3680 - loss/inference_loss: 0.3680 - val_loss: 0.3456 - val_loss/inference_loss: 0.3456\n", "Epoch 4/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 5ms/step - loss: 0.3581 - loss/inference_loss: 0.3581 - val_loss: 0.4014 - val_loss/inference_loss: 0.4014\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 4ms/step - loss: 0.3596 - loss/inference_loss: 0.3596 - val_loss: 0.3564 - val_loss/inference_loss: 0.3564\n", "Epoch 5/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 5ms/step - loss: 0.3489 - loss/inference_loss: 0.3489 - val_loss: 0.3101 - val_loss/inference_loss: 0.3101\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 5ms/step - loss: 0.3558 - loss/inference_loss: 0.3558 - val_loss: 0.3258 - val_loss/inference_loss: 0.3258\n", "Epoch 6/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 5ms/step - loss: 0.3515 - loss/inference_loss: 0.3515 - val_loss: 0.3398 - val_loss/inference_loss: 0.3398\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 5ms/step - loss: 0.3507 - loss/inference_loss: 0.3507 - val_loss: 0.2755 - val_loss/inference_loss: 0.2755\n", "Epoch 7/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 5ms/step - loss: 0.3528 - loss/inference_loss: 0.3528 - val_loss: 0.3643 - val_loss/inference_loss: 0.3643\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 4ms/step - loss: 0.3450 - loss/inference_loss: 0.3450 - val_loss: 0.3038 - val_loss/inference_loss: 0.3038\n", "Epoch 8/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 6ms/step - loss: 0.3402 - loss/inference_loss: 0.3402 - val_loss: 0.2596 - val_loss/inference_loss: 0.2596\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 4ms/step - loss: 0.3391 - loss/inference_loss: 0.3391 - val_loss: 0.2291 - val_loss/inference_loss: 0.2291\n", "Epoch 9/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 6ms/step - loss: 0.3398 - loss/inference_loss: 0.3398 - val_loss: 0.4423 - val_loss/inference_loss: 0.4423\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 4ms/step - loss: 0.3451 - loss/inference_loss: 0.3451 - val_loss: 0.3416 - val_loss/inference_loss: 0.3416\n", "Epoch 10/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 6ms/step - loss: 0.3418 - loss/inference_loss: 0.3418 - val_loss: 0.3876 - val_loss/inference_loss: 0.3876\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 4ms/step - loss: 0.3408 - loss/inference_loss: 0.3408 - val_loss: 0.2305 - val_loss/inference_loss: 0.2305\n", "Epoch 11/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 6ms/step - loss: 0.3410 - loss/inference_loss: 0.3410 - val_loss: 0.2288 - val_loss/inference_loss: 0.2288\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 4ms/step - loss: 0.3306 - loss/inference_loss: 0.3306 - val_loss: 0.3630 - val_loss/inference_loss: 0.3630\n", "Epoch 12/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 6ms/step - loss: 0.3311 - loss/inference_loss: 0.3311 - val_loss: 0.2649 - val_loss/inference_loss: 0.2649\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 4ms/step - loss: 0.3383 - loss/inference_loss: 0.3383 - val_loss: 0.4263 - val_loss/inference_loss: 0.4263\n", "Epoch 13/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 6ms/step - loss: 0.3270 - loss/inference_loss: 0.3270 - val_loss: 0.3067 - val_loss/inference_loss: 0.3067\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 3ms/step - loss: 0.3296 - loss/inference_loss: 0.3296 - val_loss: 0.3179 - val_loss/inference_loss: 0.3179\n", "Epoch 14/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 5ms/step - loss: 0.3285 - loss/inference_loss: 0.3285 - val_loss: 0.4922 - val_loss/inference_loss: 0.4922\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 3ms/step - loss: 0.3309 - loss/inference_loss: 0.3309 - val_loss: 0.6036 - val_loss/inference_loss: 0.6036\n", "Epoch 15/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 5ms/step - loss: 0.3292 - loss/inference_loss: 0.3292 - val_loss: 0.2729 - val_loss/inference_loss: 0.2729\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 3ms/step - loss: 0.3280 - loss/inference_loss: 0.3280 - val_loss: 0.3043 - val_loss/inference_loss: 0.3043\n", "Epoch 16/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 6ms/step - loss: 0.3288 - loss/inference_loss: 0.3288 - val_loss: 0.4003 - val_loss/inference_loss: 0.4003\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 3ms/step - loss: 0.3322 - loss/inference_loss: 0.3322 - val_loss: 0.2144 - val_loss/inference_loss: 0.2144\n", "Epoch 17/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 6ms/step - loss: 0.3322 - loss/inference_loss: 0.3322 - val_loss: 0.2052 - val_loss/inference_loss: 0.2052\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 3ms/step - loss: 0.3170 - loss/inference_loss: 0.3170 - val_loss: 0.3984 - val_loss/inference_loss: 0.3984\n", "Epoch 18/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 6ms/step - loss: 0.3218 - loss/inference_loss: 0.3218 - val_loss: 0.3458 - val_loss/inference_loss: 0.3458\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 3ms/step - loss: 0.3236 - loss/inference_loss: 0.3236 - val_loss: 0.3907 - val_loss/inference_loss: 0.3907\n", "Epoch 19/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 6ms/step - loss: 0.3223 - loss/inference_loss: 0.3223 - val_loss: 0.2928 - val_loss/inference_loss: 0.2928\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 3ms/step - loss: 0.3203 - loss/inference_loss: 0.3203 - val_loss: 0.4728 - val_loss/inference_loss: 0.4728\n", "Epoch 20/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 6ms/step - loss: 0.3165 - loss/inference_loss: 0.3165 - val_loss: 0.3053 - val_loss/inference_loss: 0.3053\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 3ms/step - loss: 0.3200 - loss/inference_loss: 0.3200 - val_loss: 0.2013 - val_loss/inference_loss: 0.2013\n", "Epoch 21/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 6ms/step - loss: 0.3220 - loss/inference_loss: 0.3220 - val_loss: 0.3217 - val_loss/inference_loss: 0.3217\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 3ms/step - loss: 0.3175 - loss/inference_loss: 0.3175 - val_loss: 0.3143 - val_loss/inference_loss: 0.3143\n", "Epoch 22/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 6ms/step - loss: 0.3233 - loss/inference_loss: 0.3233 - val_loss: 0.3606 - val_loss/inference_loss: 0.3606\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 3ms/step - loss: 0.3202 - loss/inference_loss: 0.3202 - val_loss: 0.2706 - val_loss/inference_loss: 0.2706\n", "Epoch 23/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 6ms/step - loss: 0.3169 - loss/inference_loss: 0.3169 - val_loss: 0.2697 - val_loss/inference_loss: 0.2697\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 3ms/step - loss: 0.3138 - loss/inference_loss: 0.3138 - val_loss: 0.3042 - val_loss/inference_loss: 0.3042\n", "Epoch 24/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 6ms/step - loss: 0.3141 - loss/inference_loss: 0.3141 - val_loss: 0.2131 - val_loss/inference_loss: 0.2131\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 3ms/step - loss: 0.3097 - loss/inference_loss: 0.3097 - val_loss: 0.2372 - val_loss/inference_loss: 0.2372\n", "Epoch 25/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 6ms/step - loss: 0.3126 - loss/inference_loss: 0.3126 - val_loss: 0.3206 - val_loss/inference_loss: 0.3206\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 3ms/step - loss: 0.3079 - loss/inference_loss: 0.3079 - val_loss: 0.3317 - val_loss/inference_loss: 0.3317\n", "Epoch 26/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 6ms/step - loss: 0.3125 - loss/inference_loss: 0.3125 - val_loss: 0.3309 - val_loss/inference_loss: 0.3309\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 3ms/step - loss: 0.3149 - loss/inference_loss: 0.3149 - val_loss: 0.2426 - val_loss/inference_loss: 0.2426\n", "Epoch 27/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 6ms/step - loss: 0.3093 - loss/inference_loss: 0.3093 - val_loss: 0.2652 - val_loss/inference_loss: 0.2652\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 3ms/step - loss: 0.3126 - loss/inference_loss: 0.3126 - val_loss: 0.3095 - val_loss/inference_loss: 0.3095\n", "Epoch 28/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 6ms/step - loss: 0.3095 - loss/inference_loss: 0.3095 - val_loss: 0.3211 - val_loss/inference_loss: 0.3211\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 3ms/step - loss: 0.3100 - loss/inference_loss: 0.3100 - val_loss: 0.3371 - val_loss/inference_loss: 0.3371\n", "Epoch 29/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 6ms/step - loss: 0.3117 - loss/inference_loss: 0.3117 - val_loss: 0.2739 - val_loss/inference_loss: 0.2739\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 3ms/step - loss: 0.3066 - loss/inference_loss: 0.3066 - val_loss: 0.2807 - val_loss/inference_loss: 0.2807\n", "Epoch 30/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 6ms/step - loss: 0.3107 - loss/inference_loss: 0.3107 - val_loss: 0.2902 - val_loss/inference_loss: 0.2902\n", - "CPU times: user 2min 37s, sys: 5.22 s, total: 2min 42s\n", - "Wall time: 1min 26s\n" + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 3ms/step - loss: 0.3071 - loss/inference_loss: 0.3071 - val_loss: 0.3870 - val_loss/inference_loss: 0.3870\n", + "CPU times: total: 8.27 s\n", + "Wall time: 55.5 s\n" ] } ], @@ -590,7 +582,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 29, "id": "d53a41b8", "metadata": {}, "outputs": [], @@ -600,7 +592,7 @@ "\n", "inference_network = bf.networks.ConsistencyModel(\n", " subnet=\"mlp\",\n", - " subnet_kwargs={\"depth\": 6, \"width\": 256},\n", + " subnet_kwargs={\"widths\": (256,)*6},\n", " total_steps=total_steps,\n", " max_time=10,\n", " sigma2=sigma2,\n", @@ -623,7 +615,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 30, "id": "d1bc228a", "metadata": {}, "outputs": [], @@ -640,7 +632,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 31, "id": "41c4599f", "metadata": {}, "outputs": [], @@ -658,7 +650,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 32, "id": "c3c1a812", "metadata": {}, "outputs": [ @@ -675,67 +667,67 @@ "output_type": "stream", "text": [ "Epoch 1/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 9ms/step - loss: 0.3980 - loss/inference_loss: 0.3980 - val_loss: 0.4073 - val_loss/inference_loss: 0.4073\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 5ms/step - loss: 0.3919 - loss/inference_loss: 0.3919 - val_loss: 0.2735 - val_loss/inference_loss: 0.2735\n", "Epoch 2/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 8ms/step - loss: 0.3501 - loss/inference_loss: 0.3501 - val_loss: 0.2820 - val_loss/inference_loss: 0.2820\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 4ms/step - loss: 0.3469 - loss/inference_loss: 0.3469 - val_loss: 0.3136 - val_loss/inference_loss: 0.3136\n", "Epoch 3/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 9ms/step - loss: 0.3351 - loss/inference_loss: 0.3351 - val_loss: 0.2432 - val_loss/inference_loss: 0.2432\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 4ms/step - loss: 0.3342 - loss/inference_loss: 0.3342 - val_loss: 0.3803 - val_loss/inference_loss: 0.3803\n", "Epoch 4/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 9ms/step - loss: 0.3318 - loss/inference_loss: 0.3318 - val_loss: 0.3066 - val_loss/inference_loss: 0.3066\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 4ms/step - loss: 0.3250 - loss/inference_loss: 0.3250 - val_loss: 0.2638 - val_loss/inference_loss: 0.2638\n", "Epoch 5/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 9ms/step - loss: 0.3227 - loss/inference_loss: 0.3227 - val_loss: 0.3283 - val_loss/inference_loss: 0.3283\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 4ms/step - loss: 0.3201 - loss/inference_loss: 0.3201 - val_loss: 0.3935 - val_loss/inference_loss: 0.3935\n", "Epoch 6/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 9ms/step - loss: 0.3180 - loss/inference_loss: 0.3180 - val_loss: 0.3263 - val_loss/inference_loss: 0.3263\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 4ms/step - loss: 0.3218 - loss/inference_loss: 0.3218 - val_loss: 0.4023 - val_loss/inference_loss: 0.4023\n", "Epoch 7/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 9ms/step - loss: 0.3177 - loss/inference_loss: 0.3177 - val_loss: 0.2487 - val_loss/inference_loss: 0.2487\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 4ms/step - loss: 0.3107 - loss/inference_loss: 0.3107 - val_loss: 0.1864 - val_loss/inference_loss: 0.1864\n", "Epoch 8/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 9ms/step - loss: 0.3114 - loss/inference_loss: 0.3114 - val_loss: 0.2381 - val_loss/inference_loss: 0.2381\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 4ms/step - loss: 0.3154 - loss/inference_loss: 0.3154 - val_loss: 0.2744 - val_loss/inference_loss: 0.2744\n", "Epoch 9/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 9ms/step - loss: 0.3089 - loss/inference_loss: 0.3089 - val_loss: 0.2677 - val_loss/inference_loss: 0.2677\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 4ms/step - loss: 0.3116 - loss/inference_loss: 0.3116 - val_loss: 0.2823 - val_loss/inference_loss: 0.2823\n", "Epoch 10/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 9ms/step - loss: 0.3089 - loss/inference_loss: 0.3089 - val_loss: 0.3307 - val_loss/inference_loss: 0.3307\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 4ms/step - loss: 0.3075 - loss/inference_loss: 0.3075 - val_loss: 0.2069 - val_loss/inference_loss: 0.2069\n", "Epoch 11/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 9ms/step - loss: 0.3071 - loss/inference_loss: 0.3071 - val_loss: 0.2687 - val_loss/inference_loss: 0.2687\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 4ms/step - loss: 0.3039 - loss/inference_loss: 0.3039 - val_loss: 0.2883 - val_loss/inference_loss: 0.2883\n", "Epoch 12/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 9ms/step - loss: 0.3040 - loss/inference_loss: 0.3040 - val_loss: 0.2827 - val_loss/inference_loss: 0.2827\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 4ms/step - loss: 0.3019 - loss/inference_loss: 0.3019 - val_loss: 0.1924 - val_loss/inference_loss: 0.1924\n", "Epoch 13/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 9ms/step - loss: 0.3012 - loss/inference_loss: 0.3012 - val_loss: 0.2679 - val_loss/inference_loss: 0.2679\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 4ms/step - loss: 0.2984 - loss/inference_loss: 0.2984 - val_loss: 0.3208 - val_loss/inference_loss: 0.3208\n", "Epoch 14/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 9ms/step - loss: 0.3009 - loss/inference_loss: 0.3009 - val_loss: 0.2093 - val_loss/inference_loss: 0.2093\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 4ms/step - loss: 0.2991 - loss/inference_loss: 0.2991 - val_loss: 0.2844 - val_loss/inference_loss: 0.2844\n", "Epoch 15/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 9ms/step - loss: 0.2975 - loss/inference_loss: 0.2975 - val_loss: 0.2663 - val_loss/inference_loss: 0.2663\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 4ms/step - loss: 0.2912 - loss/inference_loss: 0.2912 - val_loss: 0.2385 - val_loss/inference_loss: 0.2385\n", "Epoch 16/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 9ms/step - loss: 0.2949 - loss/inference_loss: 0.2949 - val_loss: 0.2288 - val_loss/inference_loss: 0.2288\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 4ms/step - loss: 0.2955 - loss/inference_loss: 0.2955 - val_loss: 0.1594 - val_loss/inference_loss: 0.1594\n", "Epoch 17/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 9ms/step - loss: 0.2897 - loss/inference_loss: 0.2897 - val_loss: 0.2794 - val_loss/inference_loss: 0.2794\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 4ms/step - loss: 0.2888 - loss/inference_loss: 0.2888 - val_loss: 0.2615 - val_loss/inference_loss: 0.2615\n", "Epoch 18/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 9ms/step - loss: 0.2862 - loss/inference_loss: 0.2862 - val_loss: 0.2403 - val_loss/inference_loss: 0.2403\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 4ms/step - loss: 0.2893 - loss/inference_loss: 0.2893 - val_loss: 0.2004 - val_loss/inference_loss: 0.2004\n", "Epoch 19/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 9ms/step - loss: 0.2896 - loss/inference_loss: 0.2896 - val_loss: 0.3645 - val_loss/inference_loss: 0.3645\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 4ms/step - loss: 0.2864 - loss/inference_loss: 0.2864 - val_loss: 0.2238 - val_loss/inference_loss: 0.2238\n", "Epoch 20/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 9ms/step - loss: 0.2845 - loss/inference_loss: 0.2845 - val_loss: 0.2376 - val_loss/inference_loss: 0.2376\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 4ms/step - loss: 0.2876 - loss/inference_loss: 0.2876 - val_loss: 0.2310 - val_loss/inference_loss: 0.2310\n", "Epoch 21/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 9ms/step - loss: 0.2845 - loss/inference_loss: 0.2845 - val_loss: 0.1906 - val_loss/inference_loss: 0.1906\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 4ms/step - loss: 0.2834 - loss/inference_loss: 0.2834 - val_loss: 0.2057 - val_loss/inference_loss: 0.2057\n", "Epoch 22/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 9ms/step - loss: 0.2882 - loss/inference_loss: 0.2882 - val_loss: 0.2277 - val_loss/inference_loss: 0.2277\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 4ms/step - loss: 0.2812 - loss/inference_loss: 0.2812 - val_loss: 0.2432 - val_loss/inference_loss: 0.2432\n", "Epoch 23/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 9ms/step - loss: 0.2857 - loss/inference_loss: 0.2857 - val_loss: 0.1923 - val_loss/inference_loss: 0.1923\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 4ms/step - loss: 0.2823 - loss/inference_loss: 0.2823 - val_loss: 0.2854 - val_loss/inference_loss: 0.2854\n", "Epoch 24/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 9ms/step - loss: 0.2836 - loss/inference_loss: 0.2836 - val_loss: 0.3887 - val_loss/inference_loss: 0.3887\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 4ms/step - loss: 0.2795 - loss/inference_loss: 0.2795 - val_loss: 0.4341 - val_loss/inference_loss: 0.4341\n", "Epoch 25/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 9ms/step - loss: 0.2822 - loss/inference_loss: 0.2822 - val_loss: 0.2481 - val_loss/inference_loss: 0.2481\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 4ms/step - loss: 0.2859 - loss/inference_loss: 0.2859 - val_loss: 0.2324 - val_loss/inference_loss: 0.2324\n", "Epoch 26/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 9ms/step - loss: 0.2812 - loss/inference_loss: 0.2812 - val_loss: 0.2629 - val_loss/inference_loss: 0.2629\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 4ms/step - loss: 0.2804 - loss/inference_loss: 0.2804 - val_loss: 0.3024 - val_loss/inference_loss: 0.3024\n", "Epoch 27/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 9ms/step - loss: 0.2806 - loss/inference_loss: 0.2806 - val_loss: 0.3617 - val_loss/inference_loss: 0.3617\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 4ms/step - loss: 0.2836 - loss/inference_loss: 0.2836 - val_loss: 0.1902 - val_loss/inference_loss: 0.1902\n", "Epoch 28/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 9ms/step - loss: 0.2853 - loss/inference_loss: 0.2853 - val_loss: 0.2874 - val_loss/inference_loss: 0.2874\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 4ms/step - loss: 0.2814 - loss/inference_loss: 0.2814 - val_loss: 0.1490 - val_loss/inference_loss: 0.1490\n", "Epoch 29/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 9ms/step - loss: 0.2754 - loss/inference_loss: 0.2754 - val_loss: 0.2088 - val_loss/inference_loss: 0.2088\n", + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 4ms/step - loss: 0.2762 - loss/inference_loss: 0.2762 - val_loss: 0.2249 - val_loss/inference_loss: 0.2249\n", "Epoch 30/30\n", - "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 9ms/step - loss: 0.2849 - loss/inference_loss: 0.2849 - val_loss: 0.2632 - val_loss/inference_loss: 0.2632\n", - "CPU times: user 5min 50s, sys: 1min, total: 6min 50s\n", - "Wall time: 2min 15s\n" + "\u001b[1m512/512\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 4ms/step - loss: 0.2805 - loss/inference_loss: 0.2805 - val_loss: 0.1796 - val_loss/inference_loss: 0.1796\n", + "CPU times: total: 6.89 s\n", + "Wall time: 1min 7s\n" ] } ], @@ -770,13 +762,13 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 33, "id": "073bcd0b", "metadata": {}, "outputs": [ { "data": { - "image/png": "", + "image/png": "", "text/plain": [ "
" ] @@ -857,7 +849,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.4" + "version": "3.11.5" }, "toc": { "base_numbering": 1,