-
Notifications
You must be signed in to change notification settings - Fork 2
/
goal_observation.jl
278 lines (215 loc) · 10.7 KB
/
goal_observation.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
using DomainSets
using ForwardDiff: jacobian
using TupleTools: deleteat
using ReactiveMP: FunctionalDependencies, messagein, setmessage!, getlocalclusters, clusterindex, getmarginals
using Base.Broadcast: BroadcastFunction
import ReactiveMP: functional_dependencies
include("distributions.jl")
struct GoalObservation end
@node GoalObservation Stochastic [c, z, A]
#----------
# Modifiers
#----------
# Metas
struct BetheMeta{P} # Meta parameterized by x type for rule overloading
x::P # Pointmass value for observation
end
BetheMeta() = BetheMeta(missing) # Absent observation
struct GeneralizedMeta{P}
x::P # Pointmass value for observation
newton_iterations::Int64
end
GeneralizedMeta() = GeneralizedMeta(missing, 20)
GeneralizedMeta(point) = GeneralizedMeta(point, 20)
# Pipelines
struct BethePipeline <: FunctionalDependencies end
struct GeneralizedPipeline <: FunctionalDependencies
init_message::Union{Bernoulli, Categorical}
GeneralizedPipeline() = new() # If state is clamped, then no inital message is required
GeneralizedPipeline(init_message::Union{Bernoulli, Categorical}) = new(init_message)
end
function functional_dependencies(::BethePipeline, factornode, interface, iindex)
message_dependencies = ()
clusters = getlocalclusters(factornode)
marginal_dependencies = getmarginals(clusters) # Include all node-local marginals
return message_dependencies, marginal_dependencies
end
function functional_dependencies(pipeline::GeneralizedPipeline, factornode, interface, iindex)
clusters = getlocalclusters(factornode)
cindex = clusterindex(clusters, iindex) # Find the index of the cluster for the current interface
# Message dependencies
if (iindex === 2) # Message towards state
output = messagein(interface)
setmessage!(output, pipeline.init_message)
message_dependencies = (interface, )
else
message_dependencies = ()
end
# Marginal dependencies
if (iindex === 2) || (iindex === 3) # Message towards state or parameter
marginal_dependencies = getmarginals(clusters) # Include all marginals
else
marginal_dependencies = skipindex(getmarginals(clusters), cindex) # Skip current cluster
end
return message_dependencies, marginal_dependencies
end
#------------------------------
# Unobserved Bethe Update Rules
#------------------------------
@rule GoalObservation(:c, Marginalisation) (q_c::Union{Dirichlet, PointMass},
q_z::Union{Bernoulli, Categorical, PointMass},
q_A::Union{SampleList, MatrixDirichlet, PointMass},
meta::BetheMeta{Missing}) = begin
log_c = mean(BroadcastFunction(log), q_c)
z = probvec(q_z)
log_A = mean(BroadcastFunction(log), q_A)
# Compute internal marginal
x = softmax(log_A*z + log_c)
return Dirichlet(x .+ 1)
end
@rule GoalObservation(:z, Marginalisation) (q_c::Union{Dirichlet, PointMass},
q_z::Union{Bernoulli, Categorical},
q_A::Union{SampleList, MatrixDirichlet, PointMass},
meta::BetheMeta{Missing}) = begin
log_c = mean(BroadcastFunction(log), q_c)
z = probvec(q_z)
log_A = mean(BroadcastFunction(log), q_A)
# Compute internal marginal
x = softmax(log_A*z + log_c)
return Categorical(softmax(log_A'*x))
end
@rule GoalObservation(:A, Marginalisation) (q_c::Union{Dirichlet, PointMass},
q_z::Union{Bernoulli, Categorical, PointMass},
q_A::Union{SampleList, MatrixDirichlet, PointMass},
meta::BetheMeta{Missing}) = begin
log_c = mean(BroadcastFunction(log), q_c)
z = probvec(q_z)
log_A = mean(BroadcastFunction(log), q_A)
# Compute internal marginal
x = softmax(log_A*z + log_c)
return MatrixDirichlet(x*z' .+ 1)
end
@average_energy GoalObservation (q_c::Union{Dirichlet, PointMass},
q_z::Union{Bernoulli, Categorical, PointMass},
q_A::Union{SampleList, MatrixDirichlet, PointMass},
meta::BetheMeta{Missing}) = begin
log_c = mean(BroadcastFunction(log), q_c)
z = probvec(q_z)
log_A = mean(BroadcastFunction(log), q_A)
# Compute internal marginal
x = softmax(log_A*z + log_c)
return -x'*(log_A*z + log_c - safelog.(x))
end
#----------------------------
# Observed Bethe Update Rules
#----------------------------
@rule GoalObservation(:c, Marginalisation) (q_c::Union{Dirichlet, PointMass}, # Unused
q_z::Union{Bernoulli, Categorical, PointMass},
q_A::Union{SampleList, MatrixDirichlet, PointMass},
meta::BetheMeta{<:AbstractVector}) = begin
return Dirichlet(meta.x .+ 1)
end
@rule GoalObservation(:z, Marginalisation) (q_c::Union{Dirichlet, PointMass},
q_z::Union{Bernoulli, Categorical}, # Unused
q_A::Union{SampleList, MatrixDirichlet, PointMass},
meta::BetheMeta{<:AbstractVector}) = begin
log_A = mean(BroadcastFunction(log), q_A)
return Categorical(softmax(log_A'*meta.x))
end
@rule GoalObservation(:A, Marginalisation) (q_c::Union{Dirichlet, PointMass},
q_z::Union{Bernoulli, Categorical, PointMass},
q_A::Union{SampleList, MatrixDirichlet, PointMass}, # Unused
meta::BetheMeta{<:AbstractVector}) = begin
z = probvec(q_z)
return MatrixDirichlet(meta.x*z' .+ 1)
end
@average_energy GoalObservation (q_c::Union{Dirichlet, PointMass},
q_z::Union{Bernoulli, Categorical, PointMass},
q_A::Union{SampleList, MatrixDirichlet, PointMass},
meta::BetheMeta{<:AbstractVector}) = begin
log_c = mean(BroadcastFunction(log), q_c)
z = probvec(q_z)
log_A = mean(BroadcastFunction(log), q_A)
return -meta.x'*(log_A*z + log_c)
end
#------------------------------------
# Unobserved Generalized Update Rules
#------------------------------------
@rule GoalObservation(:c, Marginalisation) (q_z::Union{Bernoulli, Categorical, PointMass},
q_A::Union{SampleList, MatrixDirichlet, PointMass},
meta::GeneralizedMeta{Missing}) = begin
z = probvec(q_z)
A = mean(q_A)
return Dirichlet(A*z .+ 1)
end
@rule GoalObservation(:z, Marginalisation) (m_z::Union{Bernoulli, Categorical},
q_c::Union{Dirichlet, PointMass},
q_z::Union{Bernoulli, Categorical},
q_A::Union{SampleList, MatrixDirichlet, PointMass},
meta::GeneralizedMeta{Missing}) = begin
d = probvec(m_z)
log_c = mean(BroadcastFunction(log), q_c)
z_0 = probvec(q_z)
(A, h_A) = mean_h(q_A)
# Root-finding problem for marginal statistics
g(z) = z - softmax(-h_A + A'*log_c - A'*safelog.(A*z) + safelog.(d))
z_k = deepcopy(z_0)
for k=1:meta.newton_iterations
z_k = z_k - inv(jacobian(g, z_k))*g(z_k) # Newton step for multivariate root finding
end
# Compute outbound message statistics
rho = softmax(safelog.(z_k) - log.(d .+ 1e-6))
return Categorical(rho)
end
@rule GoalObservation(:A, Marginalisation) (q_c::Union{Dirichlet, PointMass},
q_z::Union{Bernoulli, Categorical, PointMass},
q_A::Union{SampleList, MatrixDirichlet, PointMass},
meta::GeneralizedMeta{Missing}) = begin
log_c = mean(BroadcastFunction(log), q_c)
z = probvec(q_z)
A_bar = mean(q_A)
M, N = size(A_bar)
log_mu(A) = (A*z)'*(log_c - safelog.(A_bar*z)) - z'*h(A)
return ContinuousMatrixvariateLogPdf((RealNumbers()^M, RealNumbers()^N), log_mu)
end
@average_energy GoalObservation (q_c::Union{Dirichlet, PointMass},
q_z::Union{Bernoulli, Categorical, PointMass},
q_A::Union{SampleList, MatrixDirichlet, PointMass},
meta::GeneralizedMeta{Missing}) = begin
log_c = mean(BroadcastFunction(log), q_c)
z = probvec(q_z)
(A, h_A) = mean_h(q_A)
return z'*h_A - (A*z)'*(log_c - safelog.(A*z))
end
#----------------------------------
# Observed Generalized Update Rules
#----------------------------------
@rule GoalObservation(:c, Marginalisation) (q_z::Union{Bernoulli, Categorical, PointMass}, # Unused
q_A::Union{SampleList, MatrixDirichlet, PointMass}, # Unused
meta::GeneralizedMeta{<:AbstractVector}) = begin
return Dirichlet(meta.x .+ 1)
end
@rule GoalObservation(:z, Marginalisation) (m_z::Union{Bernoulli, Categorical}, # Unused
q_c::Union{Dirichlet, PointMass}, # Unused
q_z::Union{Bernoulli, Categorical}, # Unused
q_A::Union{SampleList, MatrixDirichlet, PointMass},
meta::GeneralizedMeta{<:AbstractVector}) = begin
log_A = clamp.(mean(BroadcastFunction(log), q_A), -12, 12)
return Categorical(softmax(log_A'*meta.x))
end
@rule GoalObservation(:A, Marginalisation) (q_c::Union{Dirichlet, PointMass}, # Unused
q_z::Union{Bernoulli, Categorical, PointMass},
q_A::Union{SampleList, MatrixDirichlet, PointMass}, # Unused
meta::GeneralizedMeta{<:AbstractVector}) = begin
z = probvec(q_z)
return MatrixDirichlet(meta.x*z' .+ 1)
end
@average_energy GoalObservation (q_c::Union{Dirichlet, PointMass},
q_z::Union{Bernoulli, Categorical, PointMass},
q_A::Union{SampleList, MatrixDirichlet, PointMass},
meta::GeneralizedMeta{<:AbstractVector}) = begin
log_c = mean(BroadcastFunction(log), q_c)
z = probvec(q_z)
log_A = clamp.(mean(BroadcastFunction(log), q_A), -12, 12)
return -meta.x'*(log_A*z + log_c)
end