From e1fbcedbc68b6f4d516b938c195ae07b46cdebff Mon Sep 17 00:00:00 2001 From: Will Constable Date: Wed, 30 Oct 2024 18:29:01 -0700 Subject: [PATCH] Add test for toml-based pp split points ghstack-source-id: fd34e39af69b4285e5c9e633e6a51f5d6f08ca48 Pull Request resolved: https://github.com/pytorch/torchtitan/pull/664 --- .ci/docker/dev-requirements.txt | 1 + test/test_job_config.py | 72 +++++++++++++++++++++++++++++++++ 2 files changed, 73 insertions(+) diff --git a/.ci/docker/dev-requirements.txt b/.ci/docker/dev-requirements.txt index 770301a0..bd611222 100644 --- a/.ci/docker/dev-requirements.txt +++ b/.ci/docker/dev-requirements.txt @@ -2,3 +2,4 @@ expecttest==0.1.6 pytest==7.3.2 pytest-cov pre-commit +tomli-w >= 1.1.0 diff --git a/test/test_job_config.py b/test/test_job_config.py index e4ef04ba..aed007ad 100644 --- a/test/test_job_config.py +++ b/test/test_job_config.py @@ -7,6 +7,7 @@ import tempfile import pytest +import tomli_w from torchtitan.config_manager import JobConfig @@ -44,6 +45,77 @@ def test_job_config_file_cmd_overrides(self): ) assert config.job.dump_folder == "/tmp/test_tt/" + def test_parse_pp_split_points(self): + + toml_splits = ["layers.2", "layers.4", "layers.6"] + toml_split_str = ",".join(toml_splits) + cmdline_splits = ["layers.1", "layers.3", "layers.5"] + cmdline_split_str = ",".join(cmdline_splits) + # no split points specified + config = JobConfig() + config.parse_args( + [ + "--job.config_file", + "./train_configs/debug_model.toml", + ] + ) + assert config.experimental.pipeline_parallel_split_points == [] + + # toml has no split points, but cmdline splits are specified + config = JobConfig() + config.parse_args( + [ + "--job.config_file", + "./train_configs/debug_model.toml", + "--experimental.pipeline_parallel_split_points", + f"{cmdline_split_str}", + ] + ) + assert ( + config.experimental.pipeline_parallel_split_points == cmdline_splits + ), config.experimental.pipeline_parallel_split_points + + # toml has split points, cmdline does not + with tempfile.NamedTemporaryFile() as fp: + with open(fp.name, "wb") as f: + tomli_w.dump( + { + "experimental": { + "pipeline_parallel_split_points": toml_split_str, + } + }, + f, + ) + config = JobConfig() + config.parse_args(["--job.config_file", fp.name]) + assert ( + config.experimental.pipeline_parallel_split_points == toml_splits + ), config.experimental.pipeline_parallel_split_points + + # toml has split points, cmdline overrides them + with tempfile.NamedTemporaryFile() as fp: + with open(fp.name, "wb") as f: + tomli_w.dump( + { + "experimental": { + "pipeline_parallel_split_points": toml_split_str, + } + }, + f, + ) + config = JobConfig() + config.parse_args( + [ + "--job.config_file", + fp.name, + "--experimental.pipeline_parallel_split_points", + f"{cmdline_split_str}", + ] + ) + assert ( + config.experimental.pipeline_parallel_split_points == cmdline_splits + ), config.experimental.pipeline_parallel_split_points + def test_print_help(self): config = JobConfig() parser = config.parser