Skip to content

Commit

Permalink
Merge pull request #325 from ReactiveBayes/dev-ef-projection
Browse files Browse the repository at this point in the history
Initial integration with ExponentialFamilyProjection
  • Loading branch information
bvdmitri authored Jul 19, 2024
2 parents 34afa89 + b3c1ef6 commit 839400a
Show file tree
Hide file tree
Showing 14 changed files with 1,022 additions and 43 deletions.
16 changes: 12 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "RxInfer"
uuid = "86711068-29c9-4ff7-b620-ae75d7495b3d"
authors = ["Bagaev Dmitry <[email protected]> and contributors"]
version = "3.4.0"
version = "3.5.0"

[deps]
BayesBase = "b4ee3484-f114-42fe-b91c-797d54a0c67e"
Expand All @@ -21,20 +21,27 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Rocket = "df971d30-c9d6-4b37-b8ff-e965b2cb3a40"
TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6"

[weakdeps]
ExponentialFamilyProjection = "17f509fa-9a96-44ba-99b2-1c5f01f0931b"

[extensions]
ProjectionExt = "ExponentialFamilyProjection"

[compat]
BayesBase = "1.1"
DataStructures = "0.18"
Distributions = "0.25"
DomainSets = "0.5.2, 0.6, 0.7"
ExponentialFamily = "1.2"
ExponentialFamily = "1.5"
ExponentialFamilyProjection = "1.1"
FastCholesky = "1.3.0"
GraphPPL = "~4.3.0"
LinearAlgebra = "1.9"
MacroTools = "0.5.6"
Optim = "1.0.0"
ProgressMeter = "1.0.0"
Random = "1.9"
ReactiveMP = "~4.2.0"
ReactiveMP = "~4.3.0"
Reexport = "1.2.0"
Rocket = "1.8.0"
TupleTools = "1.2.0"
Expand All @@ -49,6 +56,7 @@ CpuId = "adafc99b-e345-5852-983c-f28acb93d879"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
ExponentialFamilyProjection = "17f509fa-9a96-44ba-99b2-1c5f01f0931b"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Expand All @@ -62,4 +70,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TestSetExtensions = "98d24dd4-01ad-11ea-1b02-c9a08f80db04"

