diff --git a/app.py b/app.py
index c26a4f3d9..52144a767 100644
--- a/app.py
+++ b/app.py
@@ -15,13 +15,13 @@
logger = logging.getLogger(__name__)
-def on_startup():
+def on_startup() -> None:
initialize_logging()
logger.info("STARTING APP")
@app("/", on_startup=on_startup)
-async def serve(q: Q):
+async def serve(q: Q) -> None:
"""Serving function."""
# Chat is still being streamed but user clicks on another button.
diff --git a/llm_studio/app_utils/config.py b/llm_studio/app_utils/config.py
index eeea1b218..0d6fb2554 100644
--- a/llm_studio/app_utils/config.py
+++ b/llm_studio/app_utils/config.py
@@ -83,7 +83,6 @@ def get_size(x):
"validation_dataframe",
],
"user_settings": {
- "theme_dark": True,
"credential_saver": ".env File",
"default_aws_bucket_name": f"{os.getenv('AWS_BUCKET', 'bucket_name')}",
"default_aws_access_key": os.getenv("AWS_ACCESS_KEY_ID", ""),
diff --git a/llm_studio/app_utils/handlers.py b/llm_studio/app_utils/handlers.py
index aeb829032..0a5f0e796 100644
--- a/llm_studio/app_utils/handlers.py
+++ b/llm_studio/app_utils/handlers.py
@@ -424,7 +424,7 @@ async def handle(q: Q) -> None:
)
-async def experiment_delete_all_artifacts(q: Q, experiment_ids: List[int]):
+async def experiment_delete_all_artifacts(q: Q, experiment_ids: List[int]) -> None:
await experiment_stop(q, experiment_ids)
await experiment_delete(q, experiment_ids)
await list_current_experiments(q)
diff --git a/llm_studio/app_utils/sections/chat.py b/llm_studio/app_utils/sections/chat.py
index d380699e5..4b64faad1 100644
--- a/llm_studio/app_utils/sections/chat.py
+++ b/llm_studio/app_utils/sections/chat.py
@@ -191,6 +191,7 @@ def gpu_is_blocked(q, gpu_id):
def load_cfg_model_tokenizer(
experiment_path: str, merge: bool = False, device: str = "cuda:0"
):
+ """Loads the model, tokenizer and configuration from the experiment path."""
cfg = load_config_yaml(os.path.join(experiment_path, "cfg.yaml"))
cfg.architecture.pretrained = False
cfg.architecture.gradient_checkpointing = False
diff --git a/llm_studio/app_utils/sections/chat_update.py b/llm_studio/app_utils/sections/chat_update.py
index 63111adac..833641359 100644
--- a/llm_studio/app_utils/sections/chat_update.py
+++ b/llm_studio/app_utils/sections/chat_update.py
@@ -267,7 +267,7 @@ async def show_stream_is_aborted_dialog(q):
await q.page.save()
-async def is_app_blocked_while_streaming(q: Q):
+async def is_app_blocked_while_streaming(q: Q) -> bool:
"""
Check whether the app is blocked with current answer generation.
"""
diff --git a/llm_studio/app_utils/sections/common.py b/llm_studio/app_utils/sections/common.py
index 634a58343..2ea9ca9d0 100644
--- a/llm_studio/app_utils/sections/common.py
+++ b/llm_studio/app_utils/sections/common.py
@@ -10,6 +10,84 @@
logger = logging.getLogger(__name__)
+# https://github.com/highlightjs/highlight.js/blob/main/src/styles/atom-one-dark.css.
+css = """
+.hljs {
+ color: #abb2bf;
+}
+
+.hljs-comment,
+.hljs-quote {
+ color: #5c6370;
+ font-style: italic;
+}
+
+.hljs-doctag,
+.hljs-keyword,
+.hljs-formula {
+ color: #c678dd;
+}
+
+.hljs-section,
+.hljs-name,
+.hljs-selector-tag,
+.hljs-deletion,
+.hljs-subst {
+ color: #e06c75;
+}
+
+.hljs-literal {
+ color: #56b6c2;
+}
+
+.hljs-string,
+.hljs-regexp,
+.hljs-addition,
+.hljs-attribute,
+.hljs-meta .hljs-string {
+ color: #98c379;
+}
+
+.hljs-attr,
+.hljs-variable,
+.hljs-template-variable,
+.hljs-type,
+.hljs-selector-class,
+.hljs-selector-attr,
+.hljs-selector-pseudo,
+.hljs-number {
+ color: #d19a66;
+}
+
+.hljs-symbol,
+.hljs-bullet,
+.hljs-link,
+.hljs-meta,
+.hljs-selector-id,
+.hljs-title {
+ color: #61aeee;
+}
+
+.hljs-built_in,
+.hljs-title.class_,
+.hljs-class .hljs-title {
+ color: #e6c07b;
+}
+
+.hljs-emphasis {
+ font-style: italic;
+}
+
+.hljs-strong {
+ font-weight: bold;
+}
+
+.hljs-link {
+ text-decoration: underline;
+}
+"""
+
+
async def meta(q: Q) -> None:
if q.client["keep_meta"]: # Do not reset meta, keep current dialog opened
q.client["keep_meta"] = False
@@ -37,52 +115,12 @@ async def meta(q: Q) -> None:
scripts=[
ui.script(source, asynchronous=True) for source in q.app["script_sources"]
],
- stylesheet=ui.inline_stylesheet(
- """
- .ms-MessageBar {
- padding-top: 3px;
- padding-bottom: 3px;
- min-height: 18px;
- }
- div[data-test="nav_bar"] .ms-Nav-groupContent {
- margin-bottom: 0;
- }
-
- div[data-test="experiment/display/deployment/top_right"],
- div[data-test="experiment/display/deployment/top_right"]
- div[data-visible="true"]:last-child > div > div {
- display: flex;
- }
-
- div[data-test="experiment/display/deployment/top_right"]
- div[data-visible="true"]:last-child,
- div[data-test="experiment/display/deployment/top_right"]
- div[data-visible="true"]:last-child > div {
- display: flex;
- flex-grow: 1;
- }
-
- div[data-test="experiment/display/deployment/top_right"]
- div[data-visible="true"]:last-child > div > div > div {
- display: flex;
- flex-grow: 1;
- flex-direction: column;
- }
-
- div[data-test="experiment/display/deployment/top_right"]
- div[data-visible="true"]:last-child > div > div > div > div {
- flex-grow: 1;
- }
- """
- ),
+ stylesheet=ui.inline_stylesheet(css),
script=None,
notification_bar=notification_bar,
)
- if q.client.theme_dark:
- q.page["meta"].theme = "h2o-dark"
- else:
- q.page["meta"].theme = "light"
+ q.page["meta"].theme = "h2o-dark"
def heap_analytics(
diff --git a/llm_studio/app_utils/sections/dataset.py b/llm_studio/app_utils/sections/dataset.py
index dfc585bbe..899d85ea4 100644
--- a/llm_studio/app_utils/sections/dataset.py
+++ b/llm_studio/app_utils/sections/dataset.py
@@ -1100,7 +1100,7 @@ async def dataset_delete_single(q: Q, dataset_id: int):
async def dataset_display(q: Q) -> None:
"""Display a selected dataset."""
- dataset_id = q.client["dataset/list/df_datasets"]["id"].iloc[
+ dataset_id: int = q.client["dataset/list/df_datasets"]["id"].iloc[
q.client["dataset/display/id"]
]
dataset: Dataset = q.client.app_db.get_dataset(dataset_id)
@@ -1196,7 +1196,7 @@ async def show_data_tab(q, cfg, filename: str):
q.client.delete_cards.add("dataset/display/data")
-async def show_visualization_tab(q, cfg):
+async def show_visualization_tab(q: Q, cfg):
try:
plot = cfg.logging.plots_class.plot_data(cfg)
except Exception as error:
@@ -1236,7 +1236,7 @@ async def show_visualization_tab(q, cfg):
q.client.delete_cards.add("dataset/display/visualization")
-async def show_summary_tab(q, dataset_id):
+async def show_summary_tab(q: Q, dataset_id: int) -> None:
dataset_df = get_datasets(q)
dataset_df = dataset_df[dataset_df.id == dataset_id]
stat_list_items: List[StatListItem] = []
@@ -1253,7 +1253,9 @@ async def show_summary_tab(q, dataset_id):
q.client.delete_cards.add("dataset/display/summary")
-async def show_statistics_tab(q, dataset_filename, config_filename):
+async def show_statistics_tab(
+ q: Q, dataset_filename: str, config_filename: str
+) -> None:
cfg_hash = hashlib.md5(open(config_filename, "rb").read()).hexdigest()
stats_dict = compute_dataset_statistics(dataset_filename, config_filename, cfg_hash)
@@ -1320,7 +1322,7 @@ async def show_statistics_tab(q, dataset_filename, config_filename):
@functools.lru_cache()
-def compute_dataset_statistics(dataset_path: str, cfg_path: str, cfg_hash: str):
+def compute_dataset_statistics(dataset_path: str, cfg_path: str, cfg_hash: str) -> dict:
"""
Compute various statistics for a dataset.
- text length distribution for prompts and answers
@@ -1362,7 +1364,7 @@ def compute_dataset_statistics(dataset_path: str, cfg_path: str, cfg_hash: str):
return stats_dict
-async def dataset_import_uploaded_file(q: Q):
+async def dataset_import_uploaded_file(q: Q) -> None:
local_path = await q.site.download(
q.args["dataset/import/local_upload"][0],
f"{get_data_dir(q)}/"
@@ -1378,7 +1380,7 @@ async def dataset_import_uploaded_file(q: Q):
await dataset_import(q, step=1, error=error)
-async def dataset_delete_current_datasets(q: Q):
+async def dataset_delete_current_datasets(q: Q) -> None:
dataset_ids = list(
q.client["dataset/list/df_datasets"]["id"].iloc[
list(map(int, q.client["dataset/list/table"]))
diff --git a/llm_studio/app_utils/sections/experiment.py b/llm_studio/app_utils/sections/experiment.py
index 76db8da41..d76816849 100644
--- a/llm_studio/app_utils/sections/experiment.py
+++ b/llm_studio/app_utils/sections/experiment.py
@@ -1421,6 +1421,7 @@ async def summary_tab(experiment_id, q):
box=ui.box(zone="third"),
title="",
content=content,
+ compact=False,
)
q.client.delete_cards.add(card_name)
@@ -1495,7 +1496,7 @@ def unite_validation_metric_charts(charts_list):
return charts_list
-async def charts_tab(q, charts_list, legend_labels):
+async def charts_tab(q, charts_list, legend_labels) -> None:
charts_list = unite_validation_metric_charts(charts_list)
box = ["first", "first", "second", "second"]
diff --git a/llm_studio/app_utils/sections/settings.py b/llm_studio/app_utils/sections/settings.py
index df9d5db5c..4a1fa2d2a 100644
--- a/llm_studio/app_utils/sections/settings.py
+++ b/llm_studio/app_utils/sections/settings.py
@@ -54,17 +54,6 @@ async def settings(q: Q) -> None:
""",
),
ui.separator("Appearance"),
- ui.inline(
- items=[
- ui.label("Dark Mode", width=label_width),
- ui.toggle(
- name="theme_dark",
- value=q.client["theme_dark"],
- tooltip="Enables Dark Mode as theme.",
- trigger=True,
- ),
- ]
- ),
ui.inline(
items=[
ui.label("Delete Dialogs", width=label_width),
diff --git a/llm_studio/app_utils/wave_utils.py b/llm_studio/app_utils/wave_utils.py
index e23d7e807..4b10d6ee9 100644
--- a/llm_studio/app_utils/wave_utils.py
+++ b/llm_studio/app_utils/wave_utils.py
@@ -13,20 +13,14 @@
class ThemeColors(TypedDict):
- light: dict
- dark: dict
+ primary: str
+ background_color: str
class WaveTheme:
_theme_colors: ThemeColors = {
- "light": {
- "primary": "#000000",
- "background_color": "#ffffff",
- },
- "dark": {
- "primary": "#FEC925",
- "background_color": "#121212",
- },
+ "primary": "#FEC925",
+ "background_color": "#121212",
}
states = {
@@ -43,20 +37,12 @@ class WaveTheme:
def __repr__(self) -> str:
return "WaveTheme"
- def get_value_by_key(self, q: Q, key: str):
- value = (
- self._theme_colors["dark"][key]
- if q.client.theme_dark
- else self._theme_colors["light"][key]
- )
- return value
-
def get_primary_color(self, q: Q):
- primary_color = self.get_value_by_key(q, "primary")
+ primary_color = self._theme_colors["primary"]
return primary_color
def get_background_color(self, q: Q):
- background_color = self.get_value_by_key(q, "background_color")
+ background_color = self._theme_colors["background_color"]
return background_color
diff --git a/llm_studio/python_configs/base.py b/llm_studio/python_configs/base.py
index d0350cb33..e1a5af575 100644
--- a/llm_studio/python_configs/base.py
+++ b/llm_studio/python_configs/base.py
@@ -204,6 +204,7 @@ class DefaultConfigProblemBase(DefaultConfig):
experiment_name: str
output_directory: str
llm_backbone: str
+ _parent_experiment: str
dataset: Any
tokenizer: Any
diff --git a/llm_studio/src/order.py b/llm_studio/src/order.py
index 3f7a4ef14..2bd1e471b 100644
--- a/llm_studio/src/order.py
+++ b/llm_studio/src/order.py
@@ -20,12 +20,12 @@ def __init__(self, keys: Optional[List[str]] = None):
else:
self._list = list()
- def _unique_guard(self, *keys: str):
+ def _unique_guard(self, *keys: str) -> None:
for key in keys:
if key in self._list:
raise ValueError(f"`{key}` is already in the list!")
- def append(self, key: str):
+ def append(self, key: str) -> None:
"""
Append a key at the end of the list:
diff --git a/llm_studio/src/plots/text_causal_language_modeling_plots.py b/llm_studio/src/plots/text_causal_language_modeling_plots.py
index a2aeda47b..c743c4cc7 100644
--- a/llm_studio/src/plots/text_causal_language_modeling_plots.py
+++ b/llm_studio/src/plots/text_causal_language_modeling_plots.py
@@ -7,11 +7,7 @@
from llm_studio.src.datasets.conversation_chain_handler import get_conversation_chains
from llm_studio.src.datasets.text_utils import get_tokenizer
from llm_studio.src.utils.data_utils import read_dataframe_drop_missing_labels
-from llm_studio.src.utils.plot_utils import (
- PlotData,
- format_for_markdown_visualization,
- list_to_markdown_representation,
-)
+from llm_studio.src.utils.plot_utils import PlotData, list_to_markdown_representation
class Plots:
@@ -90,10 +86,6 @@ def plot_data(cls, cfg) -> PlotData:
]
i += 1
- df_transposed["Content"] = df_transposed["Content"].apply(
- format_for_markdown_visualization
- )
-
df_transposed.to_parquet(path)
return PlotData(path, encoding="df")
@@ -153,11 +145,6 @@ def plot_validation_predictions(
"Predicted Text": predicted_texts,
}
)
- df[input_text_column_name] = df[input_text_column_name].apply(
- format_for_markdown_visualization
- )
- df["Target Text"] = df["Target Text"].apply(format_for_markdown_visualization)
- df["Predicted Text"] = df["Predicted Text"].apply(format_for_markdown_visualization)
if val_outputs.get("metrics") is not None:
metric_column_name = f"Metric ({cfg.prediction.metric})"
@@ -195,7 +182,7 @@ def create_batch_prediction_df(
]
}
)
- df["Prompt Text"] = df["Prompt Text"].apply(format_for_markdown_visualization)
+
if labels_column in batch.keys():
df["Answer Text"] = [
tokenizer.decode(
diff --git a/llm_studio/src/plots/text_dpo_modeling_plots.py b/llm_studio/src/plots/text_dpo_modeling_plots.py
index 41cee6c78..0b464e946 100644
--- a/llm_studio/src/plots/text_dpo_modeling_plots.py
+++ b/llm_studio/src/plots/text_dpo_modeling_plots.py
@@ -11,7 +11,7 @@
plot_validation_predictions,
)
from llm_studio.src.utils.data_utils import read_dataframe_drop_missing_labels
-from llm_studio.src.utils.plot_utils import PlotData, format_for_markdown_visualization
+from llm_studio.src.utils.plot_utils import PlotData
from llm_studio.src.utils.utils import PatchedAttribute
@@ -129,9 +129,6 @@ def plot_data(cls, cfg) -> PlotData:
]
i += 1
- df_transposed["Content"] = df_transposed["Content"].apply(
- format_for_markdown_visualization
- )
df_transposed.to_parquet(path)
return PlotData(path, encoding="df")
diff --git a/llm_studio/src/utils/plot_utils.py b/llm_studio/src/utils/plot_utils.py
index b2d3bf1e0..f83989867 100644
--- a/llm_studio/src/utils/plot_utils.py
+++ b/llm_studio/src/utils/plot_utils.py
@@ -1,5 +1,4 @@
import html
-import re
from dataclasses import dataclass
from typing import List
@@ -66,31 +65,6 @@ def decode_bytes(chunks: List[bytes]):
return decoded_tokens
-def format_for_markdown_visualization(text: str) -> str:
- """
- Convert newlines to
tags, except for those inside code blocks.
- This is needed because the markdown_table_cell_type() function does not
- convert newlines to
tags, so we have to do it ourselves.
-
- This function is rather simple and may fail on text that uses `
- in some other context than marking code cells or uses ` within
- the code itself (as this function).
- """
- code_block_regex = r"(```.*?```|``.*?``)"
- parts = re.split(code_block_regex, text, flags=re.DOTALL)
- for i in range(len(parts)):
- # Only substitute for text outside matched code blocks
- if "`" not in parts[i]:
- parts[i] = parts[i].replace("\n", "
").strip()
- text = "".join(parts)
-
- # Restore newlines around code blocks, needed for correct rendering
- for x in ["```", "``", "`"]:
- text = text.replace(f"
{x}", f"\n{x}")
- text = text.replace(f"{x}
", f"{x}\n")
- return html.escape(text.replace("
", "\n"))
-
-
def list_to_markdown_representation(
tokens: List[str], masks: List[bool], pad_token: int, num_chars: int = 65
):