Skip to content

Commit

Permalink
Fix cache max_seq_len (#568)
Browse files Browse the repository at this point in the history
* fix max_seq_len

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* another one

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix max new tokens

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
AnyaCoder and pre-commit-ci[bot] authored Sep 19, 2024
1 parent 711209e commit ad55185
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 7 deletions.
1 change: 1 addition & 0 deletions install_env.bat
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ call :download_and_install "triton_windows-0.1.0-py3-none-any.whl" ^

endlocal
echo "Environment Check: Success."
:end
pause

goto :EOF
Expand Down
4 changes: 2 additions & 2 deletions tools/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def inference(req: ServeTTSRequest):
compile=args.compile,
iterative_prompt=req.chunk_length > 0,
chunk_length=req.chunk_length,
max_length=2048,
max_length=4096,
prompt_tokens=prompt_tokens,
prompt_text=prompt_texts,
)
Expand Down Expand Up @@ -424,7 +424,7 @@ async def data(self) -> Annotated[Any, ContentType("application/msgpack")]:
text="Hello world.",
references=[],
reference_id=None,
max_new_tokens=1024,
max_new_tokens=0,
chunk_length=200,
top_p=0.7,
repetition_penalty=1.2,
Expand Down
20 changes: 17 additions & 3 deletions tools/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,16 @@ def generate(
# create an empty tensor of the expected final shape and fill in the current tokens
T = prompt.size(1)

if max_new_tokens:
if T + max_new_tokens > model.config.max_seq_len:
max_new_tokens = model.config.max_seq_len - T
logger.info(f"Truncating max_new_tokens to {max_new_tokens}")

T_new = T + max_new_tokens
else:
T_new = model.config.max_seq_len
max_new_tokens = T_new - T

device, dtype = prompt.device, prompt.dtype

codebook_dim = 1 + model.config.num_codebooks
Expand Down Expand Up @@ -565,7 +575,9 @@ def worker():
)
with torch.device(device):
model.setup_caches(
max_batch_size=1, max_seq_len=2048, dtype=next(model.parameters()).dtype
max_batch_size=1,
max_seq_len=model.config.max_seq_len,
dtype=next(model.parameters()).dtype,
)
init_event.set()

Expand Down Expand Up @@ -607,7 +619,7 @@ def worker():
multiple=True,
)
@click.option("--num-samples", type=int, default=1)
@click.option("--max-new-tokens", type=int, default=1024)
@click.option("--max-new-tokens", type=int, default=0)
@click.option("--top-p", type=float, default=0.7)
@click.option("--repetition-penalty", type=float, default=1.2)
@click.option("--temperature", type=float, default=0.7)
Expand Down Expand Up @@ -654,7 +666,9 @@ def main(
)
with torch.device(device):
model.setup_caches(
max_batch_size=1, max_seq_len=2048, dtype=next(model.parameters()).dtype
max_batch_size=1,
max_seq_len=model.config.max_seq_len,
dtype=next(model.parameters()).dtype,
)
if torch.cuda.is_available():
torch.cuda.synchronize()
Expand Down
4 changes: 2 additions & 2 deletions tools/webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def build_app():
label=i18n("Maximum tokens per batch, 0 means no limit"),
minimum=0,
maximum=2048,
value=1024, # 0 means no limit
value=0, # 0 means no limit
step=8,
)

Expand Down Expand Up @@ -505,7 +505,7 @@ def parse_args():
enable_reference_audio=False,
reference_audio=None,
reference_text="",
max_new_tokens=1024,
max_new_tokens=0,
chunk_length=200,
top_p=0.7,
repetition_penalty=1.2,
Expand Down

0 comments on commit ad55185

Please sign in to comment.