-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add script to determine ideal number of shards
This re-uses some code from the `monitoring/` dir with random modifications (maybe in the future we could de-duplicate these) but for now the code is fairly spaghetti-esque. `determine_shards.py` helps generate content for PRs like apache/tvm#12473
- Loading branch information
Showing
9 changed files
with
931 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
__pycache__/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
.httpcache/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# developer scripts | ||
|
||
This is a collection of random scripts that are helpful to developing on TVM's CI. As one-off scripts these do not conform to the usual TVM quality standards which is why they are stored out of tree. | ||
|
||
## `determine_shards.py` | ||
|
||
Given a goal runtime for each test shard and a Jenkins job, print out the number of shards that should be used for each step. | ||
|
||
```bash | ||
# print out number of shards per test step | ||
python determine_shards.py --runtime-goal-m 90 --branch PR-12473 | ||
|
||
# see bottleneck steps individually | ||
python determine_shards.py --runtime-goal-m 90 --branch PR-12473 --list-steps | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
import argparse | ||
import asyncio | ||
import re | ||
import statistics | ||
import math | ||
import rich | ||
|
||
from typing import * | ||
|
||
from utils import forward | ||
from utils.forward import * | ||
|
||
|
||
def is_parallelizable(name: str, desc: str) -> bool: | ||
descs = { | ||
"Run CPU integration tests", | ||
"Run Hexagon tests", | ||
"Run Python GPU integration tests", | ||
"Run Python GPU unit tests", | ||
"Run Python frontend tests", | ||
"Run Python unit tests", | ||
"Run VTA tests in FSIM", | ||
"Run VTA tests in TSIM", | ||
"Run i386 integration tests", | ||
"Run test_arm_compute_lib test", | ||
"Run TOPI tests", | ||
"Run microTVM tests", | ||
} | ||
if name in descs: | ||
return True | ||
return False | ||
|
||
|
||
def analyze_stages(stage_name: str, stages: List[Stage], goal_runtime_m: float): | ||
steps_across_shards = {} | ||
for stage in stages: | ||
for step in stage.steps: | ||
if step.name not in steps_across_shards: | ||
steps_across_shards[step.name] = [] | ||
steps_across_shards[step.name].append(step) | ||
|
||
fixed_runtime_m = 0 | ||
parallelizable_runtime_m = 0 | ||
for name, steps in steps_across_shards.items(): | ||
parallelizable = is_parallelizable(name, "") | ||
median_runtime_m = ( | ||
statistics.median([step.duration_ms for step in steps]) / 1000.0 / 60.0 | ||
) | ||
total_runtime_m = sum([step.duration_ms for step in steps]) / 1000.0 / 60.0 | ||
if parallelizable: | ||
parallelizable_runtime_m += total_runtime_m | ||
else: | ||
fixed_runtime_m += median_runtime_m | ||
|
||
parallel_part = goal_runtime_m - fixed_runtime_m | ||
print(stage_name) | ||
if parallel_part <= 0: | ||
print( | ||
f" fixed runtime is too long ({round(fixed_runtime_m, 2)}), cannot reach goal time" | ||
) | ||
return | ||
|
||
num_shards = parallelizable_runtime_m / parallel_part | ||
num_shards = math.ceil(num_shards) | ||
|
||
print(f" fixed runtime (m): {round(fixed_runtime_m, 2)}") | ||
print(f" parallel runtime (m): {round(parallelizable_runtime_m, 2)}") | ||
print(f" required shards: {num_shards}") | ||
|
||
|
||
def list_steps(build: Build): | ||
def total_rt(stage: Stage): | ||
return sum(step.duration_ms for step in stage.steps) | ||
|
||
build.stages = sorted(build.stages, key=total_rt) | ||
print("For build at", build.blue_url) | ||
for stage in build.stages: | ||
if stage.name in {"Build", "Test", "Deploy"}: | ||
continue | ||
total = sum(step.duration_ms for step in stage.steps) | ||
if len(stage.steps) == 0: | ||
rich.print(f"{stage.name}: skipped") | ||
continue | ||
median = statistics.median([step.duration_ms for step in stage.steps]) | ||
m75 = statistics.median( | ||
[step.duration_ms for step in stage.steps if step.duration_ms > median] | ||
) | ||
rich.print(f"{stage.name}: {round(total /1000.0/60.0)}m") | ||
for step in stage.steps: | ||
if step.duration_ms > m75: | ||
rich.print( | ||
f" [bold red]{step.name}[/bold red]: {round(step.duration_ms / 1000.0 / 60.0, 2)}" | ||
) | ||
elif step.duration_ms > median: | ||
rich.print( | ||
f" [magenta]{step.name}[/magenta]: {round(step.duration_ms / 1000.0 / 60.0, 2)}" | ||
) | ||
else: | ||
rich.print( | ||
f" {step.name}: {round(step.duration_ms / 1000.0 / 60.0, 2)}" | ||
) | ||
|
||
|
||
def analyze(build: Build, goal_runtime_m: float): | ||
test_stages: List[Stage] = [] | ||
should_add = False | ||
for stage in build.stages: | ||
if stage.name == "Test": | ||
should_add = True | ||
elif stage.name == "Deploy": | ||
should_add = False | ||
elif should_add: | ||
test_stages.append(stage) | ||
|
||
names_to_stages = {} | ||
for stage in test_stages: | ||
names_to_stages[stage.name] = stage | ||
|
||
merged_shards = {} | ||
for stage in test_stages: | ||
m = re.match(r"(.*) \d+ of \d+", stage.name) | ||
if m: | ||
base_name = m.groups()[0] | ||
if base_name not in merged_shards: | ||
merged_shards[base_name] = [] | ||
merged_shards[base_name].append(stage) | ||
else: | ||
merged_shards[stage.name] = [stage] | ||
|
||
for name, stages in merged_shards.items(): | ||
analyze_stages(name, stages, goal_runtime_m) | ||
|
||
|
||
async def main(args): | ||
async with aiohttp.ClientSession() as s: | ||
forward.SESSION = s | ||
data = await fetch_branch(name=args.branch) | ||
return data | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser( | ||
description="Determine number of Jenkins shards to use" | ||
) | ||
parser.add_argument("--runtime-goal-m", required=True) | ||
parser.add_argument("--list-steps", action="store_true") | ||
parser.add_argument("--branch", default="main") | ||
parser.add_argument("--build", default="4082") | ||
args = parser.parse_args() | ||
init(dir=".httpcache") | ||
init_log() | ||
|
||
branch = asyncio.run(main(args)) | ||
build = branch.builds[0] | ||
|
||
if args.list_steps: | ||
list_steps(build) | ||
else: | ||
print(f"To reach goal runtime of {args.runtime_goal_m} for tests:") | ||
analyze(build, goal_runtime_m=float(args.runtime_goal_m)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
import os | ||
from sqlalchemy import create_engine | ||
|
||
from sqlalchemy.dialects.postgresql import insert | ||
|
||
|
||
def connection_string(db="tvm"): | ||
host = os.environ["db_host"] | ||
password = os.environ["db_password"] | ||
user = os.environ["db_user"] | ||
|
||
if db is None: | ||
return f"postgresql://{user}:{password}@{host}" | ||
else: | ||
return f"postgresql://{user}:{password}@{host}/{db}" | ||
|
||
|
||
engine = None | ||
|
||
|
||
def get_engine(connection_string: str): | ||
global engine | ||
if engine is None: | ||
engine = create_engine(connection_string, echo=bool(os.getenv("ECHO", False))) | ||
|
||
return engine | ||
|
||
|
||
def clear_engine(): | ||
global engine | ||
engine = None | ||
|
||
|
||
def upsert(engine, model, insert_dict): | ||
""" | ||
Insert or update to an engine backed by MySQL | ||
""" | ||
inserted = insert(model).values(**insert_dict) | ||
# MySQL version: | ||
# upserted = inserted.on_duplicate_key_update( | ||
# **{k: inserted.inserted[k] for k, v in insert_dict.items()} | ||
# ) | ||
|
||
# Postgres version: | ||
upserted = inserted.on_conflict_do_update( | ||
index_elements=model._pks, | ||
# index_where=my_table.c.user_email.like("%@gmail.com"), | ||
set_=insert_dict, | ||
) | ||
res = engine.execute(upserted) | ||
return res.lastrowid |
Oops, something went wrong.