Skip to content

Commit

Permalink
Merge pull request #34 from eosphoros-ai/zhanghy-sketchzh-patch-2
Browse files Browse the repository at this point in the history
Update sql_data_process.py
  • Loading branch information
csunny authored Aug 4, 2023
2 parents 0145be9 + 2be3890 commit 1840e79
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 17 deletions.
12 changes: 9 additions & 3 deletions dbgpt_hub/configs/gen_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@ class GenerationArguments:
# https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig
# Length arguments
max_new_tokens: Optional[int] = field(
default=256,
default=128,
metadata={"help": "Maximum number of new tokens to be generated in evaluation or prediction loops"
"if predict_with_generate is set."}
)
min_new_tokens : Optional[int] = field(
default=None,
metadata={"help": "Minimum number of new tokens to generate."}
)

# Generation strategy
do_sample: Optional[bool] = field(default=False)
num_beams: Optional[int] = field(default=1)
Expand All @@ -32,4 +32,10 @@ class GenerationArguments:
diversity_penalty: Optional[float] = field(default=0.0)
repetition_penalty: Optional[float] = field(default=1.0)
length_penalty: Optional[float] = field(default=1.0)
no_repeat_ngram_size: Optional[int] = field(default=0)
no_repeat_ngram_size: Optional[int] = field(default=0)

def to_dict(self) -> Dict[str, Any]:
args = asdict(self)
if args.get('max_new_tokens', None):
args.pop('max_length', None)
return args
4 changes: 2 additions & 2 deletions dbgpt_hub/configs/train_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


@dataclass
class TrainingArguments(Seq2SeqTrainingArguments):
class TrainingArguments(transformers.Seq2SeqTrainingArguments):
cache_dir: Optional[str] = field(default=None)
train_on_source: Optional[bool] = field(
default=False,
Expand Down Expand Up @@ -56,4 +56,4 @@ class TrainingArguments(Seq2SeqTrainingArguments):
group_by_length: bool = field(default=True, metadata={"help": 'Group sequences into batches with same length. Saves memory and speeds up training considerably.'})
save_strategy: str = field(default='steps', metadata={"help": 'When to save checkpoints'})
save_steps: int = field(default=250, metadata={"help": 'How often to save a model'})
save_total_limit: int = field(default=40, metadata={"help": 'How many checkpoints to save before the oldest is overwritten'})
save_total_limit: int = field(default=40, metadata={"help": 'How many checkpoints to save before the oldest is overwritten'})
4 changes: 2 additions & 2 deletions dbgpt_hub/utils/sql_data_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -905,10 +905,10 @@ def serialize_schema_natural_language(
# get the table and column of x
db_column_names_table_id =[x["table_id"] for x in db_column_names]
x_table_name = db_table_name_strs[db_column_names_table_id[x]]
x_column_name = db_column_name_strs[x]
x_column_name = db_column_name_strs[x-1]
# get the table and column of y
y_table_name = db_table_name_strs[db_column_names_table_id[y]]
y_column_name = db_column_name_strs[y]
y_column_name = db_column_name_strs[y-1]
foreign_key_description_str = foreign_key_description(x_table_name, x_column_name, y_table_name, y_column_name)
descriptions.append(foreign_key_description_str)
return " ".join(descriptions)
Expand Down
18 changes: 8 additions & 10 deletions predict_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,16 @@

SQL_PROMPT_DICT = {
"prompt_input": (
"I want you to act as a SQL terminal in front of an example database,
you need to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.
The instruction is {instruction}, So please tell me {input} Response:"
"I want you to act as a SQL terminal in front of an example database, \
you need only to return the sql command to me.Below is an instruction that describes a task, \
Write a response that appropriately completes the request. \
The instruction is {instruction}, So please tell me {input}, ###Response:"
),
"prompt_no_input": (
"I want you to act as a SQL terminal in front of an example database,
you need to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.
The instruction is {instruction}, Response:"
"I want you to act as a SQL terminal in front of an example database, \
you need only to return the sql command to me.Below is an instruction that describes a task, \
Write a response that appropriately completes the request. \
The instruction is {instruction}, ###Response:"
),
}
def extract_sql_dataset(example):
Expand Down Expand Up @@ -132,10 +134,6 @@ def predict():
# result.append(response.replace("\n", ""))
return result





if __name__ == "__main__":

result = predict()
Expand Down

0 comments on commit 1840e79

Please sign in to comment.