Skip to content

Commit

Permalink
batch inference, tp wip
Browse files Browse the repository at this point in the history
  • Loading branch information
jaysonfrancis committed Oct 30, 2024
1 parent afa54b7 commit 93685bc
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 12 deletions.
7 changes: 3 additions & 4 deletions test/generate/run_llama_pred.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@ set -ex
# LOG_RANK=0,1 NGPU=4 ./run_llama_train.sh
NGPU=${NGPU:-"2"}
LOG_RANK=${LOG_RANK:-0,1}
# CONFIG_FILE=${CONFIG_FILE:-"./train_configs/debug_model.toml"}
CONFIG_FILE=${CONFIG_FILE:-"./train_configs/llama3_3b.toml"}
CONFIG_FILE=${CONFIG_FILE:-"./train_configs/debug_model.toml"}
CHECKPOINT_DIR=${CHECKPOINT_DIR:-"./outputs/checkpoint/"}
PROMPT=${PROMPT:-"Hello!"}

Expand All @@ -22,8 +21,8 @@ if [ $# -ne 0 ]; then
overrides="$*"
fi

export NCCL_DEBUG=INFO
export NCCL_DEBUG_SUBSYS=ALL
# export NCCL_DEBUG=INFO
# export NCCL_DEBUG_SUBSYS=ALL
# export NCCL_BLOCKING_WAIT=1
# export NCCL_ASYNC_ERROR_HANDLING=1

Expand Down
16 changes: 9 additions & 7 deletions test/generate/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def example_generate(
*,
device: str = "cuda",
temperature: float = 1.0,
max_generated_tokens: int = 32,
max_new_tokens: int = 32,
top_k: Optional[int] = None,
seed: Optional[int] = None,
):
Expand Down Expand Up @@ -82,9 +82,8 @@ def example_generate(
model,
input_ids,
temperature=temperature,
max_generated_tokens=max_generated_tokens,
max_new_tokens=max_new_tokens,
top_k=top_k,
# seed=seed,
)

logger.info(f"Generation completed in {time.monotonic() - begin:.2f} seconds.")
Expand Down Expand Up @@ -115,7 +114,6 @@ def example_generate(
"--device",
type=str,
default="cuda",
choices=["cpu", "cuda"],
help="Device to load model on. Default is 'cuda'",
)
parser.add_argument(
Expand All @@ -125,21 +123,25 @@ def example_generate(
help="Sampling temperature. Default is 1.0",
)
parser.add_argument(
"--max_generated_tokens",
"--max_new_tokens",
type=int,
default=32,
help="Max number of tokens to generate. Default is 32",
)
parser.add_argument(
"--batch_size", type=int, default=1, help="Number of samples to run in batch"
)
parser.add_argument(
"--top_k", type=int, help="Prune to select from top_k probabilities. Optional"
)
parser.add_argument(
"--seed", type=int, default=42, help="Random seed for reproducibility"
)

parser.add_argument(
"--prompt",
type=str,
default="Hello! How are you?",
default="Hello! How are",
help="Input prompt for generation",
)

Expand All @@ -151,7 +153,7 @@ def example_generate(
prompt=args.prompt,
device=args.device,
temperature=args.temperature,
max_generated_tokens=args.max_generated_tokens,
max_new_tokens=args.max_new_tokens,
top_k=args.top_k,
seed=args.seed,
)
1 change: 0 additions & 1 deletion test/generate/test_generate_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def example_generate(
batch_size: int = 1,
top_k: Optional[int] = None,
seed: Optional[int] = None,
save_path: Optional[str] = None,
):
init_logger()
color = utils.Color
Expand Down

0 comments on commit 93685bc

Please sign in to comment.