Skip to content

Commit

Permalink
Add support for UBJSON
Browse files Browse the repository at this point in the history
  • Loading branch information
hcho3 committed Oct 2, 2024
1 parent 265e2f4 commit 49a5f45
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 8 deletions.
14 changes: 8 additions & 6 deletions docs/model_config.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ instance_group [{ kind: KIND_AUTO }]
parameters [
{
key: "model_type"
value: { string_value: "xgboost_json" }
value: { string_value: "xgboost_ubj" }
},
{
key: "output_class"
Expand Down Expand Up @@ -185,23 +185,25 @@ Treelite's checkpoint format. For more information, see [Model
Support](model_support.md).

The `model_type` option is used to indicate which of these serialization
formats your model uses: `xgboost` for XGBoost binary, `xgboost_json` for
XGBoost JSON, `lightgbm` for LightGBM, or `treelite_checkpoint` for
Treelite:
formats your model uses: `xgboost_ubj` for XGBoost UBJSON [^1], `xgboost_json` for
XGBoost JSON, `xgboost` for XGBoost binary (legacy), `lightgbm` for LightGBM,
or `treelite_checkpoint` for Treelite:

```
parameters [
{
key: "model_type"
value: { string_value: "xgboost_json" }
value: { string_value: "xgboost_ubj" }
}
]
```
[^1] Default format in XGBoost 2.1+

#### Model Filenames
For each model type, Triton expects a particular default filename:
- `xgboost.model` for XGBoost Binary
- `xgboost.ubj` for XGBoost UBJSON [^1]
- `xgboost.json` for XGBoost JSON
- `xgboost.model` for XGBoost Binary (Legacy)
- `model.txt` for LightGBM
- `checkpoint.tl` for Treelite
It is recommended that you use these filenames, but custom filenames can be
Expand Down
8 changes: 7 additions & 1 deletion qa/L0_e2e/generate_example_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,13 +307,19 @@ def generate_model(

def serialize_model(model, directory, output_format="xgboost"):
if output_format == "xgboost":
model_path = os.path.join(directory, "xgboost.model")
model_path = os.path.join(directory, "xgboost.deprecated")
model.save_model(model_path)
new_model_path = os.path.join(directory, "xgboost.model")
os.rename(model_path, new_model_path)
return model_path
if output_format == "xgboost_json":
model_path = os.path.join(directory, "xgboost.json")
model.save_model(model_path)
return model_path
if output_format == "xgboost_ubj":
model_path = os.path.join(directory, "xgboost.ubj")
model.save_model(model_path)
return model_path
if output_format == "lightgbm":
model_path = os.path.join(directory, "model.txt")
model.save_model(model_path)
Expand Down
13 changes: 13 additions & 0 deletions qa/generate_example_models.sh
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,19 @@ then
models+=( $name )
fi

name=xgboost_ubj
if [ $RETRAIN -ne 0 ] || [ ! -d "${MODEL_REPO}/${name}" ]
then
${GENERATOR_SCRIPT} \
--name $name \
--format xgboost_ubj \
--depth 7 \
--trees 500 \
--features 500 \
--predict_proba
models+=( $name )
fi

name=xgboost_shap
if [ $RETRAIN -ne 0 ] || [ ! -d "${MODEL_REPO}/${name}" ]
then
Expand Down
3 changes: 3 additions & 0 deletions src/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,9 @@ struct RapidsModel : rapids::Model<RapidsSharedState> {
case SerializationFormat::xgboost_json:
path /= "xgboost.json";
break;
case SerializationFormat::xgboost_ubj:
path /= "xgboost.ubj";
break;
case SerializationFormat::lightgbm:
path /= "model.txt";
break;
Expand Down
13 changes: 12 additions & 1 deletion src/serialization.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,13 @@

namespace triton { namespace backend { namespace NAMESPACE {

enum struct SerializationFormat { xgboost, xgboost_json, lightgbm, treelite };
enum struct SerializationFormat {
xgboost,
xgboost_json,
xgboost_ubj,
lightgbm,
treelite
};

inline auto
string_to_serialization(std::string const& type_string)
Expand All @@ -34,6 +40,8 @@ string_to_serialization(std::string const& type_string)
result = SerializationFormat::xgboost;
} else if (type_string == "xgboost_json") {
result = SerializationFormat::xgboost_json;
} else if (type_string == "xgboost_ubj") {
result = SerializationFormat::xgboost_ubj;
} else if (type_string == "lightgbm") {
result = SerializationFormat::lightgbm;
} else if (type_string == "treelite_checkpoint") {
Expand All @@ -60,6 +68,9 @@ serialization_to_string(SerializationFormat format)
case SerializationFormat::xgboost_json:
result = "xgboost_json";
break;
case SerializationFormat::xgboost_ubj:
result = "xgboost_ubj";
break;
case SerializationFormat::lightgbm:
result = "lightgbm";
break;
Expand Down
8 changes: 8 additions & 0 deletions src/tl_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,14 @@ load_tl_base_model(
model_file, config_str);
break;
}
case SerializationFormat::xgboost_ubj: {
auto config_str =
std::string("{\"allow_unknown_field\": ") +
std::string(xgboost_allow_unknown_field ? "true" : "false") + "}";
result = treelite::model_loader::LoadXGBoostModelUBJSON(
model_file, config_str);
break;
}
case SerializationFormat::lightgbm:
result = treelite::model_loader::LoadLightGBMModel(model_file);
break;
Expand Down

0 comments on commit 49a5f45

Please sign in to comment.