Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DAG construction with local variables #253

Closed
cscherrer opened this issue Apr 9, 2021 · 2 comments · Fixed by #260
Closed

DAG construction with local variables #253

cscherrer opened this issue Apr 9, 2021 · 2 comments · Fixed by #260

Comments

@cscherrer
Copy link
Owner

I had brought this up in #245, but I think it's really a different issue, so I'm splitting it.

The problem is in models like

m = @model begin
    a ~ For(3) do x Normal=x) end
    x ~ Normal=sum(a))
end

Soss thinks the graph is

julia> digraph(m).N
Dict{Symbol, Set{Symbol}} with 2 entries:
  :a => Set([:x])
  :x => Set([:a])

But this is wrong! The x referred to in a ~ ... is a local variable.

I think JuliaVariables.jl ought to be able to help with this. The idea is

  1. Get the rhs for each statement
  2. solve each rhs, which adds @local and @global annotations
  3. Extract the set of @global ones, which should give the true dependencies

Here's my first attempt:

julia> using JuliaVariables

julia> using MacroTools: prettify

julia> m.dists.a |> solve_from_local |>  prettify |> unwrap_scoped
:((@global For)(3) do @global x
      @global Normal
      $(Expr(:kw, , @global x))
  end)

julia> m.dists.x |> solve_from_local |>  prettify |> unwrap_scoped
:((@global Normal)(μ = (@global sum)(@global a)))

I'm not sure what's going on in that @global x, since x is clearly a local variable. My guess is that the do notation is throwing off the solver and needs to be rewritten.

@thautwarm
Copy link
Collaborator

call simplify_ex once before calling solve_xxx. JuliaVariables does not handle Expr(:(->), but Expr(:function

@cscherrer
Copy link
Owner Author

Got it, thank you @thautwarm !

julia> unwrap_scoped(ex) =
                  @match ex begin
                      Expr(:scoped, _, a) => unwrap_scoped(a)
                      Expr(head, args...) => Expr(head, map(unwrap_scoped, args)...)
                      a => a
                  end
unwrap_scoped (generic function with 1 method)

julia> m = @model begin
           a ~ For(3) do x Normal=x) end
           x ~ Normal=sum(a))
       end;

julia> m.dists.a  |>  MacroTools.prettify |> simplify_ex |>  solve_from_local  |> unwrap_scoped
:((@global For)(function (x,)
          (@global Normal)(μ = @local x)
      end, 3))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants