Skip to content

Commit

Permalink
lazy import sklearn
Browse files Browse the repository at this point in the history
  • Loading branch information
huiwengoh committed Sep 25, 2024
1 parent ea998e0 commit 905c629
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 4 deletions.
5 changes: 5 additions & 0 deletions cleanlab_studio/studio/trustworthy_language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,6 +782,11 @@ class TLMOptions(TypedDict):
log (List[str], default = []): optionally specify additional logs or metadata to return.
For instance, include "explanation" here to get explanations of why a response is scored with low trustworthiness.
custom_eval_criteria (List[Dict[str, Any]], default = []): optionally specify custom evalution criteria.
The expected input format is a list of dictionaries, where each dictionary has the following keys:
- name: name of the evaluation criteria
- criteria: the instruction for the evaluation criteria
"""

model: NotRequired[str]
Expand Down
20 changes: 17 additions & 3 deletions cleanlab_studio/utils/tlm_calibrated.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,6 @@
import numpy as np
import numpy.typing as npt
import pandas as pd
from sklearn.ensemble import RandomForestRegressor
from sklearn.exceptions import NotFittedError
from sklearn.utils.validation import check_is_fitted

from cleanlab_studio.errors import ValidationError, TlmNotCalibratedError
from cleanlab_studio.internal.types import TLMQualityPreset
Expand All @@ -36,6 +33,14 @@ def __init__(
Use `Studio.TLMCalibrated()` instead of this method to initialize a TLMCalibrated object.
lazydocs: ignore
"""
try:
from sklearn.ensemble import RandomForestRegressor
except ImportError:
raise ImportError(
"Cannot import scikit-learn which is required to use TLMCalibrated. "
"Please install it using `pip install scikit-learn` and try again."
)

self._api_key = api_key

if quality_preset not in {"base", "low", "medium"}:
Expand Down Expand Up @@ -108,6 +113,15 @@ def get_trustworthiness_score(
Similar to [`TLM.get_trustworthiness_score()`](../trustworthy_language_model/#method-get_trustworthiness_score),
view documentation there for expected input arguments and outputs.
"""
try:
from sklearn.exceptions import NotFittedError
from sklearn.utils.validation import check_is_fitted
except ImportError:
raise ImportError(
"Cannot import scikit-learn which is required to use TLMCalibrated. "
"Please install it using `pip install scikit-learn` and try again."
)

try:
check_is_fitted(self._rf_model)
except NotFittedError:
Expand Down
1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@
"openpyxl>=3.0.0,!=3.1.0",
"validators>=0.20.0",
"matplotlib>=3.4.0",
"scikit-learn",
],
entry_points="""
[console_scripts]
Expand Down

0 comments on commit 905c629

Please sign in to comment.