Skip to content

Commit

Permalink
Add test for toml-based pp split points
Browse files Browse the repository at this point in the history
ghstack-source-id: fd34e39af69b4285e5c9e633e6a51f5d6f08ca48
Pull Request resolved: #664
  • Loading branch information
wconstab committed Oct 31, 2024
1 parent 2a785e9 commit e1fbced
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 0 deletions.
1 change: 1 addition & 0 deletions .ci/docker/dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ expecttest==0.1.6
pytest==7.3.2
pytest-cov
pre-commit
tomli-w >= 1.1.0
72 changes: 72 additions & 0 deletions test/test_job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import tempfile

import pytest
import tomli_w
from torchtitan.config_manager import JobConfig


Expand Down Expand Up @@ -44,6 +45,77 @@ def test_job_config_file_cmd_overrides(self):
)
assert config.job.dump_folder == "/tmp/test_tt/"

def test_parse_pp_split_points(self):

toml_splits = ["layers.2", "layers.4", "layers.6"]
toml_split_str = ",".join(toml_splits)
cmdline_splits = ["layers.1", "layers.3", "layers.5"]
cmdline_split_str = ",".join(cmdline_splits)
# no split points specified
config = JobConfig()
config.parse_args(
[
"--job.config_file",
"./train_configs/debug_model.toml",
]
)
assert config.experimental.pipeline_parallel_split_points == []

# toml has no split points, but cmdline splits are specified
config = JobConfig()
config.parse_args(
[
"--job.config_file",
"./train_configs/debug_model.toml",
"--experimental.pipeline_parallel_split_points",
f"{cmdline_split_str}",
]
)
assert (
config.experimental.pipeline_parallel_split_points == cmdline_splits
), config.experimental.pipeline_parallel_split_points

# toml has split points, cmdline does not
with tempfile.NamedTemporaryFile() as fp:
with open(fp.name, "wb") as f:
tomli_w.dump(
{
"experimental": {
"pipeline_parallel_split_points": toml_split_str,
}
},
f,
)
config = JobConfig()
config.parse_args(["--job.config_file", fp.name])
assert (
config.experimental.pipeline_parallel_split_points == toml_splits
), config.experimental.pipeline_parallel_split_points

# toml has split points, cmdline overrides them
with tempfile.NamedTemporaryFile() as fp:
with open(fp.name, "wb") as f:
tomli_w.dump(
{
"experimental": {
"pipeline_parallel_split_points": toml_split_str,
}
},
f,
)
config = JobConfig()
config.parse_args(
[
"--job.config_file",
fp.name,
"--experimental.pipeline_parallel_split_points",
f"{cmdline_split_str}",
]
)
assert (
config.experimental.pipeline_parallel_split_points == cmdline_splits
), config.experimental.pipeline_parallel_split_points

def test_print_help(self):
config = JobConfig()
parser = config.parser
Expand Down

0 comments on commit e1fbced

Please sign in to comment.