Skip to content

Commit

Permalink
removed # pylint: disable=useless-parent-delegation in FinetuneDialog…
Browse files Browse the repository at this point in the history
…Agent
  • Loading branch information
zyzhang1130 committed Aug 23, 2024
1 parent 19ed176 commit 99cecb6
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,14 @@ def __init__(
Note:
Refer to `class DialogAgent(AgentBase)` for more information.
"""
# pylint: disable=useless-parent-delegation
super().__init__(
name,
sys_prompt,
model_config_name,
use_memory,
memory_config,
)
self.finetune = True

def load_model(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@
Features include model and tokenizer loading,
and fine-tuning on the lima dataset with adjustable parameters.
"""

# This import is necessary for AgentScope to properly use
# HuggingFaceWrapper even though it's not explicitly used in this file.
# To remove the pylint disable without causing issues
# HuggingFaceWrapper needs to be put under src/agentscope/agents.
# pylint: disable=unused-import
from huggingface_model import HuggingFaceWrapper
from FinetuneDialogAgent import FinetuneDialogAgent
Expand Down Expand Up @@ -52,6 +55,7 @@ def main() -> None:
# loading a model saved as lora model
"fine_tune_config": {
"continue_lora_finetuning": False,
"max_seq_length": 4096,
"lora_config": {
"r": 16,
"lora_alpha": 32,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,13 @@ def __init__(
fine-tuning the model.
**kwargs: Additional keyword arguments.
"""
super().__init__(config_name=config_name)
super().__init__(
config_name=config_name,
model_name=pretrained_model_name_or_path,
)
self.model = None
self.tokenizer = None
self.max_length = max_length # Set max_length as an attribute
self.max_length = max_length
self.pretrained_model_name_or_path = pretrained_model_name_or_path
self.local_model_path = local_model_path
self.lora_config = None
Expand Down Expand Up @@ -358,6 +361,8 @@ def load_tokenizer(
f" from '{local_model_path}'",
)
self.tokenizer.padding_side = "right"
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token

except Exception as e:
# Handle exceptions during model loading,
Expand Down Expand Up @@ -666,10 +671,13 @@ def fine_tune_training(
"optim": "paged_adamw_8bit",
"logging_steps": 1,
}
max_seq_length_default = 4096

lora_config_default = {}

if fine_tune_config is not None:
if fine_tune_config.get("max_seq_length") is not None:
max_seq_length_default = fine_tune_config["max_seq_length"]
if fine_tune_config.get("training_args") is not None:
training_defaults.update(fine_tune_config["training_args"])
if fine_tune_config.get("lora_config") is not None:
Expand Down Expand Up @@ -709,7 +717,7 @@ def fine_tune_training(
else {}
),
args=trainer_args,
max_seq_length=4096,
max_seq_length=max_seq_length_default,
)

logger.info(
Expand All @@ -718,14 +726,7 @@ def fine_tune_training(
timestamp=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
)

try:
trainer.train()
except Exception as e:
import traceback

logger.error(f"Error during training: {e}")
traceback.print_exc()
raise
trainer.train()

now = datetime.now()
time_string = now.strftime("%Y-%m-%d_%H-%M-%S")
Expand Down

0 comments on commit 99cecb6

Please sign in to comment.