Skip to content

Commit

Permalink
tiling: Track the branch completion status via the output stages
Browse files Browse the repository at this point in the history
Tiles may produce output on only a single branch when multiple branches
are active. Track the completion using a new BranchComplete() helper,
which traverses the pipeline till the output stage and returns a flag
if the stage has completed all its output.

In the split stage, do not inject tiles into a branch if the output is
complete.

Signed-off-by: Naushir Patuck <[email protected]>
  • Loading branch information
naushir committed Jul 15, 2024
1 parent 8f30e45 commit ee4585c
Show file tree
Hide file tree
Showing 7 changed files with 66 additions and 10 deletions.
20 changes: 17 additions & 3 deletions src/libpisp/backend/tiling/output_stage.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ using namespace tiling;
// the offset/length from the LH edge).

OutputStage::OutputStage(char const *name, Stage *upstream, Config const &config, int struct_offset)
: BasicStage(name, upstream->GetPipeline(), upstream, struct_offset), config_(config)
: BasicStage(name, upstream->GetPipeline(), upstream, struct_offset), config_(config), branch_complete_(false)
{
pipeline_->AddOutputStage(this);
}
Expand Down Expand Up @@ -118,7 +118,21 @@ void OutputStage::PushCropDown(Interval interval, [[maybe_unused]] Dir dir)
PISP_LOG(debug, "(" << name_ << ") Exit with interval " << output_interval_);
}

bool OutputStage::Done(Dir dir) const
void OutputStage::Reset()
{
return output_interval_.End() >= GetOutputImageSize()[dir];
BasicStage::Reset();
branch_complete_ = false;
}

bool OutputStage::BranchComplete() const
{
return branch_complete_;
}

bool OutputStage::Done(Dir dir)
{
if (output_interval_.End() >= GetOutputImageSize()[dir])
branch_complete_ = true;

return branch_complete_;
}
5 changes: 4 additions & 1 deletion src/libpisp/backend/tiling/output_stage.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,13 @@ class OutputStage : public BasicStage
virtual int PushEndDown(int input_end, Dir dir);
virtual void PushEndUp(int output_end, Dir dir);
virtual void PushCropDown(Interval interval, Dir dir);
bool Done(Dir dir) const;
virtual void Reset();
bool BranchComplete() const;
bool Done(Dir dir);

private:
Config config_;
bool branch_complete_;
};

} // namespace tiling
5 changes: 4 additions & 1 deletion src/libpisp/backend/tiling/pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,10 @@ int Pipeline::tileDirection(Dir dir, void *mem, size_t num_items, size_t item_si
if (num_tiles == num_items)
throw std::runtime_error("Too many tiles!");
for (auto s : outputs_)
s->PushStartUp(s->GetOutputInterval().End(), dir);
{
if (!s->BranchComplete())
s->PushStartUp(s->GetOutputInterval().End(), dir);
}
for (auto s : inputs_)
s->PushEndDown(s->GetInputInterval().offset + config_.max_tile_size[dir], dir);
for (auto s : inputs_)
Expand Down
23 changes: 21 additions & 2 deletions src/libpisp/backend/tiling/split_stage.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,13 @@ void SplitStage::PushStartUp(int output_start, Dir dir)
input_interval_ = Interval(output_start);
else
input_interval_ |= output_start;
if (count_ == downstream_.size())

unsigned int branch_incomplete_count = 0;
for (auto const &d : downstream_)
if (!d->BranchComplete())
branch_incomplete_count++;

if (count_ == branch_incomplete_count)
{
count_ = 0;
PISP_LOG(debug, "(" << name_ << ") Exit - call PushStartUp with " << input_interval_.offset);
Expand All @@ -67,6 +73,8 @@ int SplitStage::PushEndDown(int input_end, Dir dir)
input_interval_.SetEnd(0);
for (auto d : downstream_)
{
if (d->BranchComplete())
continue;
int branch_end = d->PushEndDown(input_end, dir);
// (It is OK for a branch to make no progress at all - so long as another branch does!)
if (branch_end > input_interval_.End())
Expand All @@ -82,7 +90,8 @@ int SplitStage::PushEndDown(int input_end, Dir dir)
}

for (auto d : downstream_)
d->PushEndDown(input_interval_.End(), dir);
if (!d->BranchComplete())
d->PushEndDown(input_interval_.End(), dir);

PushEndUp(input_interval_.End(), dir);
return input_interval_.End();
Expand All @@ -105,6 +114,8 @@ void SplitStage::PushCropDown(Interval interval, Dir dir)
input_interval_ = interval;
for (auto d : downstream_)
{
if (d->BranchComplete())
continue;
PISP_LOG(debug, "(" << name_ << ") Exit with interval " << interval);
d->PushCropDown(interval, dir);
}
Expand All @@ -113,3 +124,11 @@ void SplitStage::PushCropDown(Interval interval, Dir dir)
void SplitStage::CopyOut([[maybe_unused]] void *dest, [[maybe_unused]] Dir dir)
{
}

bool SplitStage::BranchComplete() const
{
bool done = true;
for (auto d : downstream_)
done &= d->BranchComplete();
return done;
}
1 change: 1 addition & 0 deletions src/libpisp/backend/tiling/split_stage.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class SplitStage : public Stage
virtual void PushEndUp(int output_end, Dir dir);
virtual void PushCropDown(Interval interval, Dir dir);
virtual void CopyOut(void *dest, Dir dir);
virtual bool BranchComplete() const;

private:
Stage *upstream_;
Expand Down
20 changes: 17 additions & 3 deletions src/libpisp/backend/tiling/stages.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,22 @@ void BasicStage::CopyOut(void *dest, Dir dir)
if (struct_offset_ >= 0)
{
Region *region = (Region *)((uint8_t *)dest + struct_offset_);
region->input[dir] = input_interval_;
region->crop[dir] = crop_;
region->output[dir] = output_interval_;
if (!BranchComplete())
{
region->input[dir] = input_interval_;
region->crop[dir] = crop_;
region->output[dir] = output_interval_;
}
else
{
region->input[dir] = {};
region->crop[dir] = {};
region->output[dir] = {};
}
}
}

bool BasicStage::BranchComplete() const
{
return downstream_->BranchComplete();
}
2 changes: 2 additions & 0 deletions src/libpisp/backend/tiling/stages.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class Stage
virtual void PushEndUp(int output_end, Dir dir) = 0;
virtual void PushCropDown(Interval interval, Dir dir) = 0;
virtual void CopyOut(void *dest, Dir dir) = 0;
virtual bool BranchComplete() const = 0;
void MergeRegions(void *dest, void *x_src, void *y_src) const;

protected:
Expand All @@ -55,6 +56,7 @@ class BasicStage : public Stage
virtual void SetDownstream(Stage *downstream);
virtual void Reset();
virtual void CopyOut(void *dest, Dir dir);
virtual bool BranchComplete() const;

protected:
Stage *upstream_;
Expand Down

0 comments on commit ee4585c

Please sign in to comment.