From f664eb861396f9724331daf257c4f2770f47943e Mon Sep 17 00:00:00 2001 From: luchun <71970539+zhanghy-sketchzh@users.noreply.github.com> Date: Fri, 4 Aug 2023 15:08:15 +0800 Subject: [PATCH 1/4] Update sql_data_process.py --- dbgpt_hub/utils/sql_data_process.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dbgpt_hub/utils/sql_data_process.py b/dbgpt_hub/utils/sql_data_process.py index 378a8c3..6c751c1 100644 --- a/dbgpt_hub/utils/sql_data_process.py +++ b/dbgpt_hub/utils/sql_data_process.py @@ -905,10 +905,10 @@ def serialize_schema_natural_language( # get the table and column of x db_column_names_table_id =[x["table_id"] for x in db_column_names] x_table_name = db_table_name_strs[db_column_names_table_id[x]] - x_column_name = db_column_name_strs[x] + x_column_name = db_column_name_strs[x-1] # get the table and column of y y_table_name = db_table_name_strs[db_column_names_table_id[y]] - y_column_name = db_column_name_strs[y] + y_column_name = db_column_name_strs[y-1] foreign_key_description_str = foreign_key_description(x_table_name, x_column_name, y_table_name, y_column_name) descriptions.append(foreign_key_description_str) return " ".join(descriptions) From 4c53d4fc3ca33afc3cb504409936ba6cc0114cba Mon Sep 17 00:00:00 2001 From: luchun <71970539+zhanghy-sketchzh@users.noreply.github.com> Date: Fri, 4 Aug 2023 15:10:11 +0800 Subject: [PATCH 2/4] Update gen_args.py --- dbgpt_hub/configs/gen_args.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/dbgpt_hub/configs/gen_args.py b/dbgpt_hub/configs/gen_args.py index 1d0f012..d8512f0 100644 --- a/dbgpt_hub/configs/gen_args.py +++ b/dbgpt_hub/configs/gen_args.py @@ -8,7 +8,7 @@ class GenerationArguments: # https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig # Length arguments max_new_tokens: Optional[int] = field( - default=256, + default=128, metadata={"help": "Maximum number of new tokens to be generated in evaluation or prediction loops" "if predict_with_generate is set."} ) @@ -16,7 +16,7 @@ class GenerationArguments: default=None, metadata={"help": "Minimum number of new tokens to generate."} ) - + # Generation strategy do_sample: Optional[bool] = field(default=False) num_beams: Optional[int] = field(default=1) @@ -32,4 +32,10 @@ class GenerationArguments: diversity_penalty: Optional[float] = field(default=0.0) repetition_penalty: Optional[float] = field(default=1.0) length_penalty: Optional[float] = field(default=1.0) - no_repeat_ngram_size: Optional[int] = field(default=0) \ No newline at end of file + no_repeat_ngram_size: Optional[int] = field(default=0) + + def to_dict(self) -> Dict[str, Any]: + args = asdict(self) + if args.get('max_new_tokens', None): + args.pop('max_length', None) + return args From 96d1231248a3f27f0287b1d8c953990c82af230a Mon Sep 17 00:00:00 2001 From: luchun <71970539+zhanghy-sketchzh@users.noreply.github.com> Date: Fri, 4 Aug 2023 15:13:32 +0800 Subject: [PATCH 3/4] Update predict_lora.py --- predict_lora.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/predict_lora.py b/predict_lora.py index a3cb482..bb6b0ab 100644 --- a/predict_lora.py +++ b/predict_lora.py @@ -12,14 +12,16 @@ SQL_PROMPT_DICT = { "prompt_input": ( - "I want you to act as a SQL terminal in front of an example database, - you need to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request. - The instruction is {instruction}, So please tell me {input} Response:" + "I want you to act as a SQL terminal in front of an example database, \ + you need only to return the sql command to me.Below is an instruction that describes a task, \ + Write a response that appropriately completes the request. \ + The instruction is {instruction}, So please tell me {input}, ###Response:" ), "prompt_no_input": ( - "I want you to act as a SQL terminal in front of an example database, - you need to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request. - The instruction is {instruction}, Response:" + "I want you to act as a SQL terminal in front of an example database, \ + you need only to return the sql command to me.Below is an instruction that describes a task, \ + Write a response that appropriately completes the request. \ + The instruction is {instruction}, ###Response:" ), } def extract_sql_dataset(example): @@ -132,10 +134,6 @@ def predict(): # result.append(response.replace("\n", "")) return result - - - - if __name__ == "__main__": result = predict() From 2be3890cd00d267cd4c150c91e354619c4333c7b Mon Sep 17 00:00:00 2001 From: luchun <71970539+zhanghy-sketchzh@users.noreply.github.com> Date: Fri, 4 Aug 2023 15:15:22 +0800 Subject: [PATCH 4/4] Update train_args.py --- dbgpt_hub/configs/train_args.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dbgpt_hub/configs/train_args.py b/dbgpt_hub/configs/train_args.py index aa517fc..1004cd7 100644 --- a/dbgpt_hub/configs/train_args.py +++ b/dbgpt_hub/configs/train_args.py @@ -4,7 +4,7 @@ @dataclass -class TrainingArguments(Seq2SeqTrainingArguments): +class TrainingArguments(transformers.Seq2SeqTrainingArguments): cache_dir: Optional[str] = field(default=None) train_on_source: Optional[bool] = field( default=False, @@ -56,4 +56,4 @@ class TrainingArguments(Seq2SeqTrainingArguments): group_by_length: bool = field(default=True, metadata={"help": 'Group sequences into batches with same length. Saves memory and speeds up training considerably.'}) save_strategy: str = field(default='steps', metadata={"help": 'When to save checkpoints'}) save_steps: int = field(default=250, metadata={"help": 'How often to save a model'}) - save_total_limit: int = field(default=40, metadata={"help": 'How many checkpoints to save before the oldest is overwritten'}) \ No newline at end of file + save_total_limit: int = field(default=40, metadata={"help": 'How many checkpoints to save before the oldest is overwritten'})