Skip to content

Commit

Permalink
Add lora_path parameter to Llama model
Browse files Browse the repository at this point in the history
  • Loading branch information
abetlen committed Apr 18, 2023
1 parent 35abf89 commit eb7f278
Showing 1 changed file with 15 additions and 0 deletions.
15 changes: 15 additions & 0 deletions llama_cpp/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(
n_threads: Optional[int] = None,
n_batch: int = 8,
last_n_tokens_size: int = 64,
lora_path: Optional[str] = None,
verbose: bool = True,
):
"""Load a llama.cpp model from `model_path`.
Expand All @@ -57,6 +58,7 @@ def __init__(
n_threads: Number of threads to use. If None, the number of threads is automatically determined.
n_batch: Maximum number of prompt tokens to batch together when calling llama_eval.
last_n_tokens_size: Maximum number of tokens to keep in the last_n_tokens deque.
lora_path: Path to a LoRA file to apply to the model.
verbose: Print verbose output to stderr.
Raises:
Expand Down Expand Up @@ -108,6 +110,17 @@ def __init__(
self.model_path.encode("utf-8"), self.params
)

self.lora_path = None
if lora_path:
self.lora_path = lora_path
if llama_cpp.llama_apply_lora_from_file(
self.ctx,
self.lora_path.encode("utf-8"),
self.model_path.encode("utf-8"),
llama_cpp.c_int(self.n_threads),
):
raise RuntimeError(f"Failed to apply LoRA from path: {self.lora_path}")

if self.verbose:
print(llama_cpp.llama_print_system_info().decode("utf-8"), file=sys.stderr)

Expand Down Expand Up @@ -802,6 +815,7 @@ def __getstate__(self):
last_n_tokens_size=self.last_n_tokens_size,
n_batch=self.n_batch,
n_threads=self.n_threads,
lora_path=self.lora_path,
)

def __setstate__(self, state):
Expand All @@ -819,6 +833,7 @@ def __setstate__(self, state):
n_threads=state["n_threads"],
n_batch=state["n_batch"],
last_n_tokens_size=state["last_n_tokens_size"],
lora_path=state["lora_path"],
verbose=state["verbose"],
)

Expand Down

0 comments on commit eb7f278

Please sign in to comment.