diff --git a/app.py b/app.py index c26a4f3d9..52144a767 100644 --- a/app.py +++ b/app.py @@ -15,13 +15,13 @@ logger = logging.getLogger(__name__) -def on_startup(): +def on_startup() -> None: initialize_logging() logger.info("STARTING APP") @app("/", on_startup=on_startup) -async def serve(q: Q): +async def serve(q: Q) -> None: """Serving function.""" # Chat is still being streamed but user clicks on another button. diff --git a/llm_studio/app_utils/config.py b/llm_studio/app_utils/config.py index eeea1b218..0d6fb2554 100644 --- a/llm_studio/app_utils/config.py +++ b/llm_studio/app_utils/config.py @@ -83,7 +83,6 @@ def get_size(x): "validation_dataframe", ], "user_settings": { - "theme_dark": True, "credential_saver": ".env File", "default_aws_bucket_name": f"{os.getenv('AWS_BUCKET', 'bucket_name')}", "default_aws_access_key": os.getenv("AWS_ACCESS_KEY_ID", ""), diff --git a/llm_studio/app_utils/handlers.py b/llm_studio/app_utils/handlers.py index aeb829032..0a5f0e796 100644 --- a/llm_studio/app_utils/handlers.py +++ b/llm_studio/app_utils/handlers.py @@ -424,7 +424,7 @@ async def handle(q: Q) -> None: ) -async def experiment_delete_all_artifacts(q: Q, experiment_ids: List[int]): +async def experiment_delete_all_artifacts(q: Q, experiment_ids: List[int]) -> None: await experiment_stop(q, experiment_ids) await experiment_delete(q, experiment_ids) await list_current_experiments(q) diff --git a/llm_studio/app_utils/sections/chat.py b/llm_studio/app_utils/sections/chat.py index d380699e5..4b64faad1 100644 --- a/llm_studio/app_utils/sections/chat.py +++ b/llm_studio/app_utils/sections/chat.py @@ -191,6 +191,7 @@ def gpu_is_blocked(q, gpu_id): def load_cfg_model_tokenizer( experiment_path: str, merge: bool = False, device: str = "cuda:0" ): + """Loads the model, tokenizer and configuration from the experiment path.""" cfg = load_config_yaml(os.path.join(experiment_path, "cfg.yaml")) cfg.architecture.pretrained = False cfg.architecture.gradient_checkpointing = False diff --git a/llm_studio/app_utils/sections/chat_update.py b/llm_studio/app_utils/sections/chat_update.py index 63111adac..833641359 100644 --- a/llm_studio/app_utils/sections/chat_update.py +++ b/llm_studio/app_utils/sections/chat_update.py @@ -267,7 +267,7 @@ async def show_stream_is_aborted_dialog(q): await q.page.save() -async def is_app_blocked_while_streaming(q: Q): +async def is_app_blocked_while_streaming(q: Q) -> bool: """ Check whether the app is blocked with current answer generation. """ diff --git a/llm_studio/app_utils/sections/common.py b/llm_studio/app_utils/sections/common.py index 634a58343..2ea9ca9d0 100644 --- a/llm_studio/app_utils/sections/common.py +++ b/llm_studio/app_utils/sections/common.py @@ -10,6 +10,84 @@ logger = logging.getLogger(__name__) +# https://github.com/highlightjs/highlight.js/blob/main/src/styles/atom-one-dark.css. +css = """ +.hljs { + color: #abb2bf; +} + +.hljs-comment, +.hljs-quote { + color: #5c6370; + font-style: italic; +} + +.hljs-doctag, +.hljs-keyword, +.hljs-formula { + color: #c678dd; +} + +.hljs-section, +.hljs-name, +.hljs-selector-tag, +.hljs-deletion, +.hljs-subst { + color: #e06c75; +} + +.hljs-literal { + color: #56b6c2; +} + +.hljs-string, +.hljs-regexp, +.hljs-addition, +.hljs-attribute, +.hljs-meta .hljs-string { + color: #98c379; +} + +.hljs-attr, +.hljs-variable, +.hljs-template-variable, +.hljs-type, +.hljs-selector-class, +.hljs-selector-attr, +.hljs-selector-pseudo, +.hljs-number { + color: #d19a66; +} + +.hljs-symbol, +.hljs-bullet, +.hljs-link, +.hljs-meta, +.hljs-selector-id, +.hljs-title { + color: #61aeee; +} + +.hljs-built_in, +.hljs-title.class_, +.hljs-class .hljs-title { + color: #e6c07b; +} + +.hljs-emphasis { + font-style: italic; +} + +.hljs-strong { + font-weight: bold; +} + +.hljs-link { + text-decoration: underline; +} +""" + + async def meta(q: Q) -> None: if q.client["keep_meta"]: # Do not reset meta, keep current dialog opened q.client["keep_meta"] = False @@ -37,52 +115,12 @@ async def meta(q: Q) -> None: scripts=[ ui.script(source, asynchronous=True) for source in q.app["script_sources"] ], - stylesheet=ui.inline_stylesheet( - """ - .ms-MessageBar { - padding-top: 3px; - padding-bottom: 3px; - min-height: 18px; - } - div[data-test="nav_bar"] .ms-Nav-groupContent { - margin-bottom: 0; - } - - div[data-test="experiment/display/deployment/top_right"], - div[data-test="experiment/display/deployment/top_right"] - div[data-visible="true"]:last-child > div > div { - display: flex; - } - - div[data-test="experiment/display/deployment/top_right"] - div[data-visible="true"]:last-child, - div[data-test="experiment/display/deployment/top_right"] - div[data-visible="true"]:last-child > div { - display: flex; - flex-grow: 1; - } - - div[data-test="experiment/display/deployment/top_right"] - div[data-visible="true"]:last-child > div > div > div { - display: flex; - flex-grow: 1; - flex-direction: column; - } - - div[data-test="experiment/display/deployment/top_right"] - div[data-visible="true"]:last-child > div > div > div > div { - flex-grow: 1; - } - """ - ), + stylesheet=ui.inline_stylesheet(css), script=None, notification_bar=notification_bar, ) - if q.client.theme_dark: - q.page["meta"].theme = "h2o-dark" - else: - q.page["meta"].theme = "light" + q.page["meta"].theme = "h2o-dark" def heap_analytics( diff --git a/llm_studio/app_utils/sections/dataset.py b/llm_studio/app_utils/sections/dataset.py index dfc585bbe..899d85ea4 100644 --- a/llm_studio/app_utils/sections/dataset.py +++ b/llm_studio/app_utils/sections/dataset.py @@ -1100,7 +1100,7 @@ async def dataset_delete_single(q: Q, dataset_id: int): async def dataset_display(q: Q) -> None: """Display a selected dataset.""" - dataset_id = q.client["dataset/list/df_datasets"]["id"].iloc[ + dataset_id: int = q.client["dataset/list/df_datasets"]["id"].iloc[ q.client["dataset/display/id"] ] dataset: Dataset = q.client.app_db.get_dataset(dataset_id) @@ -1196,7 +1196,7 @@ async def show_data_tab(q, cfg, filename: str): q.client.delete_cards.add("dataset/display/data") -async def show_visualization_tab(q, cfg): +async def show_visualization_tab(q: Q, cfg): try: plot = cfg.logging.plots_class.plot_data(cfg) except Exception as error: @@ -1236,7 +1236,7 @@ async def show_visualization_tab(q, cfg): q.client.delete_cards.add("dataset/display/visualization") -async def show_summary_tab(q, dataset_id): +async def show_summary_tab(q: Q, dataset_id: int) -> None: dataset_df = get_datasets(q) dataset_df = dataset_df[dataset_df.id == dataset_id] stat_list_items: List[StatListItem] = [] @@ -1253,7 +1253,9 @@ async def show_summary_tab(q, dataset_id): q.client.delete_cards.add("dataset/display/summary") -async def show_statistics_tab(q, dataset_filename, config_filename): +async def show_statistics_tab( + q: Q, dataset_filename: str, config_filename: str +) -> None: cfg_hash = hashlib.md5(open(config_filename, "rb").read()).hexdigest() stats_dict = compute_dataset_statistics(dataset_filename, config_filename, cfg_hash) @@ -1320,7 +1322,7 @@ async def show_statistics_tab(q, dataset_filename, config_filename): @functools.lru_cache() -def compute_dataset_statistics(dataset_path: str, cfg_path: str, cfg_hash: str): +def compute_dataset_statistics(dataset_path: str, cfg_path: str, cfg_hash: str) -> dict: """ Compute various statistics for a dataset. - text length distribution for prompts and answers @@ -1362,7 +1364,7 @@ def compute_dataset_statistics(dataset_path: str, cfg_path: str, cfg_hash: str): return stats_dict -async def dataset_import_uploaded_file(q: Q): +async def dataset_import_uploaded_file(q: Q) -> None: local_path = await q.site.download( q.args["dataset/import/local_upload"][0], f"{get_data_dir(q)}/" @@ -1378,7 +1380,7 @@ async def dataset_import_uploaded_file(q: Q): await dataset_import(q, step=1, error=error) -async def dataset_delete_current_datasets(q: Q): +async def dataset_delete_current_datasets(q: Q) -> None: dataset_ids = list( q.client["dataset/list/df_datasets"]["id"].iloc[ list(map(int, q.client["dataset/list/table"])) diff --git a/llm_studio/app_utils/sections/experiment.py b/llm_studio/app_utils/sections/experiment.py index 76db8da41..d76816849 100644 --- a/llm_studio/app_utils/sections/experiment.py +++ b/llm_studio/app_utils/sections/experiment.py @@ -1421,6 +1421,7 @@ async def summary_tab(experiment_id, q): box=ui.box(zone="third"), title="", content=content, + compact=False, ) q.client.delete_cards.add(card_name) @@ -1495,7 +1496,7 @@ def unite_validation_metric_charts(charts_list): return charts_list -async def charts_tab(q, charts_list, legend_labels): +async def charts_tab(q, charts_list, legend_labels) -> None: charts_list = unite_validation_metric_charts(charts_list) box = ["first", "first", "second", "second"] diff --git a/llm_studio/app_utils/sections/settings.py b/llm_studio/app_utils/sections/settings.py index df9d5db5c..4a1fa2d2a 100644 --- a/llm_studio/app_utils/sections/settings.py +++ b/llm_studio/app_utils/sections/settings.py @@ -54,17 +54,6 @@ async def settings(q: Q) -> None: """, ), ui.separator("Appearance"), - ui.inline( - items=[ - ui.label("Dark Mode", width=label_width), - ui.toggle( - name="theme_dark", - value=q.client["theme_dark"], - tooltip="Enables Dark Mode as theme.", - trigger=True, - ), - ] - ), ui.inline( items=[ ui.label("Delete Dialogs", width=label_width), diff --git a/llm_studio/app_utils/wave_utils.py b/llm_studio/app_utils/wave_utils.py index e23d7e807..4b10d6ee9 100644 --- a/llm_studio/app_utils/wave_utils.py +++ b/llm_studio/app_utils/wave_utils.py @@ -13,20 +13,14 @@ class ThemeColors(TypedDict): - light: dict - dark: dict + primary: str + background_color: str class WaveTheme: _theme_colors: ThemeColors = { - "light": { - "primary": "#000000", - "background_color": "#ffffff", - }, - "dark": { - "primary": "#FEC925", - "background_color": "#121212", - }, + "primary": "#FEC925", + "background_color": "#121212", } states = { @@ -43,20 +37,12 @@ class WaveTheme: def __repr__(self) -> str: return "WaveTheme" - def get_value_by_key(self, q: Q, key: str): - value = ( - self._theme_colors["dark"][key] - if q.client.theme_dark - else self._theme_colors["light"][key] - ) - return value - def get_primary_color(self, q: Q): - primary_color = self.get_value_by_key(q, "primary") + primary_color = self._theme_colors["primary"] return primary_color def get_background_color(self, q: Q): - background_color = self.get_value_by_key(q, "background_color") + background_color = self._theme_colors["background_color"] return background_color diff --git a/llm_studio/python_configs/base.py b/llm_studio/python_configs/base.py index d0350cb33..e1a5af575 100644 --- a/llm_studio/python_configs/base.py +++ b/llm_studio/python_configs/base.py @@ -204,6 +204,7 @@ class DefaultConfigProblemBase(DefaultConfig): experiment_name: str output_directory: str llm_backbone: str + _parent_experiment: str dataset: Any tokenizer: Any diff --git a/llm_studio/src/order.py b/llm_studio/src/order.py index 3f7a4ef14..2bd1e471b 100644 --- a/llm_studio/src/order.py +++ b/llm_studio/src/order.py @@ -20,12 +20,12 @@ def __init__(self, keys: Optional[List[str]] = None): else: self._list = list() - def _unique_guard(self, *keys: str): + def _unique_guard(self, *keys: str) -> None: for key in keys: if key in self._list: raise ValueError(f"`{key}` is already in the list!") - def append(self, key: str): + def append(self, key: str) -> None: """ Append a key at the end of the list: diff --git a/llm_studio/src/plots/text_causal_language_modeling_plots.py b/llm_studio/src/plots/text_causal_language_modeling_plots.py index a2aeda47b..c743c4cc7 100644 --- a/llm_studio/src/plots/text_causal_language_modeling_plots.py +++ b/llm_studio/src/plots/text_causal_language_modeling_plots.py @@ -7,11 +7,7 @@ from llm_studio.src.datasets.conversation_chain_handler import get_conversation_chains from llm_studio.src.datasets.text_utils import get_tokenizer from llm_studio.src.utils.data_utils import read_dataframe_drop_missing_labels -from llm_studio.src.utils.plot_utils import ( - PlotData, - format_for_markdown_visualization, - list_to_markdown_representation, -) +from llm_studio.src.utils.plot_utils import PlotData, list_to_markdown_representation class Plots: @@ -90,10 +86,6 @@ def plot_data(cls, cfg) -> PlotData: ] i += 1 - df_transposed["Content"] = df_transposed["Content"].apply( - format_for_markdown_visualization - ) - df_transposed.to_parquet(path) return PlotData(path, encoding="df") @@ -153,11 +145,6 @@ def plot_validation_predictions( "Predicted Text": predicted_texts, } ) - df[input_text_column_name] = df[input_text_column_name].apply( - format_for_markdown_visualization - ) - df["Target Text"] = df["Target Text"].apply(format_for_markdown_visualization) - df["Predicted Text"] = df["Predicted Text"].apply(format_for_markdown_visualization) if val_outputs.get("metrics") is not None: metric_column_name = f"Metric ({cfg.prediction.metric})" @@ -195,7 +182,7 @@ def create_batch_prediction_df( ] } ) - df["Prompt Text"] = df["Prompt Text"].apply(format_for_markdown_visualization) + if labels_column in batch.keys(): df["Answer Text"] = [ tokenizer.decode( diff --git a/llm_studio/src/plots/text_dpo_modeling_plots.py b/llm_studio/src/plots/text_dpo_modeling_plots.py index 41cee6c78..0b464e946 100644 --- a/llm_studio/src/plots/text_dpo_modeling_plots.py +++ b/llm_studio/src/plots/text_dpo_modeling_plots.py @@ -11,7 +11,7 @@ plot_validation_predictions, ) from llm_studio.src.utils.data_utils import read_dataframe_drop_missing_labels -from llm_studio.src.utils.plot_utils import PlotData, format_for_markdown_visualization +from llm_studio.src.utils.plot_utils import PlotData from llm_studio.src.utils.utils import PatchedAttribute @@ -129,9 +129,6 @@ def plot_data(cls, cfg) -> PlotData: ] i += 1 - df_transposed["Content"] = df_transposed["Content"].apply( - format_for_markdown_visualization - ) df_transposed.to_parquet(path) return PlotData(path, encoding="df") diff --git a/llm_studio/src/utils/plot_utils.py b/llm_studio/src/utils/plot_utils.py index b2d3bf1e0..f83989867 100644 --- a/llm_studio/src/utils/plot_utils.py +++ b/llm_studio/src/utils/plot_utils.py @@ -1,5 +1,4 @@ import html -import re from dataclasses import dataclass from typing import List @@ -66,31 +65,6 @@ def decode_bytes(chunks: List[bytes]): return decoded_tokens -def format_for_markdown_visualization(text: str) -> str: - """ - Convert newlines to
tags, except for those inside code blocks. - This is needed because the markdown_table_cell_type() function does not - convert newlines to
tags, so we have to do it ourselves. - - This function is rather simple and may fail on text that uses ` - in some other context than marking code cells or uses ` within - the code itself (as this function). - """ - code_block_regex = r"(```.*?```|``.*?``)" - parts = re.split(code_block_regex, text, flags=re.DOTALL) - for i in range(len(parts)): - # Only substitute for text outside matched code blocks - if "`" not in parts[i]: - parts[i] = parts[i].replace("\n", "
").strip() - text = "".join(parts) - - # Restore newlines around code blocks, needed for correct rendering - for x in ["```", "``", "`"]: - text = text.replace(f"
{x}", f"\n{x}") - text = text.replace(f"{x}
", f"{x}\n") - return html.escape(text.replace("
", "\n")) - - def list_to_markdown_representation( tokens: List[str], masks: List[bool], pad_token: int, num_chars: int = 65 ):