Skip to content

Commit

Permalink
Merge pull request #977 from ZenithalHourlyRate:522-reduction-analysi…
Browse files Browse the repository at this point in the history
…s-save-constant

PiperOrigin-RevId: 679202441
  • Loading branch information
copybara-github committed Sep 26, 2024
2 parents 5f3e093 + e090d58 commit b79bccf
Show file tree
Hide file tree
Showing 6 changed files with 363 additions and 27 deletions.
51 changes: 34 additions & 17 deletions lib/Analysis/RotationAnalysis/RotationAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,27 +138,44 @@ void RotationAnalysis::run(Operation *op) {
Value rhs = arithOp.getRhs();
Value newRoot = arithOp.getResult();
OperationName opName = arithOp.getOperation()->getName();
bool canJoin = false;

// Both lhs/rhs are in a reduction tree and can join
if (rootToPartialReductions.contains(lhs) &&
rootToPartialReductions.contains(rhs)) {
// This is inefficient, but what can we do better here? I suspect a
// better approach may be to identify cases in which only one of
// these reductions needs to be kept because it's "the best"
// according to some metric (e.g., it monotonically increases the
// number of indices and all else stays the same). But for now even
// on the box_blur_64x64 example this is far from the bottleneck.
for (const auto &lhsReduction : rootToPartialReductions[lhs]) {
for (const auto &rhsReduction : rootToPartialReductions[rhs]) {
if (PartialReduction::canJoin(lhsReduction, rhsReduction,
opName)) {
canJoin = true;
addPartialReduction(PartialReduction::join(
lhsReduction, rhsReduction, newRoot, opName));
}
}
}
}

// TODO(#522): support these non-tensor-extract operands by
// saving the values, and applying them again to the final
// result.
if (!rootToPartialReductions.contains(lhs) ||
!rootToPartialReductions.contains(rhs)) {
return;
// If can not join, try saving in one side
if (!canJoin && rootToPartialReductions.contains(lhs)) {
for (const auto &lhsReduction : rootToPartialReductions[lhs]) {
if (PartialReduction::canSave(lhsReduction, rhs, opName)) {
addPartialReduction(
PartialReduction::save(lhsReduction, rhs, newRoot, opName));
}
}
}

// This is inefficient, but what can we do better here? I suspect a
// better approach may be to identify cases in which only one of these
// reductions needs to be kept because it's "the best" according to
// some metric (e.g., it monotonically increases the number of indices
// and all else stays the same). But for now even on the
// box_blur_64x64 example this is far from the bottleneck.
for (const auto &lhsReduction : rootToPartialReductions[lhs]) {
if (!canJoin && rootToPartialReductions.contains(rhs)) {
for (const auto &rhsReduction : rootToPartialReductions[rhs]) {
if (PartialReduction::canJoin(lhsReduction, rhsReduction,
opName)) {
addPartialReduction(PartialReduction::join(
lhsReduction, rhsReduction, newRoot, opName));
if (PartialReduction::canSave(rhsReduction, lhs, opName)) {
addPartialReduction(
PartialReduction::save(rhsReduction, lhs, newRoot, opName));
}
}
}
Expand Down
79 changes: 78 additions & 1 deletion lib/Analysis/RotationAnalysis/RotationAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,19 @@ class PartialReduction {

Value getRoot() const { return root; }

const SmallVector<Value> &getSavedValues() const { return savedValues; }

void print(raw_ostream &os) const {
os << "{ opName: " << (opName.has_value() ? opName->getStringRef() : "None")
<< "; " << " tensor: " << tensor << "; " << "rotations: [";
for (auto index : accessedIndices) {
os << index << ", ";
}
os << "]; root: " << root << "; }";
os << "]; root: " << root << "; savedValues: [";
for (auto value : savedValues) {
os << value << ", ";
}
os << "]; }";
}

// Construct a "leaf" of a reduction, i.e., a PartialReduction that represents
Expand All @@ -89,6 +95,11 @@ class PartialReduction {
// like {1, 2, 3, ...} rather than {1, 1, 1, ...}
static PartialReduction rotate(const PartialReduction &lhs,
const int64_t shift, Value result) {
// only tensor can rotate
assert(lhs.savedValues.empty() &&
"Internal state of RotationAnalysis is broken; tensor having saved "
"value should be impossible");

LLVM_DEBUG({
llvm::dbgs() << "Rotating\n\t";
lhs.print(llvm::dbgs());
Expand Down Expand Up @@ -178,6 +189,12 @@ class PartialReduction {
for (auto index : rhs.accessedIndices) {
merged.addRotation(index);
}
for (auto value : lhs.savedValues) {
merged.savedValues.push_back(value);
}
for (auto value : rhs.savedValues) {
merged.savedValues.push_back(value);
}
LLVM_DEBUG({
llvm::dbgs() << "Joining\n\t";
lhs.print(llvm::dbgs());
Expand All @@ -190,6 +207,58 @@ class PartialReduction {
return merged;
}

// Determine if a Value is legal to join at an op whose
// OperationName is given.
static bool canSave(const PartialReduction &lhs, Value rhs,
OperationName opName) {
// If the lhs op is not set, then any op is legal.
if (lhs.opName.has_value() && *lhs.opName != opName) {
return false;
}
// Only support saving scalar value.
// If the saved rhs is a tensor, it might get rotated alongside
// the reduction tree later.
//
// TODO(#522): if no rotation later then a tensor can be saved.
// This can be implemented via checking in a canRotate method.
//
// Note that for the full rotation case, the new PartialReduction
// created from the result tensor in analysis would suffice
if (mlir::isa<RankedTensorType>(rhs.getType())) {
return false;
}
return true;
}

// Save value within a partial reduction. This assumes the lhs and rhs have
// already been checked to have compatible opNames via canSave.
static PartialReduction save(const PartialReduction &lhs, Value rhs,
Value newRoot, OperationName opName) {
assert(!lhs.accessedIndices.empty() &&
"Internal state of RotationAnalysis is broken; empty rotation sets "
"should be impossible");

PartialReduction merged;
merged.tensor = lhs.tensor;
merged.root = newRoot;
merged.opName = opName;
for (auto index : lhs.accessedIndices) {
merged.addRotation(index);
}
merged.savedValues = lhs.savedValues;
merged.savedValues.push_back(rhs);
LLVM_DEBUG({
llvm::dbgs() << "Saving\n\t";
rhs.print(llvm::dbgs());
llvm::dbgs() << " inside\n\t";
lhs.print(llvm::dbgs());
llvm::dbgs() << " to get\n\t";
merged.print(llvm::dbgs());
llvm::dbgs() << "\n";
});
return merged;
}

private:
// The SSA value being reduced
Value tensor;
Expand All @@ -214,6 +283,14 @@ class PartialReduction {
// For now we use std::set which is implemented as a binary tree and ordered
// by the index values.
std::set<int64_t> accessedIndices;

// The list of constant Value encountered by the reduction so far.
//
// constant Value in the reduction tree should be saved and applied later
// on the reduced final result.
// Use SmallVector instead of std::set as there might be the same Value saved
// repeatedly
SmallVector<Value> savedValues;
};

inline raw_ostream &operator<<(raw_ostream &os, const PartialReduction &v) {
Expand Down
4 changes: 3 additions & 1 deletion lib/Dialect/Secret/IR/SecretPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,9 @@ LogicalResult HoistPlaintextOps::matchAndRewrite(
if (isa<YieldOp>(op)) {
return false;
}
// complex op
// Conservatively preserve a complex op with a nested region
// This could be a replaced with a recursive call to check that all of the
// regions' operations can be hoisted.
if (op.getNumRegions() != 0) {
return false;
}
Expand Down
8 changes: 5 additions & 3 deletions lib/Dialect/TensorExt/Transforms/RotateAndReduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,12 @@ struct RotateAndReduce : impl::RotateAndReduceBase<RotateAndReduce> {
auto extractOp = b.create<tensor::ExtractOp>(
finalOp->getResult(0),
b.create<arith::ConstantIndexOp>(0).getResult());
op->replaceAllUsesWith(extractOp);
} else {
op->replaceAllUsesWith(finalOp);
finalOp = extractOp;
}
for (auto value : reduction.getSavedValues()) {
finalOp = b.create<ArithOp>(finalOp->getResult(0), value);
}
op->replaceAllUsesWith(finalOp);
LLVM_DEBUG(llvm::dbgs() << "Post-replacement: " << *parentOp << "\n");
}

Expand Down
31 changes: 30 additions & 1 deletion tests/heir_simd_vectorizer/simple_sum.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
// RUN: heir-opt --secretize=entry-function=simple_sum --wrap-generic --canonicalize --cse \
// RUN: heir-opt --secretize=entry-function=simple_sum \
// RUN: --secretize=entry-function=simple_sum_nested \
// RUN: --wrap-generic --canonicalize --cse \
// RUN: --heir-simd-vectorizer %s | FileCheck %s

// Sum all entries of a tensor into a single scalar
Expand All @@ -16,3 +18,30 @@ func.func @simple_sum(%arg0: tensor<32xi16>) -> i16 {
}
return %0 : i16
}

// Sum all entries of 4 tensors into a single scalar
// CHECK-LABEL: @simple_sum_nested
// CHECK: secret.generic
// CHECK-COUNT-20: tensor_ext.rotate
// CHECK-NOT: tensor_ext.rotate
// CHECK: tensor.extract
// CHECK-NOT: tensor.extract
func.func @simple_sum_nested(%arg0: tensor<32xi16>, %arg1: tensor<32xi16>, %arg2: tensor<32xi16>, %arg3: tensor<32xi16>) -> i16 {
%c0_i16 = arith.constant 0 : i16
%expanded = tensor.expand_shape %arg0 [[0, 1]] output_shape [1, 32] : tensor<32xi16> into tensor<1x32xi16>
%expanded_0 = tensor.expand_shape %arg1 [[0, 1]] output_shape [1, 32] : tensor<32xi16> into tensor<1x32xi16>
%expanded_1 = tensor.expand_shape %arg2 [[0, 1]] output_shape [1, 32] : tensor<32xi16> into tensor<1x32xi16>
%expanded_2 = tensor.expand_shape %arg3 [[0, 1]] output_shape [1, 32] : tensor<32xi16> into tensor<1x32xi16>
%concat = tensor.concat dim(0) %expanded, %expanded_0, %expanded_1, %expanded_2 : (tensor<1x32xi16>, tensor<1x32xi16>, tensor<1x32xi16>, tensor<1x32xi16>) -> tensor<4x32xi16>
%0 = affine.for %arg4 = 0 to 4 iter_args(%arg5 = %c0_i16) -> (i16) {
%extracted_slice = tensor.extract_slice %concat[%arg4, 0] [1, 32] [1, 1] : tensor<4x32xi16> to tensor<32xi16>
%1 = affine.for %arg6 = 0 to 32 iter_args(%arg7 = %c0_i16) -> (i16) {
%extracted = tensor.extract %extracted_slice[%arg6] : tensor<32xi16>
%3 = arith.addi %extracted, %arg7 : i16
affine.yield %3 : i16
}
%2 = arith.addi %1, %arg5 : i16
affine.yield %2 : i16
}
return %0 : i16
}
Loading

0 comments on commit b79bccf

Please sign in to comment.