Skip to content

Commit

Permalink
Merge branch 'main' into feat/example-docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
axl1313 committed Jul 17, 2023
2 parents 06b3e66 + a00820b commit ad01269
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 30 deletions.
42 changes: 34 additions & 8 deletions cleanlab_studio/internal/api/api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import time
from typing import Any, Callable, List, Optional, Tuple, Dict
from typing import Callable, List, Optional, Tuple, Union, Any
from cleanlab_studio.errors import APIError

import requests
Expand All @@ -9,6 +9,13 @@
import numpy as np
import numpy.typing as npt

try:
import pyspark.sql

pyspark_exists = True
except ImportError:
pyspark_exists = False

from cleanlab_studio.internal.types import JSONDict
from cleanlab_studio.version import __version__

Expand Down Expand Up @@ -155,25 +162,44 @@ def get_label_column_of_project(api_key: str, project_id: str) -> str:
return label_column


def download_cleanlab_columns(api_key: str, cleanset_id: str, all: bool = False) -> pd.DataFrame:
def download_cleanlab_columns(
api_key: str,
cleanset_id: str,
all: bool = True,
to_spark: bool = False,
) -> Any:
"""
Download all rows from specified Cleanlab columns
:param api_key:
:param cleanset_id:
:param all: whether to download all Cleanlab columns or just the clean_label column
:return: return (rows, id_column)
:return: return a dataframe, either pandas or spark. Type is Any because don't want to require spark installed
"""
res = requests.get(
cli_base_url + f"/cleansets/{cleanset_id}/columns?all={all}",
cli_base_url + f"/cleansets/{cleanset_id}/columns",
params=dict(to_spark=to_spark, all=all),
headers=_construct_headers(api_key),
)
handle_api_error(res)
cleanset_json: str = res.json()["cleanset_json"]
cleanset_df: pd.DataFrame = pd.read_json(cleanset_json, orient="table")
id_col = get_id_column(api_key, cleanset_id)
cleanset_df.rename(columns={"id": id_col}, inplace=True)
return cleanset_df
cleanset_json: str = res.json()["cleanset_json"]
if to_spark:
if not pyspark_exists:
raise ImportError(
"pyspark is not installed. Please install pyspark to download cleanlab columns as a pyspark DataFrame."
)
from pyspark.sql import SparkSession

spark = SparkSession.builder.getOrCreate()
rdd = spark.sparkContext.parallelize([cleanset_json])
cleanset_pyspark: pyspark.sql.DataFrame = spark.read.json(rdd)
cleanset_pyspark = cleanset_pyspark.withColumnRenamed("id", id_col)
return cleanset_pyspark

cleanset_pd: pd.DataFrame = pd.read_json(cleanset_json, orient="table")
cleanset_pd.rename(columns={"id": id_col}, inplace=True)
return cleanset_pd


def download_numpy(api_key: str, cleanset_id: str, name: str) -> npt.NDArray[np.float_]:
Expand Down
49 changes: 27 additions & 22 deletions cleanlab_studio/studio/studio.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,12 @@
"""
Python API for Cleanlab Studio.
"""
from typing import Any, List, Literal, Optional
from typing import Any, List, Literal, Optional, Union

import numpy as np
import numpy.typing as npt
import pandas as pd

try:
import pyspark.sql

_pyspark_exists = True
except ImportError:
_pyspark_exists = False

