Skip to content

Commit

Permalink
Merge pull request #46 from Hexastack/feat/break-apart-nlu-joint-model
Browse files Browse the repository at this point in the history
feat: break NLU JISF apart
  • Loading branch information
marrouchi authored Sep 23, 2024
2 parents 7115426 + 6183bf3 commit 1e4fdec
Show file tree
Hide file tree
Showing 10 changed files with 340 additions and 152 deletions.
3 changes: 2 additions & 1 deletion docker/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ AUTH_TOKEN=token123
LANGUAGE_CLASSIFIER=language-classifier
INTENT_CLASSIFIERS=en,fr
TFLC_REPO_ID=Hexastack/tflc
JISF_REPO_ID=Hexastack/jisf
INTENT_CLASSIFIER_REPO_ID=Hexastack/intent-classifier
SLOT_FILLER_REPO_ID=Hexastack/slot-filler
NLP_PORT=5000

# Frontend (Next.js)
Expand Down
3 changes: 2 additions & 1 deletion nlu/.env.dev
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@ AUTH_TOKEN=123
LANGUAGE_CLASSIFIER=language-classifier
INTENT_CLASSIFIERS=ar,fr,tn
TFLC_REPO_ID=Hexastack/tflc
JISF_REPO_ID=Hexastack/jisf
INTENT_CLASSIFIER_REPO_ID=Hexastack/intent-classifier
SLOT_FILLER_REPO_ID=Hexastack/slot-filler
4 changes: 2 additions & 2 deletions nlu/.env.example
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
AUTH_TOKEN=
LANGUAGE_CLASSIFIER=
INTENT_CLASSIFIERS=
TFLC_REPO_ID=
JISF_REPO_ID=
INTENT_CLASSIFIER_REPO_ID=
SLOT_FILLER_REPO_ID=
4 changes: 2 additions & 2 deletions nlu/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ pip install -r requirements.txt
You should run `source env.sh` on each new shell session. This activates the virtualenv and creates a nice alias for `run.py`:
```bash
$ cat env.sh
source env/bin/activate
source venv/bin/activate
alias run='python run.py'
```

