diff --git a/cleanlab_studio/internal/api/api.py b/cleanlab_studio/internal/api/api.py index d01abc7d..1b267897 100644 --- a/cleanlab_studio/internal/api/api.py +++ b/cleanlab_studio/internal/api/api.py @@ -1,6 +1,6 @@ import os import time -from typing import Any, Callable, List, Optional, Tuple, Dict +from typing import Callable, List, Optional, Tuple, Union, Any from cleanlab_studio.errors import APIError import requests @@ -9,6 +9,13 @@ import numpy as np import numpy.typing as npt +try: + import pyspark.sql + + pyspark_exists = True +except ImportError: + pyspark_exists = False + from cleanlab_studio.internal.types import JSONDict from cleanlab_studio.version import __version__ @@ -155,25 +162,44 @@ def get_label_column_of_project(api_key: str, project_id: str) -> str: return label_column -def download_cleanlab_columns(api_key: str, cleanset_id: str, all: bool = False) -> pd.DataFrame: +def download_cleanlab_columns( + api_key: str, + cleanset_id: str, + all: bool = True, + to_spark: bool = False, +) -> Any: """ Download all rows from specified Cleanlab columns :param api_key: :param cleanset_id: :param all: whether to download all Cleanlab columns or just the clean_label column - :return: return (rows, id_column) + :return: return a dataframe, either pandas or spark. Type is Any because don't want to require spark installed """ res = requests.get( - cli_base_url + f"/cleansets/{cleanset_id}/columns?all={all}", + cli_base_url + f"/cleansets/{cleanset_id}/columns", + params=dict(to_spark=to_spark, all=all), headers=_construct_headers(api_key), ) handle_api_error(res) - cleanset_json: str = res.json()["cleanset_json"] - cleanset_df: pd.DataFrame = pd.read_json(cleanset_json, orient="table") id_col = get_id_column(api_key, cleanset_id) - cleanset_df.rename(columns={"id": id_col}, inplace=True) - return cleanset_df + cleanset_json: str = res.json()["cleanset_json"] + if to_spark: + if not pyspark_exists: + raise ImportError( + "pyspark is not installed. Please install pyspark to download cleanlab columns as a pyspark DataFrame." + ) + from pyspark.sql import SparkSession + + spark = SparkSession.builder.getOrCreate() + rdd = spark.sparkContext.parallelize([cleanset_json]) + cleanset_pyspark: pyspark.sql.DataFrame = spark.read.json(rdd) + cleanset_pyspark = cleanset_pyspark.withColumnRenamed("id", id_col) + return cleanset_pyspark + + cleanset_pd: pd.DataFrame = pd.read_json(cleanset_json, orient="table") + cleanset_pd.rename(columns={"id": id_col}, inplace=True) + return cleanset_pd def download_numpy(api_key: str, cleanset_id: str, name: str) -> npt.NDArray[np.float_]: diff --git a/cleanlab_studio/studio/studio.py b/cleanlab_studio/studio/studio.py index e11747c5..98fe1389 100644 --- a/cleanlab_studio/studio/studio.py +++ b/cleanlab_studio/studio/studio.py @@ -1,19 +1,12 @@ """ Python API for Cleanlab Studio. """ -from typing import Any, List, Literal, Optional +from typing import Any, List, Literal, Optional, Union import numpy as np import numpy.typing as npt import pandas as pd -try: - import pyspark.sql - - _pyspark_exists = True -except ImportError: - _pyspark_exists = False - from . import clean, upload from cleanlab_studio.internal.api import api from cleanlab_studio.internal.util import ( @@ -24,6 +17,10 @@ from cleanlab_studio.internal.settings import CleanlabSettings from cleanlab_studio.internal.types import FieldSchemaDict +_pyspark_exists = api.pyspark_exists +if _pyspark_exists: + import pyspark.sql + class Studio: """Used to interact with Cleanlab Studio @@ -87,7 +84,8 @@ def download_cleanlab_columns( self, cleanset_id: str, include_action: bool = False, - ) -> pd.DataFrame: + to_spark: bool = False, + ) -> Any: """ Downloads Cleanlab columns for a cleanset @@ -96,11 +94,17 @@ def download_cleanlab_columns( include_action: Whether to include a column with any actions taken on the cleanset in the downloaded columns Returns: - Dataframe of downloaded columns + A pandas or pyspark DataFrame + Type Any because don't want to rely on pyspark being installed """ - rows_df: pd.DataFrame = api.download_cleanlab_columns(self._api_key, cleanset_id, all=True) + rows_df = api.download_cleanlab_columns( + self._api_key, cleanset_id, all=True, to_spark=to_spark + ) if not include_action: - rows_df.drop("action", inplace=True, axis=1) + if to_spark: + rows_df = rows_df.drop("action") + else: + rows_df.drop("action", inplace=True, axis=1) return rows_df def apply_corrections(self, cleanset_id: str, dataset: Any, keep_excluded: bool = False) -> Any: @@ -118,26 +122,26 @@ def apply_corrections(self, cleanset_id: str, dataset: Any, keep_excluded: bool project_id = api.get_project_of_cleanset(self._api_key, cleanset_id) label_column = api.get_label_column_of_project(self._api_key, project_id) id_col = api.get_id_column(self._api_key, cleanset_id) - cl_cols = self.download_cleanlab_columns(cleanset_id, include_action=True) if _pyspark_exists and isinstance(dataset, pyspark.sql.DataFrame): from pyspark.sql.functions import udf - spark = dataset.sparkSession - cl_cols_df = spark.createDataFrame(cl_cols) - corrected_ds = dataset.alias("corrected_ds") - if id_col not in corrected_ds.columns: + cl_cols = self.download_cleanlab_columns( + cleanset_id, include_action=True, to_spark=True + ) + corrected_ds_spark = dataset.alias("corrected_ds") + if id_col not in corrected_ds_spark.columns: from pyspark.sql.functions import ( row_number, monotonically_increasing_id, ) from pyspark.sql.window import Window - corrected_ds = corrected_ds.withColumn( + corrected_ds_spark = corrected_ds_spark.withColumn( id_col, row_number().over(Window.orderBy(monotonically_increasing_id())) - 1, ) - both = cl_cols_df.select([id_col, "action", "clean_label"]).join( - corrected_ds.select([id_col, label_column]), + both = cl_cols.select([id_col, "action", "clean_label"]).join( + corrected_ds_spark.select([id_col, label_column]), on=id_col, how="left", ) @@ -154,12 +158,13 @@ def apply_corrections(self, cleanset_id: str, dataset: Any, keep_excluded: bool [id_col, "action", "__cleanlab_final_label"] ).withColumnRenamed("__cleanlab_final_label", label_column) return ( - corrected_ds.drop(label_column) + corrected_ds_spark.drop(label_column) .join(new_labels, on=id_col, how="right") .where(new_labels["action"] != "exclude") .drop("action") ) elif isinstance(dataset, pd.DataFrame): + cl_cols = self.download_cleanlab_columns(cleanset_id, include_action=True) joined_ds: pd.DataFrame if id_col in dataset.columns: joined_ds = dataset.join(cl_cols.set_index(id_col), on=id_col) @@ -170,7 +175,7 @@ def apply_corrections(self, cleanset_id: str, dataset: Any, keep_excluded: bool dataset[label_column].to_numpy(), ) - corrected_ds = dataset.copy() + corrected_ds: pd.DataFrame = dataset.copy() corrected_ds[label_column] = joined_ds["__cleanlab_final_label"] if not keep_excluded: corrected_ds = corrected_ds.loc[(joined_ds["action"] != "exclude").fillna(True)]