Skip to content

Commit

Permalink
Fix display
Browse files Browse the repository at this point in the history
  • Loading branch information
seanmor5 committed Oct 21, 2024
1 parent cc7dec6 commit a8d3149
Showing 1 changed file with 53 additions and 28 deletions.
81 changes: 53 additions & 28 deletions lib/axon/display.ex
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,24 @@ defmodule Axon.Display do
vertical_symbol: "|"
)
|> then(&(&1 <> "Total Parameters: #{model_info.num_params}\n"))
|> then(&(&1 <> "Total Parameters Memory: #{model_info.total_param_byte_size} bytes\n"))
|> then(
&(&1 <> "Total Parameters Memory: #{readable_size(model_info.total_param_byte_size)}\n")
)
end

defp readable_size(n) when n < 1_000, do: "#{n} bytes"

defp readable_size(n) when n >= 1_000 and n < 1_000_000,
do: "#{float_format(n / 1_000)} kilobytes"

defp readable_size(n) when n >= 1_000_000 and n < 1_000_000_000,
do: "#{float_format(n / 1_000_000)} megabytes"

defp readable_size(n) when n >= 1_000_000_000 and n < 1_000_000_000_000,
do: "#{float_format(n / 1_000_000_000)} gigabytes"

defp float_format(value), do: :io_lib.format("~.2f", [value])

defp assert_table_rex!(fn_name) do
unless Code.ensure_loaded?(TableRex) do
raise RuntimeError, """
Expand Down Expand Up @@ -93,7 +108,6 @@ defmodule Axon.Display do
defp do_axon_to_rows(
%Axon.Node{
id: id,
op: structure,
op_name: :container,
parent: [parents],
name: name_fn
Expand All @@ -104,7 +118,7 @@ defmodule Axon.Display do
op_counts,
model_info
) do
{input_names, {cache, op_counts, model_info}} =
{_, {cache, op_counts, model_info}} =
Enum.map_reduce(parents, {cache, op_counts, model_info}, fn
parent_id, {cache, op_counts, model_info} ->
{_, name, _shape, cache, op_counts, model_info} =
Expand All @@ -119,11 +133,11 @@ defmodule Axon.Display do
shape = Axon.get_output_shape(%Axon{output: id, nodes: nodes}, templates)

row = [
"#{name} ( #{op_string} #{inspect(apply(structure, input_names))} )",
"#{name} ( #{op_string} )",
"#{inspect({})}",
"#{inspect(shape)}",
render_output_shape(shape),
render_options([]),
render_parameters(%{}, [])
render_parameters(nil, %{}, [])
]

{row, name, shape, cache, op_counts, model_info}
Expand All @@ -136,7 +150,7 @@ defmodule Axon.Display do
parameters: params,
name: name_fn,
opts: opts,
policy: %{params: {_, bitsize}},
policy: %{params: params_policy},
op_name: op_name
},
nodes,
Expand All @@ -145,6 +159,12 @@ defmodule Axon.Display do
op_counts,
model_info
) do
bitsize =
case params_policy do
nil -> 32
{_, bitsize} -> bitsize
end

{input_names_and_shapes, {cache, op_counts, model_info}} =
Enum.map_reduce(parents, {cache, op_counts, model_info}, fn
parent_id, {cache, op_counts, model_info} ->
Expand All @@ -154,39 +174,34 @@ defmodule Axon.Display do
{{name, shape}, {cache, op_counts, model_info}}
end)

{input_names, input_shapes} = Enum.unzip(input_names_and_shapes)
{_, input_shapes} = Enum.unzip(input_names_and_shapes)

inputs =
Map.new(input_names_and_shapes, fn {name, shape} ->
{name, render_output_shape(shape)}
end)

num_params =
Enum.reduce(params, 0, fn
%Parameter{shape: {:tuple, shapes}}, acc ->
Enum.reduce(shapes, acc, &(Nx.size(apply(&1, input_shapes)) + &2))

%Parameter{shape: shape_fn}, acc ->
%Parameter{template: shape_fn}, acc when is_function(shape_fn) ->
acc + Nx.size(apply(shape_fn, input_shapes))
end)

param_byte_size = num_params * div(bitsize, 8)

op_inspect = Atom.to_string(op_name)

inputs =
case input_names do
[] ->
""

[_ | _] = input_names ->
"#{inspect(input_names)}"
end

name = name_fn.(op_name, op_counts)
shape = Axon.get_output_shape(%Axon{output: id, nodes: nodes}, templates)

row = [
"#{name} ( #{op_inspect}#{inputs} )",
"#{inspect(input_shapes)}",
"#{inspect(shape)}",
"#{name} ( #{op_inspect} )",
"#{inspect(inputs)}",
render_output_shape(shape),
render_options(opts),
render_parameters(params, input_shapes)
render_parameters(params_policy, params, input_shapes)
]

model_info =
Expand All @@ -200,6 +215,14 @@ defmodule Axon.Display do
{row, name, shape, cache, op_counts, model_info}
end

defp render_output_shape(%Nx.Tensor{} = template) do
type = type_str(Nx.type(template))
shape = shape_string(Nx.shape(template))
"#{type}#{shape}"
end

defp type_str({type, size}), do: "#{Atom.to_string(type)}#{size}"

defp render_options(opts) do
opts
|> Enum.map(fn {key, val} ->
Expand All @@ -209,21 +232,23 @@ defmodule Axon.Display do
|> Enum.join("\n")
end

defp render_parameters(params, input_shapes) do
defp render_parameters(policy, params, input_shapes) do
type = policy || {:f, 32}

params
|> Enum.map(fn
%Parameter{name: name, shape: {:tuple, shape_fns}} ->
shapes =
shape_fns
|> Enum.map(&apply(&1, input_shapes))
|> Enum.map(fn shape -> "f32#{shape_string(shape)}" end)
|> Enum.map(fn shape -> "#{type_str(type)}#{shape_string(shape)}" end)
|> List.to_tuple()

"#{name}: tuple#{inspect(shapes)}"

%Parameter{name: name, shape: shape_fn} ->
shape = apply(shape_fn, input_shapes)
"#{name}: f32#{shape_string(shape)}"
%Parameter{name: name, template: shape_fn} when is_function(shape_fn) ->
shape = Nx.shape(apply(shape_fn, input_shapes))
"#{name}: #{type_str(type)}#{shape_string(shape)}"
end)
|> Enum.join("\n")
end
Expand Down

0 comments on commit a8d3149

Please sign in to comment.