Skip to content

Commit

Permalink
Misc Improvement (#550)
Browse files Browse the repository at this point in the history
* ...

* minor
  • Loading branch information
junrushao authored Dec 14, 2021
1 parent f998006 commit 495074f
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 53 deletions.
21 changes: 12 additions & 9 deletions python/tvm/meta_schedule/cost_model/xgb_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,8 +319,11 @@ def load(self, path: str) -> None:
Since XGBoost model trains from scratch, each time we can only load the model without the
previous cached features / results so any call of update won't use previous training data.
"""
import xgboost as xgb # pylint: disable=import-outside-toplevel

with tempfile.TemporaryDirectory() as tmp_dir:
untar(path, tmp_dir)
self.booster = xgb.Booster()
self.booster.load_model(os.path.join(tmp_dir, "model.bin"))
self.cached_features = list(
np.load(os.path.join(tmp_dir, "cached_features.npy"), allow_pickle=True)
Expand All @@ -346,23 +349,24 @@ def save(self, path: str) -> None:
import xgboost as xgb # pylint: disable=import-outside-toplevel

if self.booster is None:
# save all the paramaters
# save all the parameters
self.booster = xgb.Booster(self.config.to_dict())
with tempfile.TemporaryDirectory() as tmpdirname:
self.booster.save_model(os.path.join(tmpdirname, "model.bin"))
with tempfile.TemporaryDirectory() as tmp_dir:
self.booster.save_model(os.path.join(tmp_dir, "model.bin"))
np.save(
os.path.join(tmpdirname, "cached_features.npy"),
os.path.join(tmp_dir, "cached_features.npy"),
np.array(self.cached_features, dtype=object),
)
np.save(os.path.join(tmpdirname, "cached_mean_costs.npy"), self.cached_mean_costs)
np.save(os.path.join(tmp_dir, "cached_mean_costs.npy"), self.cached_mean_costs)
tar(
path,
[
os.path.join(tmpdirname, "model.bin"),
os.path.join(tmpdirname, "cached_features.npy"),
os.path.join(tmpdirname, "cached_mean_costs.npy"),
os.path.join(tmp_dir, "model.bin"),
os.path.join(tmp_dir, "cached_features.npy"),
os.path.join(tmp_dir, "cached_mean_costs.npy"),
],
)
logger.info("Saved XGBModel to %s", path)

def update(
self,
Expand Down Expand Up @@ -491,7 +495,6 @@ def average_peak_score(
)

del self.d_train
# todo(zxybazh): measure callback to save the model

def _predict( # type: ignore # pylint: disable=invalid-name
self,
Expand Down
15 changes: 7 additions & 8 deletions python/tvm/meta_schedule/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,10 +256,10 @@ def _runner(runner: Optional[Runner]) -> Runner:
return runner

@staticmethod
def _database(database: Union[None, Database], path: str) -> Database:
def _database(database: Union[None, Database], task_name: str, path: str) -> Database:
if database is None:
path_workload = os.path.join(path, "workload.json")
path_tuning_record = os.path.join(path, "tuning_record.json")
path_workload = os.path.join(path, f"{task_name}_database_workload.json")
path_tuning_record = os.path.join(path, f"{task_name}_database_tuning_record.json")
logger.info(
"Creating JSONDatabase. Workload at: %s. Tuning records at: %s",
path_workload,
Expand All @@ -269,8 +269,6 @@ def _database(database: Union[None, Database], path: str) -> Database:
path_workload=path_workload,
path_tuning_record=path_tuning_record,
)
elif callable(database):
database = database(path)
if not isinstance(database, Database):
raise TypeError(f"Expected `database` to be Database, but gets: {database}")
return database
Expand Down Expand Up @@ -496,7 +494,7 @@ def tune_tir(
logger.info("Working directory: %s", work_dir)
# pylint: disable=protected-access
mod = Parse._mod(mod)
database = Parse._database(database, work_dir)
database = Parse._database(database, task_name, work_dir)
tune_context = Parse._tune_context(
tune_context=None,
mod=mod,
Expand Down Expand Up @@ -529,6 +527,7 @@ def tune_tir(
assert len(bests) == 1
sch = Schedule(mod)
bests[0].trace.apply_to_schedule(sch, remove_postproc=False)
task_scheduler.cost_model.save(os.path.join(work_dir, f"{task_name}.xgb"))
return sch


Expand Down Expand Up @@ -663,7 +662,7 @@ def tune_relay(
# pylint: disable=protected-access
tune_contexts = []
target = Parse._target(target)
database = Parse._database(database, work_dir)
database = Parse._database(database, task_name, work_dir)
for task in extracted_tasks:
assert len(task.dispatched) == 1, "Only size 1 dispatched task list is supported for now"
mod = Parse._mod(task.dispatched[0])
Expand Down Expand Up @@ -692,7 +691,7 @@ def tune_relay(
)
# pylint: enable=protected-access
task_scheduler.tune()
schs = []
schs: List[Schedule] = []
for task in tune_contexts:
mod = task.mod
workload = database.commit_workload(mod)
Expand Down
45 changes: 33 additions & 12 deletions src/meta_schedule/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -278,17 +278,23 @@ inline int GetTargetNumCores(const Target& target) {
return num_cores;
}

/*!
* \brief A helper data structure that replays a trace and collects failure counts
* for each postprocessor
*/
struct ThreadedTraceApply {
const Array<Postproc>& postprocs;
std::vector<std::unique_ptr<std::atomic<int>>> fail_counter;

/*! \brief Constructor */
explicit ThreadedTraceApply(const Array<Postproc>& postprocs)
: postprocs(postprocs), fail_counter(postprocs.size()) {
for (std::unique_ptr<std::atomic<int>>& p : fail_counter) {
p = std::make_unique<std::atomic<int>>(0);
: n_(postprocs.size()), items_(new Item[n_]) {
for (int i = 0; i < n_; ++i) {
items_[i].postproc = postprocs[i];
items_[i].fail_counter = 0;
}
}

/*! \brief Destructor */
~ThreadedTraceApply() { delete[] items_; }

/*!
* \brief Apply the trace and postprocessors to an IRModule
* \param mod The IRModule to be applied
Expand All @@ -305,23 +311,38 @@ struct ThreadedTraceApply {
/*error_render_level=*/tir::ScheduleErrorRenderLevel::kNone);
trace->ApplyToSchedule(sch, /*remove_postproc=*/true);
sch->EnterPostproc();
for (int i = 0, n = postprocs.size(); i < n; ++i) {
if (!postprocs[i]->Apply(sch)) {
++*fail_counter[i];
for (int i = 0; i < n_; ++i) {
Item& item = items_[i];
if (!item.postproc->Apply(sch)) {
++item.fail_counter;
return NullOpt;
}
}
return sch;
}

/*! \brief Returns a string summarizing the failures on each postprocessor */
std::string SummarizeFailures() const {
std::ostringstream os;
for (int i = 0, n = postprocs.size(); i < n; ++i) {
os << "Postproc #" << i << " [" << postprocs[i] //
<< "]: " << *fail_counter[i] << " failure(s)\n";
for (int i = 0; i < n_; ++i) {
const Item& item = items_[i];
os << "Postproc #" << i << " [" << item.postproc //
<< "]: " << item.fail_counter.load() << " failure(s)";
if (i != n_ - 1) {
os << "\n";
}
}
return os.str();
}

private:
struct Item {
Postproc postproc{nullptr};
std::atomic<int> fail_counter{0};
};

int n_;
Item* items_;
};

} // namespace meta_schedule
Expand Down
11 changes: 7 additions & 4 deletions tests/python/meta_schedule/run_meta_schedule_cuda.sh
Original file line number Diff line number Diff line change
@@ -1,25 +1,28 @@
# set -euxo pipefail
set -euxo pipefail

RPC_HOST="192.168.6.66"
RPC_PORT="4445"
RPC_KEY="jetson-agx-xavier"
TARGET="nvidia/jetson-agx-xavier"
LOG_DIR=$HOME/logs/ms-cuda/
NUM_TRIALS=800
LOG_DIR=/tmp/logs/ms-cuda/
NUM_TRIALS=2000

mkdir -p $LOG_DIR

run () {
name=$1
work_dir=$LOG_DIR/$name/
mkdir -p $work_dir
echo "Running workload $name"
python tests/python/meta_schedule/test_meta_schedule.py \
--workload "$name" \
--target "$TARGET" \
--work-dir "$work_dir" \
--rpc-host "$RPC_HOST" \
--rpc-port "$RPC_PORT" \
--rpc-key "$RPC_KEY" \
--num-trials $NUM_TRIALS \
2>&1 | tee "$LOG_DIR/$name.log"
2>&1 | tee "$work_dir/$name.log"
}

# Single op
Expand Down
43 changes: 23 additions & 20 deletions tests/python/meta_schedule/test_meta_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import argparse
import logging
from os import cpu_count
import tempfile

import tvm
from tvm import meta_schedule as ms
Expand All @@ -43,6 +42,11 @@ def _parse_args():
type=int,
required=True,
)
args.add_argument(
"--work-dir",
type=str,
required=True,
)
args.add_argument(
"--rpc-host",
type=str,
Expand Down Expand Up @@ -85,25 +89,24 @@ def main():
alloc_repeat=3,
max_workers=ARGS.rpc_workers,
)
with tempfile.TemporaryDirectory() as work_dir:
sch: tir.Schedule = ms.tune_tir(
mod=create_te_workload(ARGS.workload, 0),
target=ARGS.target,
config=ms.EvolutionarySearchConfig(
num_trials_per_iter=64,
num_trials_total=ARGS.num_trials,
init_max_fail_count=1024,
),
runner=runner,
task_name=ARGS.workload,
work_dir=work_dir,
num_threads=cpu_count(),
)
if sch is None:
print("No valid schedule found!")
else:
print(sch.mod.script())
print(sch.trace)
sch: tir.Schedule = ms.tune_tir(
mod=create_te_workload(ARGS.workload, 0),
target=ARGS.target,
config=ms.EvolutionarySearchConfig(
num_trials_per_iter=64,
num_trials_total=ARGS.num_trials,
init_max_fail_count=8192,
),
runner=runner,
task_name=ARGS.workload,
work_dir=ARGS.work_dir,
num_threads=cpu_count(),
)
if sch is None:
print("No valid schedule found!")
else:
print(sch.mod.script())
print(sch.trace)


if __name__ == "__main__":
Expand Down

0 comments on commit 495074f

Please sign in to comment.