Skip to content

Commit

Permalink
Use the ai2-llm wandb team if possible (#223)
Browse files Browse the repository at this point in the history
* use ai2 entity

* set default wandb project

* add tags

* use the same wandb project

* automatically add some tags

* format

* quick fix
  • Loading branch information
vwxyzjn authored Aug 7, 2024
1 parent bf1a0ce commit 0a327ff
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 5 deletions.
14 changes: 12 additions & 2 deletions open_instruct/dpo_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,12 @@
concatenated_forward,
dpo_loss,
)
from open_instruct.utils import ArgumentParserPlus, FlatArguments
from open_instruct.utils import (
ArgumentParserPlus,
FlatArguments,
get_wandb_tags,
maybe_use_ai2_wandb_entity,
)

logger = get_logger(__name__)

Expand Down Expand Up @@ -514,8 +519,13 @@ def load_model():
experiment_config = vars(args)
# TensorBoard cannot log Enums, need the raw value
experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"]
if args.wandb_entity is None:
args.wandb_entity = maybe_use_ai2_wandb_entity()
exp_name = os.path.basename(__file__)[: -len(".py")]
accelerator.init_trackers(
"open_instruct_dpo", experiment_config, init_kwargs={"wandb": {"entity": args.wandb_entity}}
"open_instruct_internal",
experiment_config,
init_kwargs={"wandb": {"entity": args.wandb_entity, "tags": [exp_name] + get_wandb_tags()}},
)

# Train!
Expand Down
15 changes: 13 additions & 2 deletions open_instruct/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,13 @@
get_scheduler,
)

from open_instruct.utils import ArgumentParserPlus, FlatArguments, get_datasets
from open_instruct.utils import (
ArgumentParserPlus,
FlatArguments,
get_datasets,
get_wandb_tags,
maybe_use_ai2_wandb_entity,
)

logger = get_logger(__name__)

Expand Down Expand Up @@ -545,8 +551,13 @@ def main(args: FlatArguments):
experiment_config = vars(args)
# TensorBoard cannot log Enums, need the raw value
experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"]
if args.wandb_entity is None:
args.wandb_entity = maybe_use_ai2_wandb_entity()
exp_name = os.path.basename(__file__)[: -len(".py")]
accelerator.init_trackers(
"open_instruct_sft", experiment_config, init_kwargs={"wandb": {"entity": args.wandb_entity}}
"open_instruct_internal",
experiment_config,
init_kwargs={"wandb": {"entity": args.wandb_entity, "tags": [exp_name] + get_wandb_tags()}},
)

# Train!
Expand Down
81 changes: 80 additions & 1 deletion open_instruct/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,14 @@
# limitations under the License.

import dataclasses
import logging
import os
import subprocess
import sys
from dataclasses import dataclass, field
from typing import Any, List, NewType, Optional, Tuple, Union

import requests
from datasets import DatasetDict, concatenate_datasets, load_dataset, load_from_disk
from datasets.builder import DatasetGenerationError
from transformers import MODEL_FOR_CAUSAL_LM_MAPPING, HfArgumentParser
Expand Down Expand Up @@ -581,6 +584,82 @@ def __post_init__(self):
raise ValueError("Cannot provide two dataset selection mechanisms.")


def maybe_use_ai2_wandb_entity() -> Optional[str]:
"""Ai2 internal logic: try use the ai2-llm team if possible. Should not affect external users."""
import wandb

wandb.login()
api = wandb.Api()
current_user = api.viewer
teams = current_user.teams
if "ai2-llm" in teams:
return "ai2-llm"
else:
return None


def get_git_tag() -> str:
"""Try to get the latest Git tag (e.g., `no-tag-404-g98dc659` or `v1.0.0-4-g98dc659`)"""
git_tag = ""
try:
git_tag = (
subprocess.check_output(["git", "describe", "--tags"], stderr=subprocess.DEVNULL).decode("ascii").strip()
)
except subprocess.CalledProcessError as e:
logging.debug(f"Failed to get Git tag: {e}")

# If no Git tag found, create a custom tag based on commit count and hash
if len(git_tag) == 0:
try:
count = int(
subprocess.check_output(["git", "rev-list", "--count", "HEAD"], stderr=subprocess.DEVNULL)
.decode("ascii")
.strip()
)
hash = (
subprocess.check_output(["git", "rev-parse", "--short", "HEAD"], stderr=subprocess.DEVNULL)
.decode("ascii")
.strip()
)
git_tag = f"no-tag-{count}-g{hash}"
except subprocess.CalledProcessError as e:
logging.debug(f"Failed to get commit count and hash: {e}")

return git_tag


def get_pr_tag() -> str:
"""Try to find associated pull request on GitHub (e.g., `pr-123`)"""
pr_tag = ""
git_commit = (
subprocess.check_output(["git", "rev-parse", "--verify", "HEAD"], stderr=subprocess.DEVNULL)
.decode("ascii")
.strip()
)
try:
# try finding the pull request number on github
prs = requests.get(f"https://api.github.com/search/issues?q=repo:allenai/open-instruct+is:pr+{git_commit}")
if prs.status_code == 200:
prs = prs.json()
if len(prs["items"]) > 0:
pr = prs["items"][0]
pr_number = pr["number"]
pr_tag = f"pr-{pr_number}"
except Exception as e:
logging.debug(f"Failed to get PR number: {e}")

return pr_tag


def get_wandb_tags() -> List[str]:
"""Get tags for Weights & Biases (e.g., `no-tag-404-g98dc659,pr-123`)"""
existing_wandb_tags = os.environ.get("WANDB_TAGS", "")
git_tag = get_git_tag()
pr_tag = get_pr_tag()
non_empty_tags = [tag for tag in [existing_wandb_tags, git_tag, pr_tag] if len(tag) > 0]
return non_empty_tags


class ArgumentParserPlus(HfArgumentParser):
def parse_yaml_and_args(self, yaml_arg: str, other_args: Optional[List[str]] = None) -> List[dataclass]:
"""
Expand Down Expand Up @@ -639,7 +718,7 @@ def parse_yaml_and_args(self, yaml_arg: str, other_args: Optional[List[str]] = N

return outputs

def parse(self) -> DataClassType | Tuple[DataClassType]:
def parse(self) -> Union[DataClassType, Tuple[DataClassType]]:
if len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
# If we pass only one argument to the script and it's the path to a YAML file,
# let's parse it to get our arguments.
Expand Down

0 comments on commit 0a327ff

Please sign in to comment.