Skip to content

Commit

Permalink
Merge pull request #33 from eosphoros-ai/generate
Browse files Browse the repository at this point in the history
Updated the predict method and modified the code
  • Loading branch information
csunny authored Aug 3, 2023
2 parents b2fcc3b + b80a7cd commit 0145be9
Show file tree
Hide file tree
Showing 12 changed files with 140 additions and 266 deletions.
6 changes: 4 additions & 2 deletions dbgpt_hub/configs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from .data_args import DataArguments
from .gen_args import GenerationArguments
from .infer_args import ModelInferenceArguments
from .lora_args import LoraArguments
from .model_args import ModelArguments
from .quant_args import QuantArguments
from .train_args import TrainingArguments

__all__ = [
'DataArguments', 'GenerationArguments', 'ModelArguments',
'TrainingArguments', 'LoraArguments','QuantArguments'
]
'TrainingArguments', 'ModelInferenceArguments', 'LoraArguments',
'QuantArguments'
]
2 changes: 1 addition & 1 deletion dbgpt_hub/configs/gen_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class GenerationArguments:
default=None,
metadata={"help": "Minimum number of new tokens to generate."}
)

# Generation strategy
do_sample: Optional[bool] = field(default=False)
num_beams: Optional[int] = field(default=1)
Expand Down
28 changes: 28 additions & 0 deletions dbgpt_hub/configs/infer_args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from dataclasses import dataclass, field
from typing import Optional
import os

model_path = os.path.join("./model", os.listdir("model")[1])

@dataclass
class ModelInferenceArguments:
cache_dir: Optional[str] = field(default=None)
model_name_or_path: Optional[str] = field(
default=model_path,
metadata={'help': 'Path to pre-trained model'})
model_max_length: int = field(
default=1024,
metadata={
'help':
'Maximum sequence length. Sequences will be right padded (and possibly truncated).'
},
)
prompt_template: str = field(
default='default',
metadata={
'help':
'Prompt template name. Such as vanilla, alpaca, llama2, vicuna..., etc.'
})
source_prefix: Optional[str] = field(
default=None,
metadata={'help': 'Prefix to prepend to every source text.'})
2 changes: 1 addition & 1 deletion dbgpt_hub/configs/train_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


@dataclass
class TrainingArguments(TrainingArguments):
class TrainingArguments(Seq2SeqTrainingArguments):
cache_dir: Optional[str] = field(default=None)
train_on_source: Optional[bool] = field(
default=False,
Expand Down
14 changes: 7 additions & 7 deletions dbgpt_hub/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,14 @@ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:

SQL_PROMPT_DICT = {
"prompt_input": (
"I want you to act as a SQL terminal in front of an example database. "
"Below is an instruction that describes a task, Write a response that appropriately completes the request.\n\n"
"###Instruction:\n{instruction}\n\n###Input:\n{input}\n\n###Response: "
"I want you to act as a SQL terminal in front of an example database,
you need to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.
The instruction is {instruction}, So please tell me {input} Response:"
),
"prompt_no_input": (
"I want you to act as a SQL terminal in front of an example database. "
"Below is an instruction that describes a task, Write a response that appropriately completes the request.\n\n"
"###Instruction:\n{instruction}\n\n### Response: "
"I want you to act as a SQL terminal in front of an example database,
you need to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.
The instruction is {instruction}, Response:"
),
}

Expand Down Expand Up @@ -226,4 +226,4 @@ def format_dataset(dataset, dataset_format):
eval_dataset=eval_dataset if args.do_eval else None,
predict_dataset=eval_dataset if args.do_predict else None,
data_collator=data_collator
)
)
Loading

0 comments on commit 0145be9

Please sign in to comment.