Skip to content

Commit

Permalink
[TIR] Reorder the block iters of the blocks generated by RFactor (#561)
Browse files Browse the repository at this point in the history
* RFactor block iter reordering

* Add class name back
  • Loading branch information
MasterJH5574 authored Dec 28, 2021
1 parent fad1415 commit 4179e8c
Show file tree
Hide file tree
Showing 2 changed files with 268 additions and 151 deletions.
71 changes: 40 additions & 31 deletions src/tir/schedule/primitive/reduction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -485,13 +485,11 @@ class LoopPropertyError : public ScheduleError {
if (loop.get() == rf_loop) {
throw LoopPropertyError(self->mod, loop, kDataParIterTouchRFactorLoop);
}
continue;
} else if (reduction_touched) {
if (!meet_reduction_loop) {
CheckGetSingleChildBlockRealizeOnSRefTree(self, self->stmt2ref.at(loop.get()));
meet_reduction_loop = true;
}
continue;
} else if (meet_reduction_loop && !is_one(loop->extent)) {
throw LoopPropertyError(self->mod, loop, kUnboundLoopUnderReductionLoop);
}
Expand Down Expand Up @@ -560,10 +558,13 @@ class BaseBlockCreator {
}

void CreateBlock() {
CreateAdditionalIter();
for (int i = 0; i < n_block_iters_; ++i) {
CreateNormalIters(i);
}
if (!additional_iter_.defined()) {
ICHECK(arith::Analyzer().CanProveEqual(rf_loop_->extent, Integer(1)));
CreateAdditionalIter();
}
CreateReductionUpdate();
CreateReadWriteRegions();

Expand All @@ -589,8 +590,8 @@ class BaseBlockCreator {
}

private:
virtual void CreateAdditionalIter() = 0;
virtual void CreateNormalIters(int idx) = 0;
virtual void CreateAdditionalIter() = 0;
virtual void CreateReductionUpdate() = 0;
virtual void CreateReadWriteRegions() = 0;

Expand All @@ -601,6 +602,8 @@ class BaseBlockCreator {
BlockRealize new_block_realize_;
/*! \brief The indices used to access the intermediate rfactor buffer */
Array<PrimExpr> rf_buf_access_indices_;
/*! \brief The additional block iter of the new created block for the rfactor loop. */
IterVar additional_iter_;

protected:
/*! \brief The old block-realize */
Expand Down Expand Up @@ -672,15 +675,6 @@ class RFactorBlockCreator : public BaseBlockCreator {
combiner_rhs_(std::move(combiner_rhs)) {}

private:
void CreateAdditionalIter() final {
// Create a new data parallel block iter for the rfactor loop.
additional_iter_ =
IterVarFromLoop(rf_loop_, "v" + rf_loop_->loop_var->name_hint, IterVarType::kDataPar);
loop_var2block_binding_[rf_loop_->loop_var.get()] = additional_iter_->var;
iter_vars_.push_back(additional_iter_);
iter_values_.push_back(rf_loop_->loop_var);
}

void CreateNormalIters(int idx) final {
IterVar old_iter = old_block_realize_->block->iter_vars[idx];
PrimExpr old_binding = old_block_realize_->iter_values[idx];
Expand All @@ -706,20 +700,31 @@ class RFactorBlockCreator : public BaseBlockCreator {
}
const For& loop = it->second;
if (loop_var2block_binding_.find(var.get()) == loop_var2block_binding_.end()) {
// We haven't created the new block iter for `var`. So here we create it, append it
// and its binding to `rf_block_iter_vars` and `rf_block_iter_values` respectively.
IterVar new_iter_var =
IterVarFromLoop(loop, "v" + loop->loop_var->name_hint, IterVarType::kCommReduce);
// - We haven't created the new block iter for `var`. So here we create it, append it
// and its binding to `rf_block_iter_vars` and `rf_block_iter_values` respectively.
// - If the loop is the rfactor loop, envoke `CreateAdditionalIter()`.
if (loop.same_as(rf_loop_)) {
CreateAdditionalIter();
continue;
}
IterVar new_iter_var = IterVarFromLoop(loop, "v" + loop->loop_var->name_hint, kCommReduce);
loop_var2block_binding_[var.get()] = new_iter_var->var;
iter_vars_.push_back(new_iter_var);
iter_values_.push_back(var);
}
}
// Substitute the original binding with new block iters. Store the result expression
// in `rf_var_map` for future substitution.
// in `var_map_` for future substitution.
var_map_.Set(old_iter->var, Substitute(old_binding, loop_var2block_binding_));
}

void CreateAdditionalIter() final {
additional_iter_ = IterVarFromLoop(rf_loop_, "v" + rf_loop_->loop_var->name_hint, kDataPar);
iter_vars_.insert(iter_vars_.end(), additional_iter_);
iter_values_.insert(iter_values_.end(), rf_loop_->loop_var);
loop_var2block_binding_[rf_loop_->loop_var.get()] = additional_iter_;
}

void CreateReductionUpdate() final {
rf_buf_access_indices_ = old_reduction_update_->indices;
rf_buf_access_indices_.insert(rf_buf_access_indices_.begin() + factor_axis_,
Expand Down Expand Up @@ -754,10 +759,6 @@ class RFactorBlockCreator : public BaseBlockCreator {
return new_regions;
}

public:
/*! \brief The generated additional block iter in rfactor block for the rfactor loop */
IterVar additional_iter_;

private:
/*!
* \brief A mapping which maps a loop var to its corresponding For loop for all the reduction
Expand Down Expand Up @@ -797,25 +798,33 @@ class WriteBackBlockCreator : public BaseBlockCreator {
}

private:
void CreateAdditionalIter() final {
// Create a new reduction block iter for the rfactor loop.
IterVar wb_new_block_iter =
IterVarFromLoop(rf_loop_, "v" + rf_loop_->loop_var->name_hint, kCommReduce);
iter_vars_.push_back(wb_new_block_iter);
iter_values_.push_back(rf_loop_->loop_var);
var_map_.Set(rf_additional_iter_->var, wb_new_block_iter->var);
}

void CreateNormalIters(int idx) final {
IterVar old_block_iter = old_block_realize_->block->iter_vars[idx];
if (old_block_iter->iter_type == IterVarType::kDataPar) {
iter_vars_.emplace_back(old_block_iter->dom, old_block_iter->var.copy_with_suffix(""),
kDataPar);
iter_values_.push_back(old_block_realize_->iter_values[idx]);
var_map_.Set(old_block_iter->var, iter_vars_.back());
return;
}

ICHECK(old_block_iter->iter_type == IterVarType::kCommReduce);
// If the old block iter touches the reduction loop and we have not created a new reduction
// block iter for the rfactor loop, create one now.
if (!additional_iter_.defined() &&
UsesVar(old_block_realize_->iter_values[idx],
[v = rf_loop_->loop_var.get()](const VarNode* var) { return var == v; })) {
CreateAdditionalIter();
}
}

void CreateAdditionalIter() final {
additional_iter_ = IterVarFromLoop(rf_loop_, "v" + rf_loop_->loop_var->name_hint, kCommReduce);
iter_vars_.insert(iter_vars_.end(), additional_iter_);
iter_values_.insert(iter_values_.end(), rf_loop_->loop_var);
var_map_.Set(rf_additional_iter_->var, additional_iter_->var);
}

void CreateReductionUpdate() final {
wb_lhs_ = Downcast<BufferLoad>(Substitute(combiner_lhs_, var_map_));
wb_rhs_ =
Expand Down
Loading

0 comments on commit 4179e8c

Please sign in to comment.