Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add shard execution workflow #1557

Open
wants to merge 11 commits into
base: pv-feat/experimental-sharding-backend
Choose a base branch
from
4 changes: 4 additions & 0 deletions nx/config/config.exs
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,7 @@ import Config
# true inside Nx.
config :nx, :verify_grad, true
config :nx, :verify_binary_size, true

# If set to true, shards and sharding stages will be
# inspected with their debug ids alongside their unique ref ids
config :nx, :debug_shards, true
1 change: 1 addition & 0 deletions nx/lib/nx/application.ex
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ defmodule Nx.Application do

def start(_type, _args) do
children = [
Nx.Defn.ShardingCompiler.ShardRegistry,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure we want to have this here, actually.
In fact, I think we might actually want to go with gen_stage for the execution, since the whole "chain of processes producing data to one another" smells a lot like gen_stage.

Thoughts?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should talk about it but I doubt GenStage will be helpful here. One of the biggest pitfalls in GenStage is that people move data around too much, when they should not. It is cheaper to move computations than to move data.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Plus the whole demand approach is unnecessary here. Here is either pending or done (like a promise), no?

%{id: Nx.Serving.PG, start: {:pg, :start_link, [Nx.Serving.PG]}},
{Nx.HiddenServing, Nx.Serving.PG}
]
Expand Down
19 changes: 9 additions & 10 deletions nx/lib/nx/defn/sharding_compiler.ex
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,13 @@ defmodule Nx.Defn.ShardingCompiler do

[args] = args

%T{
shape: shape,
type: type,
data: %ShardPropagation{
shards: output_shards,
parameter_ids_to_index: parameter_ids_to_index
}
} =
{%T{
type: type,
data: %ShardPropagation{
shards: output_shards
}
}, parameter_ids_to_index,
shape} =
propagate_shards(vars, fun, opts[:sharding_config] || [])

data_sections =
Expand Down Expand Up @@ -152,9 +151,9 @@ defmodule Nx.Defn.ShardingCompiler do
|> Enum.with_index(fn x, idx -> {idx, x} end)
|> Map.new()

{container, _cache, _state} = ShardPropagation.traverse(expr, tensor_shardings)
{container, _cache, state} = ShardPropagation.traverse(expr, tensor_shardings)

container
{container, state.parameter_ids_to_index, expr.shape}
end

@impl true
Expand Down
149 changes: 109 additions & 40 deletions nx/lib/nx/defn/sharding_compiler/passes/graph_splitter.ex
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,23 @@ defmodule Nx.Defn.ShardingCompiler.Passes.GraphSplitter do
alias Nx.Tensor, as: T
alias Nx.Defn.Expr
alias Nx.Defn.ShardingCompiler.Shard
alias Nx.Defn.ShardingCompiler.Passes.GraphSplitter.Stage

@gather_ops [:dot]
@reduction_ops [:sum]

def traverse(expr, expr_shards \\ %{}) do
@ops_to_split Map.merge(
Map.new(@gather_ops, &{&1, :gather}),
Map.new(@reduction_ops, &{&1, :reduce})
)

def traverse(expr, expr_shards \\ %{}, ops_to_split \\ @ops_to_split) do
# expression_chain is going to be a reverse-accumulation of {category, subexpr}
# that we can then compile and chain-execute elsewhere. category is either :gather, :reduce or :none
state = %{
expression_chain: [],
nodes_to_replace: %{},
ops_to_split: ops_to_split,
# contains the sharding configuration for each node by id
shards: expr_shards,
# args is a map of id -> {stage_id, output_container_position}
Expand Down Expand Up @@ -54,62 +61,64 @@ defmodule Nx.Defn.ShardingCompiler.Passes.GraphSplitter do
{id, {expr, nil}}, idx ->
{id, put_in(expr.data.args, [idx])}

{id, {expr, shard_propagation}}, idx ->
{id, {expr, _shard_propagation}}, idx ->
expr = put_in(expr.data.args, [idx])
expr = Expr.metadata(expr, %{shards: shard_propagation.shards})
{id, expr}
end)
|> Map.new()

{expr, _} =
composite_rewrite_subtree(expr, %{state | nodes_to_replace: arg_remapping})