from . import clean, upload
from cleanlab_studio.internal.api import api
from cleanlab_studio.internal.util import (
Expand All @@ -24,6 +17,10 @@
from cleanlab_studio.internal.settings import CleanlabSettings
from cleanlab_studio.internal.types import FieldSchemaDict

_pyspark_exists = api.pyspark_exists
if _pyspark_exists:
import pyspark.sql


class Studio:
"""Used to interact with Cleanlab Studio
Expand Down Expand Up @@ -87,7 +84,8 @@ def download_cleanlab_columns(
self,
cleanset_id: str,
include_action: bool = False,
) -> pd.DataFrame:
to_spark: bool = False,
) -> Any:
"""
Downloads Cleanlab columns for a cleanset
Expand All @@ -96,11 +94,17 @@ def download_cleanlab_columns(
include_action: Whether to include a column with any actions taken on the cleanset in the downloaded columns
Returns:
Dataframe of downloaded columns
A pandas or pyspark DataFrame
Type Any because don't want to rely on pyspark being installed
"""
rows_df: pd.DataFrame = api.download_cleanlab_columns(self._api_key, cleanset_id, all=True)
rows_df = api.download_cleanlab_columns(
self._api_key, cleanset_id, all=True, to_spark=to_spark
)
if not include_action:
rows_df.drop("action", inplace=True, axis=1)
if to_spark:
rows_df = rows_df.drop("action")
else:
rows_df.drop("action", inplace=True, axis=1)
return rows_df

def apply_corrections(self, cleanset_id: str, dataset: Any, keep_excluded: bool = False) -> Any:
Expand All @@ -118,26 +122,26 @@ def apply_corrections(self, cleanset_id: str, dataset: Any, keep_excluded: bool
project_id = api.get_project_of_cleanset(self._api_key, cleanset_id)
label_column = api.get_label_column_of_project(self._api_key, project_id)
id_col = api.get_id_column(self._api_key, cleanset_id)
cl_cols = self.download_cleanlab_columns(cleanset_id, include_action=True)
if _pyspark_exists and isinstance(dataset, pyspark.sql.DataFrame):
from pyspark.sql.functions import udf

spark = dataset.sparkSession
cl_cols_df = spark.createDataFrame(cl_cols)
corrected_ds = dataset.alias("corrected_ds")
if id_col not in corrected_ds.columns:
cl_cols = self.download_cleanlab_columns(
cleanset_id, include_action=True, to_spark=True
)
corrected_ds_spark = dataset.alias("corrected_ds")
if id_col not in corrected_ds_spark.columns:
from pyspark.sql.functions import (
row_number,
monotonically_increasing_id,
)
from pyspark.sql.window import Window

corrected_ds = corrected_ds.withColumn(
corrected_ds_spark = corrected_ds_spark.withColumn(
id_col,
row_number().over(Window.orderBy(monotonically_increasing_id())) - 1,
)
both = cl_cols_df.select([id_col, "action", "clean_label"]).join(
corrected_ds.select([id_col, label_column]),
both = cl_cols.select([id_col, "action", "clean_label"]).join(
corrected_ds_spark.select([id_col, label_column]),
on=id_col,
how="left",
)
Expand All @@ -154,12 +158,13 @@ def apply_corrections(self, cleanset_id: str, dataset: Any, keep_excluded: bool
[id_col, "action", "__cleanlab_final_label"]
).withColumnRenamed("__cleanlab_final_label", label_column)
return (
corrected_ds.drop(label_column)
corrected_ds_spark.drop(label_column)
.join(new_labels, on=id_col, how="right")
.where(new_labels["action"] != "exclude")
.drop("action")
)
elif isinstance(dataset, pd.DataFrame):
cl_cols = self.download_cleanlab_columns(cleanset_id, include_action=True)
joined_ds: pd.DataFrame
if id_col in dataset.columns:
joined_ds = dataset.join(cl_cols.set_index(id_col), on=id_col)
Expand All @@ -170,7 +175,7 @@ def apply_corrections(self, cleanset_id: str, dataset: Any, keep_excluded: bool
dataset[label_column].to_numpy(),
)

corrected_ds = dataset.copy()
corrected_ds: pd.DataFrame = dataset.copy()
corrected_ds[label_column] = joined_ds["__cleanlab_final_label"]
if not keep_excluded:
corrected_ds = corrected_ds.loc[(joined_ds["action"] != "exclude").fillna(True)]
Expand Down

0 comments on commit ad01269

Please sign in to comment.