[targets]
test = ["Test", "Pkg", "Logging", "InteractiveUtils", "TestSetExtensions", "Coverage", "CpuId", "Dates", "Distributed", "Documenter", "Plots", "BenchmarkCI", "BenchmarkTools", "PkgBenchmark", "Aqua", "StableRNGs", "StatsFuns", "Optimisers", "ReTestItems"]
test = ["Test", "Pkg", "Logging", "InteractiveUtils", "TestSetExtensions", "Coverage", "CpuId", "Dates", "Distributed", "Documenter", "ExponentialFamilyProjection", "Plots", "BenchmarkCI", "BenchmarkTools", "PkgBenchmark", "Aqua", "StableRNGs", "StatsFuns", "Optimisers", "ReTestItems"]
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ Our high-level project roadmap outlines the key milestones and focus areas for t
| Q1/Q2 2024 | Q3/Q4 2024 | 2025 |
|---------------------|---------------------------|--------------------|
| 🧩 **Nested models with [GraphPPL.jl](https://github.com/reactivebayes/GraphPPL.jl)**| 🌐 **Graph structure visualization** | 🔀 **Stochastic Processes** |
| 🔄 **Development of [ExponentialFamilyProjection.jl]()** | 🧠 **Automated inference with [ExponentialFamilyProjection.jl](https://github.com/reactivebayes/ExponentialFamilyProjection.jl)** | 🚀 **Robustness & Memory-efficiency** |
| 🔄 **Development of [ExponentialFamilyProjection.jl]()** | 🧠 **Automated inference with [ExponentialFamilyProjection.jl](https://github.com/reactivebayes/ExponentialFamilyProjection.jl)** | 🚀 **Robustness & Memory-efficiency** |

For a more granular view of our progress and ongoing tasks, check out our [project board](https://github.com/orgs/reactivebayes/projects/2/views/4) or join our 4-weekly [public meetings](https://dynalist.io/d/F4aA-Z2c8X-M1iWTn9hY_ndN).

Expand Down
2 changes: 2 additions & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Cairo = "159f3aea-2a34-519c-b102-8c37f9878175"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
ExponentialFamily = "62312e5e-252a-4322-ace9-a5f4bf9b357b"
ExponentialFamilyProjection = "17f509fa-9a96-44ba-99b2-1c5f01f0931b"
GraphPPL = "b3f8163a-e979-4e85-b43e-1f63d8c8b42c"
GraphPlot = "a2cc645c-3eea-5389-862e-a155d0052231"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Expand All @@ -12,6 +13,7 @@ ReactiveMP = "a194aa59-28ba-4574-a09c-4a745416d6e3"
Rocket = "df971d30-c9d6-4b37-b8ff-e965b2cb3a40"
RxInfer = "86711068-29c9-4ff7-b620-ae75d7495b3d"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd"

[compat]
Documenter = "1.0.0"
9 changes: 7 additions & 2 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,10 @@ foreach(vcat(ExamplesOverviewPath, ExamplesCategoriesOverviewPaths)) do path
@warn "`$(path)` does not exist. Generating an empty overview. Use the `make examples` command to generate the overview and all examples."
mkpath(dirname(path))
open(path, "w") do f
write(f, "The overview is missing. Use the `make examples` command to generate the overview and all examples.")
write(f, """
$(isequal(path, ExamplesOverviewPath) ? "# [Examples overview](@id examples-overview)" : "")
The overview is missing. Use the `make examples` command to generate the overview and all examples.
""")
end
end
end
Expand Down Expand Up @@ -108,7 +111,9 @@ makedocs(;
"Streamline inference" => "manuals/inference/streamlined.md",
"Initialization" => "manuals/inference/initialization.md",
"Auto-updates" => "manuals/inference/autoupdates.md",
"Deterministic nodes" => "manuals/inference/delta-node.md"
"Deterministic nodes" => "manuals/inference/delta-node.md",
"Non-conjugate inference" => "manuals/inference/nonconjugate.md",
"Undefined message update rules" => "manuals/inference/undefinedrules.md"
],
"Inference customization" => [
"Defining a custom node and rules" => "manuals/customization/custom-node.md",
Expand Down
2 changes: 2 additions & 0 deletions docs/src/manuals/customization/custom-node.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

Welcome to the `RxInfer` documentation on creating custom factor graph nodes. In `RxInfer`, factor nodes represent functional relationships between variables, also known as factors. Together, these factors define your probabilistic model. Quite often these factors represent distributions, denoting how a certain parameter affects another. However, other factors are also possible, such as ones specifying linear or non-linear relationships. `RxInfer` already supports a lot of factor nodes, however, depending on the problem that you are trying to solve, you may need to create a custom node that better fits the specific requirements of your model. This tutorial will guide you through the process of defining a custom node in `RxInfer`, step by step. By the end of this tutorial, you will be able to create your own custom node and integrate it into your model.

In addition, read another section on a different way of running inference with custom stochastic nodes without explicit rule specification [here](@ref inference-undefinedrules).

---

To create a custom node in `RxInfer`, 4 steps are required:
Expand Down
2 changes: 1 addition & 1 deletion docs/src/manuals/getting-started.md
Original file line number Diff line number Diff line change
Expand Up @@ -305,4 +305,4 @@ result.posteriors[:θ]

## Where to go next?

There are a set of [examples](@ref examples-overview) available in `RxInfer` repository that demonstrate the more advanced features of the package for various problems. Alternatively, you can head to the [Model specification](@ref user-guide-model-specification) which provides more detailed information of how to use `RxInfer` to specify probabilistic models. [Inference execution](@ref user-guide-inference-execution) section provides a documentation about `RxInfer` API for running reactive Bayesian inference. Also read the [Comparison](@ref comparison) to compare `RxInfer` with other probabilistic programming libraries.
There are a set of [examples](@ref examples-overview) available in `RxInfer` repository that demonstrate the more advanced features of the package for various problems. Alternatively, you can head to the [Model specification](@ref user-guide-model-specification) which provides more detailed information of how to use `RxInfer` to specify probabilistic models. [Inference execution](@ref user-guide-inference-execution) section provides a documentation about `RxInfer` API for running reactive Bayesian inference. Also read the [Comparison](@ref comparison) to compare `RxInfer` with other probabilistic programming libraries. For advances use cases refer to the [Non-conjugate inference](@ref inference-nonconjugate) tutorial and inference [without defining the message update rules explicitly](@ref inference-undefinedrules).
61 changes: 39 additions & 22 deletions docs/src/manuals/inference/delta-node.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,20 @@ RxInfer.jl offers a comprehensive set of stochastic nodes, primarily emphasizing

The delta node supports several approximation methods for probabilistic inference. The desired approximation method depends on the nodes connected to the delta node. We differentiate the following deterministic transformation scenarios:

1. **Gaussian Nodes**: For delta nodes linked to strictly multivariate or univariate Gaussian distributions, the recommended methods are Linearization or Unscented transforms.
2. **Exponential Family Nodes**: For the delta node connected to nodes from the exponential family, the CVI (Conjugate Variational Inference) is the method of choice.
3. **Stacking Delta Nodes**: For scenarios where delta nodes are stacked, either Linearization or Unscented transforms are suitable.
1. **Gaussian Nodes**: For delta nodes linked to strictly multivariate or univariate Gaussian distributions, the recommended methods are `Linearization` or `Unscented` transforms.
2. **Exponential Family Nodes**: For the delta node connected to nodes from the exponential family, the `CVIProjection` (Conjugate Variational Inference) is the method of choice.
3. **Stacking Delta Nodes**: For scenarios where delta nodes are stacked, either `Linearization`, `Unscented` or `CVIProjection` are suitable.
4. **Support for Inverse Functions**: For scenarious, where an inverse function is available

The table below summarizes the features of the delta node in RxInfer.jl, categorized by the approximation method:

| Methods | Gaussian Nodes | Exponential Family Nodes | Stacking Delta Nodes
|---------------|----------------|--------------------------|----------------------
| Linearization | ✓ | ✗ | ✓
| Unscented | ✓ | ✗ | ✓
| CVI | ✓ | ✓ | ✗
| Methods | Gaussian Nodes | Exponential Family Nodes | Stacking Delta Nodes | Inverse functions
|------------------|----------------|--------------------------|----------------------|----------------------
| Linearization | ✓ | ✗ | ✓ | ✓
| Unscented | ✓ | ✗ | ✓ | ✓
| CVI (deprecated) | ✓ | ✓ | ✗ | ✗
| CVI Projection | ✓ | ✓ | ✓ | ✗


## Gaussian Case

Expand All @@ -29,12 +32,15 @@ For clarity, consider the following example:
using RxInfer
@model function delta_node_example(z)
x ~ Normal(mean=0.0, var=1.0)
x ~ Normal(mean = 0.0, var = 1.0)
y := tanh(x)
z ~ Normal(mean=y, var=1.0)
z ~ Normal(mean = y, var = 1.0)
end
```

!!! note
While not strictly required, it is advised to use `:=` to define a deterministic relationship within the `@model` macro.

To perform inference on this model, designate the approximation method for the delta node (here, the `tanh` function) using the `@meta` specification:

```@example delta_node_example
Expand Down Expand Up @@ -62,21 +68,25 @@ end
To execute the inference procedure:

```@example delta_node_example
infer(model = delta_node_example(), meta=delta_meta, data = (z = 1.0,))
result = infer(
model = delta_node_example(),
meta = delta_meta,
data = (z = 1.0,)
)
```

This methodology is consistent even when the delta node is associated with multiple nodes. For instance:
This methodology is consistent even when the delta node is associated with multiple inputs. For instance:

```@example delta_node_example
f(x, g) = x*tanh(g)
```

```@example delta_node_example
@model function delta_node_example(z)
x ~ Normal(mean=1.0, var=1.0)
g ~ Normal(mean=1.0, var=1.0)
x ~ Normal(mean = 1.0, var = 1.0)
g ~ Normal(mean = 1.0, var = 1.0)
y := f(x, g)
z ~ Normal(mean=y, var=0.1)
z ~ Normal(mean = y, var = 0.1)
end
```

Expand Down Expand Up @@ -112,11 +122,14 @@ end

When the delta node is associated with nodes from the exponential family (excluding Gaussians), the `Linearization` and `Unscented` methods are not applicable. In such cases, the CVI (Conjugate Variational Inference) is available. Here's a modified example:

!!! note
The `CVIProjection` method is available only if `ExponentialFamilyProjection` package is installed in the current environment.

```@example delta_node_example_cvi
using RxInfer
using RxInfer, ExponentialFamilyProjection
@model function delta_node_example1(z)
x ~ Gamma(shape=1.0, rate=1.0)
x ~ Gamma(shape = 1.0, rate = 1.0)
y := tanh(x)
z ~ Bernoulli(y)
end
Expand All @@ -125,12 +138,16 @@ end
The corresponding meta specification can be represented as:

```@example delta_node_example_cvi
using StableRNGs
using Optimisers
delta_meta = @meta begin
tanh() -> DeltaMeta(method = CVI(StableRNG(42), 100, 100, Optimisers.Descent(0.01)))
tanh() -> CVIProjection()
end
```

Consult the `ProdCVI` docstrings for a detailed explanation of these parameters.
Consult the `CVIProjection` docstrings for a detailed explanation of its hyper-parameters. Additionally, read the [Non-conjugate Inference](@ref inference-nonconjugate) section.

!!! note
The `CVIProjection` method is an improved version of the now-deprecated `CVI` method. This new implementation features different hyperparameters, better accuracy, and improved stability.

## Fuse deterministic nodes with stochastic nodes

Read how to circumvent the need to define the meta structure and, instead, fuse the deterministic relation with a neighboring stochastic factor node in [this section](@ref inference-undefinedrules-fusedelta).
Loading

2 comments on commit 839400a

@bvdmitri
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/111354

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v3.5.0 -m "<description of version>" 839400a2f4758f56afbacb82a288ba86da2c9558
git push origin v3.5.0

Please sign in to comment.