expr =
Composite.traverse(expr, fn
%T{data: %Expr{id: id}} = t ->
if shard_propagation = state.shards[id] do
Expr.metadata(t, %{shards: shard_propagation.shards})
else
t
end

other ->
other
# Traverse the expression to remap all shapes according to the sharding given
expr = set_shard_metadata(expr, state.shards)

arguments =
Map.new(arg_remapping, fn {_id, arg_expr} ->
{arg_expr.data.id, set_shard_metadata(arg_expr, state.shards)}
end)

argument_sources = Map.take(state.args, Map.keys(arg_remapping))
argument_sources =
state.args
|> Map.take(Map.keys(arg_remapping))
|> Map.new(fn {remap_id, v} ->
{arg_remapping[remap_id].data.id, v}
end)

[{id, category, expr, argument_sources} | acc]
[
%Stage{
id: id,
category: category,
expr: expr,
arguments: arguments,
argument_sources: argument_sources
}
| acc
]
end
)

{expr_chain, Map.delete(state, :expression_chain), cache}
{expr_chain, cache, Map.delete(state, :expression_chain)}
end

defp composite_eval(expr, state, cache) do
Composite.traverse(expr, {cache, state}, &eval/2)
end

defp eval(%T{data: %Expr{id: id, op: op}} = ans, {cache, state}) do
case {cache, state.nodes_to_replace} do
{_, %{^id => res}} ->
case {cache, state.nodes_to_replace, state.ops_to_split} do
{_, %{^id => res}, _} ->
# Replace the node with the corresponding parameter
{res, {Map.put(cache, id, res), state}}

{%{^id => res}, _} ->
{%{^id => res}, _, _} ->
{res, {cache, state}}

{_, _} ->
cond do
op in @gather_ops ->
rewrite_args(ans, :gather, {cache, state})

op in @reduction_ops ->
rewrite_args(ans, :reduce, {cache, state})
{_, _, %{^op => category}} ->
rewrite_args(ans, category, {cache, state})

true ->
eval_apply(op, ans, {cache, state})
end
_ ->
eval_apply(op, ans, {cache, state})
end
end

Expand Down Expand Up @@ -203,8 +212,8 @@ defmodule Nx.Defn.ShardingCompiler.Passes.GraphSplitter do
{new_expr, {cache, state}}
end

defp eval_apply(:parameter, %T{data: %Expr{id: id}} = expr, {cache, state}) do
state = put_in(state.args[id], nil)
defp eval_apply(:parameter, %T{data: %Expr{id: id, args: [idx]}} = expr, {cache, state}) do
state = put_in(state.args[id], {nil, idx})
{expr, {Map.put(cache, id, expr), state}}
end

Expand All @@ -220,19 +229,26 @@ defmodule Nx.Defn.ShardingCompiler.Passes.GraphSplitter do
{ans, {Map.put(cache, id, ans), state}}
end

defp composite_rewrite_subtree(args, state, acc \\ %{used_args: %{}})
defp composite_rewrite_subtree(container, state, acc \\ %{used_args: %{}})