Expand All @@ -53,7 +53,7 @@ run fit myexperiment1 mlp mnist --batch_size=32 --learning_rate=0.1
Examples :
```bash
# Intent classification
run fit intent-classifier-en-30072024 jisf --intent_num_labels=88 --slot_num_labels=17 --language=en
run fit intent-classifier-en-30072024 intent_classifier --intent_num_labels=88 --slot_num_labels=17 --language=en
run predict intent-classifier-fr-30072024 --intent_num_labels=7 --slot_num_labels=2 --language=fr

# Language classification
Expand Down
52 changes: 37 additions & 15 deletions nlu/data_loaders/jisfdl.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import numpy as np
from transformers import PreTrainedTokenizerFast, PreTrainedTokenizer


import boilerplate as tfbp
from utils.jisf_data_mapper import JisfDataMapper
from utils.json_helper import JsonHelper


Expand Down Expand Up @@ -101,8 +101,11 @@ def parse_dataset_intents(self, data):
# Filter examples by language
lang = self.hparams.language
all_examples = data["common_examples"]
examples = filter(lambda exp: any(
e['entity'] == 'language' and e['value'] == lang for e in exp['entities']), all_examples)

if not bool(lang):
examples = all_examples
else:
examples = filter(lambda exp: any(e['entity'] == 'language' and e['value'] == lang for e in exp['entities']), all_examples)

# Parse raw data
for exp in examples:
Expand Down Expand Up @@ -145,7 +148,6 @@ def _transform_dataset(self, dataset: List[JointRawData], tokenizer: Union[PreTr
# the classifier.
texts = [d.text for d in dataset]
encoded_texts = self.encode_texts(texts, tokenizer)

# Map intents, load from the model (evaluate), recompute from dataset otherwise (train)
intents = [d.intent for d in dataset]
if not model_params:
Expand All @@ -161,19 +163,35 @@ def _transform_dataset(self, dataset: List[JointRawData], tokenizer: Union[PreTr
# To handle those we need to add <PAD> to slots_names. It can be some other symbol as well.
slot_names.insert(0, "<PAD>")
else:
intent_names = model_params.intent_names
slot_names = model_params.slot_names

intent_map = dict() # Dict : intent -> index
for idx, ui in enumerate(intent_names):
intent_map[ui] = idx
if "intent_names" in model_params:
intent_names = model_params["intent_names"]
else:
intent_names = None

if "slot_names" in model_params:
slot_names = model_params["slot_names"]
else:
slot_names = None

if intent_names:
intent_map = dict() # Dict : intent -> index
for idx, ui in enumerate(intent_names):
intent_map[ui] = idx
else:
intent_map = None

# Encode intents
encoded_intents = self.encode_intents(intents, intent_map)
if intent_map:
encoded_intents = self.encode_intents(intents, intent_map)
else:
encoded_intents = None

slot_map: Dict[str, int] = dict() # slot -> index
for idx, us in enumerate(slot_names):
slot_map[us] = idx
if slot_names:
slot_map: Dict[str, int] = dict() # slot -> index
for idx, us in enumerate(slot_names):
slot_map[us] = idx
else:
slot_map = None

# Encode slots
# Text : Add a tune to my elrow Guest List
Expand All @@ -183,8 +201,12 @@ def _transform_dataset(self, dataset: List[JointRawData], tokenizer: Union[PreTr
max_len = len(encoded_texts["input_ids"][0]) # type: ignore
all_slots = [td.slots for td in dataset]
all_texts = [td.text for td in dataset]
encoded_slots = self.encode_slots(tokenizer,

if slot_map:
encoded_slots = self.encode_slots(tokenizer,
all_slots, all_texts, slot_map, max_len)
else:
encoded_slots = None

return encoded_texts, encoded_intents, encoded_slots, intent_names, slot_names

Expand Down
2 changes: 1 addition & 1 deletion nlu/data_loaders/tflcdl.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(self, method=None, save_dir=None, **hparams):

self.json_helper = JsonHelper("tflc")
self._save_dir = save_dir
print(hparams)

# We will opt for a TF-IDF representation of the data as the frequency of word
# roots should give us a good idea about which language we're dealing with.
if method == "fit":
Expand Down
43 changes: 30 additions & 13 deletions nlu/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@

AVAILABLE_LANGUAGES = os.getenv("AVAILABLE_LANGUAGES", "en,fr").split(',')
TFLC_REPO_ID = os.getenv("TFLC_REPO_ID")
JISF_REPO_ID = os.getenv("JISF_REPO_ID")

INTENT_CLASSIFIER_REPO_ID = os.getenv("INTENT_CLASSIFIER_REPO_ID")
SLOT_FILLER_REPO_ID = os.getenv("SLOT_FILLER_REPO_ID")

def load_language_classifier():
# Init language classifier model
Expand All @@ -27,21 +27,31 @@ def load_language_classifier():
logging.info(f'Successfully loaded the language classifier model')
return model


def load_intent_classifiers():
Model = tfbp.get_model("jisf")
models = {}
Model = tfbp.get_model("intent_classifier")
intent_classifiers = {}
for language in AVAILABLE_LANGUAGES:
kwargs = {}
models[language] = Model(save_dir=language, method="predict", repo_id=JISF_REPO_ID, **kwargs)
models[language].load_model()
intent_classifiers[language] = Model(save_dir=language, method="predict", repo_id=INTENT_CLASSIFIER_REPO_ID, **kwargs)
intent_classifiers[language].load_model()
logging.info(f'Successfully loaded the intent classifier {language} model')
return models
return intent_classifiers

def load_slot_classifiers():
Model = tfbp.get_model("slot_classifier")
slot_fillers = {}
for language in AVAILABLE_LANGUAGES:
kwargs = {}
slot_fillers[language] = Model(save_dir=language, method="predict", repo_id=SLOT_FILLER_REPO_ID, **kwargs)
slot_fillers[language].load_model()
logging.info(f'Successfully loaded the slot filler {language} model')
return slot_fillers


def load_models():
app.language_classifier = load_language_classifier() # type: ignore
app.intent_classifiers = load_intent_classifiers() # type: ignore
app.slot_fillers = load_intent_classifiers() # type: ignore

app = FastAPI()

Expand Down Expand Up @@ -74,13 +84,20 @@ async def check_health():

@app.post("/parse")
def parse(input: ParseInput, is_authenticated: Annotated[str, Depends(authenticate)]):
if not hasattr(app, 'language_classifier') or not hasattr(app, 'intent_classifiers'):
if not hasattr(app, 'language_classifier') or not hasattr(app, 'intent_classifiers') or not hasattr(app, 'slot_fillers'):
headers = {"Retry-After": "120"} # Suggest retrying after 2 minutes
return JSONResponse(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, content={"message": "Models are loading, please retry later."}, headers=headers)
return JSONResponse(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, content={"message": "Models are still loading, please retry later."}, headers=headers)

language = app.language_classifier.get_prediction(input.q) # type: ignore
lang = language.get("value")
prediction = app.intent_classifiers[lang].get_prediction(
intent_prediction = app.intent_classifiers[lang].get_prediction(
input.q) # type: ignore
slot_prediction = app.slot_fillers[lang].get_prediction(
input.q) # type: ignore
prediction.get("entities").append(language)
return prediction
slot_prediction.get("entities").append(language)

return {
"text": input.q,
"intent": intent_prediction.get("intent"),
"entities": slot_prediction.get("entities"),
}
Loading

0 comments on commit 1e4fdec

Please sign in to comment.