defp composite_rewrite_subtree(args, state, acc) when is_list(args) do
Enum.map_reduce(args, acc, fn
defp composite_rewrite_subtree(container, state, acc) when is_list(container) do
Enum.map_reduce(container, acc, fn
%T{} = arg, acc ->
composite_rewrite_subtree(arg, state, acc)

arg, acc when is_list(arg) ->
composite_rewrite_subtree(arg, state, acc)

arg, acc ->
{arg, acc}
end)
end

defp composite_rewrite_subtree(%T{data: %Expr{id: id, op: :parameter}} = expr, state, acc) do
defp composite_rewrite_subtree(container, state, acc) do
Composite.traverse(container, acc, &rewrite_subtree(&1, state, &2))
end

defp rewrite_subtree(%T{data: %Expr{id: id, op: :parameter}} = expr, state, acc) do
case state.nodes_to_replace do
%{^id => res} ->
{res, put_in(acc.used_args[id], {res, state.shards[id]})}
Expand All @@ -242,22 +258,75 @@ defmodule Nx.Defn.ShardingCompiler.Passes.GraphSplitter do
end
end

defp composite_rewrite_subtree(arg, state, acc) do
Composite.traverse(arg, acc, &rewrite_subtree(&1, state, &2))
defp rewrite_subtree(
%T{data: %Expr{op: :optional, id: id, args: [call, subexpr, fun]}} = expr,
state,
acc
) do
case state.nodes_to_replace do
%{^id => res} ->
{res, put_in(acc.used_args[id], {res, state.shards[id]})}

_ ->
{call, acc} = rewrite_subtree(call, state, acc)
# `subexpr` is hermetic, in the sense that it is a self-contained scope
# from which the arguments always come from `call`, so we can
# keep it as is.

{put_in(expr.data.args, [call, subexpr, fun]), acc}
end
end

defp rewrite_subtree(%T{data: %Expr{id: id, args: args}} = expr, state, acc) do
case state.nodes_to_replace do
%{^id => res} ->
# nodes_to_replace always contains a param
{res, put_in(acc.used_args[id], res)}
{res, put_in(acc.used_args[id], {res, state.shards[id]})}

_ ->
{args, acc} = composite_rewrite_subtree(args, state, acc)

{put_in(expr.data.args, args), acc}
end
end

defp rewrite_subtree(other, _, acc), do: {other, acc}

defp set_shard_metadata(expr, shards) do
Composite.traverse(expr, fn
%T{data: %Expr{id: id}} = t ->
if shard_propagation = shards[id] do
shape =
shard_propagation.shards
|> Enum.sort()
|> Enum.map(fn {_axis, [%Shard{length: length} | _]} -> length end)
|> List.to_tuple()

t = do_set_shard_metadata(%{t | shape: shape}, shards)
Expr.metadata(t, %{shards: shard_propagation.shards})
else
do_set_shard_metadata(t, shards)
end

other ->
other
end)
end

defp do_set_shard_metadata(%T{data: %Expr{args: args}} = expr, shards) do
args =
Enum.map(args, fn
%T{} = arg ->
set_shard_metadata(arg, shards)

arg when is_list(arg) ->
Enum.map(arg, &do_set_shard_metadata(&1, shards))

arg ->
arg
end)

put_in(expr.data.args, args)
end

defp do_set_shard_metadata(other, _), do: other
end
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
defmodule Nx.Defn.ShardingCompiler.Passes.GraphSplitter.Stage do
defstruct [:id, :category, :expr, :arguments, :argument_sources]
end
10 changes: 4 additions & 6 deletions nx/lib/nx/defn/sharding_compiler/passes/shard_propagation.ex
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ defmodule Nx.Defn.ShardingCompiler.Passes.ShardPropagation do

alias Nx.Defn.ShardingCompiler.Shard

defstruct [:id, :shards, :input_tensor_shardings, :parameter_ids_to_index, :expr]
defstruct [:id, :shards, :expr]

def traverse(expr, tensor_shardings) do
{container, {cache, state}} =
Expand All @@ -19,9 +19,6 @@ defmodule Nx.Defn.ShardingCompiler.Passes.ShardPropagation do
%{}
)

container = put_in(container.data.input_tensor_shardings, tensor_shardings)
container = put_in(container.data.parameter_ids_to_index, state.parameter_ids_to_index)

{container, cache, state}
end

Expand Down Expand Up @@ -53,7 +50,7 @@ defmodule Nx.Defn.ShardingCompiler.Passes.ShardPropagation do
t
|> Nx.axes()
|> Map.new(fn axis ->
{axis, [0..(elem(t.shape, axis) - 1)]}
{axis, elem(t.shape, axis)}
end)

expr = shard_from_config(t, config)
Expand All @@ -62,7 +59,7 @@ defmodule Nx.Defn.ShardingCompiler.Passes.ShardPropagation do
end

defp eval(%T{data: %Expr{op: :constant, args: [_constant]}} = ans, {cache, state}) do
expr = shard_from_config(ans, %{0 => [0..0]})
expr = shard_from_config(ans, %{})
state = put_in(state.expr_shards[expr.data.id], expr.data)
{expr, {cache, state}}
end
Expand Down Expand Up @@ -361,6 +358,7 @@ defmodule Nx.Defn.ShardingCompiler.Passes.ShardPropagation do
defp resolve_sharding_broadcast(axis, left_shards, false, right_shards, false) do
# We have a shard on both sides. We need to determine the intersection of the two.
# This is fine only if all shards are equal

{reverse_out_shards, all_shards_match} =
Enum.zip_reduce(left_shards, right_shards, {[], true}, fn left,
right,
Expand Down
Loading
Loading