From e090d58df47e0e5243207cc3dff9af085fc5b4d3 Mon Sep 17 00:00:00 2001 From: Zenithal Date: Fri, 20 Sep 2024 08:07:23 +0000 Subject: [PATCH] Save non-tensor scalar value in rotation analysis and apply later --- .../RotationAnalysis/RotationAnalysis.cpp | 51 ++-- .../RotationAnalysis/RotationAnalysis.h | 79 ++++++- lib/Dialect/Secret/IR/SecretPatterns.cpp | 4 +- .../TensorExt/Transforms/RotateAndReduce.cpp | 8 +- tests/heir_simd_vectorizer/simple_sum.mlir | 31 ++- tests/tensor_ext/rotate_and_reduce.mlir | 217 +++++++++++++++++- 6 files changed, 363 insertions(+), 27 deletions(-) diff --git a/lib/Analysis/RotationAnalysis/RotationAnalysis.cpp b/lib/Analysis/RotationAnalysis/RotationAnalysis.cpp index e91ea06ca..11ca4ddb4 100644 --- a/lib/Analysis/RotationAnalysis/RotationAnalysis.cpp +++ b/lib/Analysis/RotationAnalysis/RotationAnalysis.cpp @@ -138,27 +138,44 @@ void RotationAnalysis::run(Operation *op) { Value rhs = arithOp.getRhs(); Value newRoot = arithOp.getResult(); OperationName opName = arithOp.getOperation()->getName(); + bool canJoin = false; + + // Both lhs/rhs are in a reduction tree and can join + if (rootToPartialReductions.contains(lhs) && + rootToPartialReductions.contains(rhs)) { + // This is inefficient, but what can we do better here? I suspect a + // better approach may be to identify cases in which only one of + // these reductions needs to be kept because it's "the best" + // according to some metric (e.g., it monotonically increases the + // number of indices and all else stays the same). But for now even + // on the box_blur_64x64 example this is far from the bottleneck. + for (const auto &lhsReduction : rootToPartialReductions[lhs]) { + for (const auto &rhsReduction : rootToPartialReductions[rhs]) { + if (PartialReduction::canJoin(lhsReduction, rhsReduction, + opName)) { + canJoin = true; + addPartialReduction(PartialReduction::join( + lhsReduction, rhsReduction, newRoot, opName)); + } + } + } + } - // TODO(#522): support these non-tensor-extract operands by - // saving the values, and applying them again to the final - // result. - if (!rootToPartialReductions.contains(lhs) || - !rootToPartialReductions.contains(rhs)) { - return; + // If can not join, try saving in one side + if (!canJoin && rootToPartialReductions.contains(lhs)) { + for (const auto &lhsReduction : rootToPartialReductions[lhs]) { + if (PartialReduction::canSave(lhsReduction, rhs, opName)) { + addPartialReduction( + PartialReduction::save(lhsReduction, rhs, newRoot, opName)); + } + } } - // This is inefficient, but what can we do better here? I suspect a - // better approach may be to identify cases in which only one of these - // reductions needs to be kept because it's "the best" according to - // some metric (e.g., it monotonically increases the number of indices - // and all else stays the same). But for now even on the - // box_blur_64x64 example this is far from the bottleneck. - for (const auto &lhsReduction : rootToPartialReductions[lhs]) { + if (!canJoin && rootToPartialReductions.contains(rhs)) { for (const auto &rhsReduction : rootToPartialReductions[rhs]) { - if (PartialReduction::canJoin(lhsReduction, rhsReduction, - opName)) { - addPartialReduction(PartialReduction::join( - lhsReduction, rhsReduction, newRoot, opName)); + if (PartialReduction::canSave(rhsReduction, lhs, opName)) { + addPartialReduction( + PartialReduction::save(rhsReduction, lhs, newRoot, opName)); } } } diff --git a/lib/Analysis/RotationAnalysis/RotationAnalysis.h b/lib/Analysis/RotationAnalysis/RotationAnalysis.h index 5d861ca7c..fdddc06a4 100644 --- a/lib/Analysis/RotationAnalysis/RotationAnalysis.h +++ b/lib/Analysis/RotationAnalysis/RotationAnalysis.h @@ -57,13 +57,19 @@ class PartialReduction { Value getRoot() const { return root; } + const SmallVector &getSavedValues() const { return savedValues; } + void print(raw_ostream &os) const { os << "{ opName: " << (opName.has_value() ? opName->getStringRef() : "None") << "; " << " tensor: " << tensor << "; " << "rotations: ["; for (auto index : accessedIndices) { os << index << ", "; } - os << "]; root: " << root << "; }"; + os << "]; root: " << root << "; savedValues: ["; + for (auto value : savedValues) { + os << value << ", "; + } + os << "]; }"; } // Construct a "leaf" of a reduction, i.e., a PartialReduction that represents @@ -89,6 +95,11 @@ class PartialReduction { // like {1, 2, 3, ...} rather than {1, 1, 1, ...} static PartialReduction rotate(const PartialReduction &lhs, const int64_t shift, Value result) { + // only tensor can rotate + assert(lhs.savedValues.empty() && + "Internal state of RotationAnalysis is broken; tensor having saved " + "value should be impossible"); + LLVM_DEBUG({ llvm::dbgs() << "Rotating\n\t"; lhs.print(llvm::dbgs()); @@ -178,6 +189,12 @@ class PartialReduction { for (auto index : rhs.accessedIndices) { merged.addRotation(index); } + for (auto value : lhs.savedValues) { + merged.savedValues.push_back(value); + } + for (auto value : rhs.savedValues) { + merged.savedValues.push_back(value); + } LLVM_DEBUG({ llvm::dbgs() << "Joining\n\t"; lhs.print(llvm::dbgs()); @@ -190,6 +207,58 @@ class PartialReduction { return merged; } + // Determine if a Value is legal to join at an op whose + // OperationName is given. + static bool canSave(const PartialReduction &lhs, Value rhs, + OperationName opName) { + // If the lhs op is not set, then any op is legal. + if (lhs.opName.has_value() && *lhs.opName != opName) { + return false; + } + // Only support saving scalar value. + // If the saved rhs is a tensor, it might get rotated alongside + // the reduction tree later. + // + // TODO(#522): if no rotation later then a tensor can be saved. + // This can be implemented via checking in a canRotate method. + // + // Note that for the full rotation case, the new PartialReduction + // created from the result tensor in analysis would suffice + if (mlir::isa(rhs.getType())) { + return false; + } + return true; + } + + // Save value within a partial reduction. This assumes the lhs and rhs have + // already been checked to have compatible opNames via canSave. + static PartialReduction save(const PartialReduction &lhs, Value rhs, + Value newRoot, OperationName opName) { + assert(!lhs.accessedIndices.empty() && + "Internal state of RotationAnalysis is broken; empty rotation sets " + "should be impossible"); + + PartialReduction merged; + merged.tensor = lhs.tensor; + merged.root = newRoot; + merged.opName = opName; + for (auto index : lhs.accessedIndices) { + merged.addRotation(index); + } + merged.savedValues = lhs.savedValues; + merged.savedValues.push_back(rhs); + LLVM_DEBUG({ + llvm::dbgs() << "Saving\n\t"; + rhs.print(llvm::dbgs()); + llvm::dbgs() << " inside\n\t"; + lhs.print(llvm::dbgs()); + llvm::dbgs() << " to get\n\t"; + merged.print(llvm::dbgs()); + llvm::dbgs() << "\n"; + }); + return merged; + } + private: // The SSA value being reduced Value tensor; @@ -214,6 +283,14 @@ class PartialReduction { // For now we use std::set which is implemented as a binary tree and ordered // by the index values. std::set accessedIndices; + + // The list of constant Value encountered by the reduction so far. + // + // constant Value in the reduction tree should be saved and applied later + // on the reduced final result. + // Use SmallVector instead of std::set as there might be the same Value saved + // repeatedly + SmallVector savedValues; }; inline raw_ostream &operator<<(raw_ostream &os, const PartialReduction &v) { diff --git a/lib/Dialect/Secret/IR/SecretPatterns.cpp b/lib/Dialect/Secret/IR/SecretPatterns.cpp index 4c92dfac9..97d157662 100644 --- a/lib/Dialect/Secret/IR/SecretPatterns.cpp +++ b/lib/Dialect/Secret/IR/SecretPatterns.cpp @@ -560,7 +560,9 @@ LogicalResult HoistPlaintextOps::matchAndRewrite( if (isa(op)) { return false; } - // complex op + // Conservatively preserve a complex op with a nested region + // This could be a replaced with a recursive call to check that all of the + // regions' operations can be hoisted. if (op.getNumRegions() != 0) { return false; } diff --git a/lib/Dialect/TensorExt/Transforms/RotateAndReduce.cpp b/lib/Dialect/TensorExt/Transforms/RotateAndReduce.cpp index 17e5ac0ca..2ae9541a2 100644 --- a/lib/Dialect/TensorExt/Transforms/RotateAndReduce.cpp +++ b/lib/Dialect/TensorExt/Transforms/RotateAndReduce.cpp @@ -65,10 +65,12 @@ struct RotateAndReduce : impl::RotateAndReduceBase { auto extractOp = b.create( finalOp->getResult(0), b.create(0).getResult()); - op->replaceAllUsesWith(extractOp); - } else { - op->replaceAllUsesWith(finalOp); + finalOp = extractOp; } + for (auto value : reduction.getSavedValues()) { + finalOp = b.create(finalOp->getResult(0), value); + } + op->replaceAllUsesWith(finalOp); LLVM_DEBUG(llvm::dbgs() << "Post-replacement: " << *parentOp << "\n"); } diff --git a/tests/heir_simd_vectorizer/simple_sum.mlir b/tests/heir_simd_vectorizer/simple_sum.mlir index 7596a87ee..59a432df1 100644 --- a/tests/heir_simd_vectorizer/simple_sum.mlir +++ b/tests/heir_simd_vectorizer/simple_sum.mlir @@ -1,4 +1,6 @@ -// RUN: heir-opt --secretize=entry-function=simple_sum --wrap-generic --canonicalize --cse \ +// RUN: heir-opt --secretize=entry-function=simple_sum \ +// RUN: --secretize=entry-function=simple_sum_nested \ +// RUN: --wrap-generic --canonicalize --cse \ // RUN: --heir-simd-vectorizer %s | FileCheck %s // Sum all entries of a tensor into a single scalar @@ -16,3 +18,30 @@ func.func @simple_sum(%arg0: tensor<32xi16>) -> i16 { } return %0 : i16 } + +// Sum all entries of 4 tensors into a single scalar +// CHECK-LABEL: @simple_sum_nested +// CHECK: secret.generic +// CHECK-COUNT-20: tensor_ext.rotate +// CHECK-NOT: tensor_ext.rotate +// CHECK: tensor.extract +// CHECK-NOT: tensor.extract +func.func @simple_sum_nested(%arg0: tensor<32xi16>, %arg1: tensor<32xi16>, %arg2: tensor<32xi16>, %arg3: tensor<32xi16>) -> i16 { + %c0_i16 = arith.constant 0 : i16 + %expanded = tensor.expand_shape %arg0 [[0, 1]] output_shape [1, 32] : tensor<32xi16> into tensor<1x32xi16> + %expanded_0 = tensor.expand_shape %arg1 [[0, 1]] output_shape [1, 32] : tensor<32xi16> into tensor<1x32xi16> + %expanded_1 = tensor.expand_shape %arg2 [[0, 1]] output_shape [1, 32] : tensor<32xi16> into tensor<1x32xi16> + %expanded_2 = tensor.expand_shape %arg3 [[0, 1]] output_shape [1, 32] : tensor<32xi16> into tensor<1x32xi16> + %concat = tensor.concat dim(0) %expanded, %expanded_0, %expanded_1, %expanded_2 : (tensor<1x32xi16>, tensor<1x32xi16>, tensor<1x32xi16>, tensor<1x32xi16>) -> tensor<4x32xi16> + %0 = affine.for %arg4 = 0 to 4 iter_args(%arg5 = %c0_i16) -> (i16) { + %extracted_slice = tensor.extract_slice %concat[%arg4, 0] [1, 32] [1, 1] : tensor<4x32xi16> to tensor<32xi16> + %1 = affine.for %arg6 = 0 to 32 iter_args(%arg7 = %c0_i16) -> (i16) { + %extracted = tensor.extract %extracted_slice[%arg6] : tensor<32xi16> + %3 = arith.addi %extracted, %arg7 : i16 + affine.yield %3 : i16 + } + %2 = arith.addi %1, %arg5 : i16 + affine.yield %2 : i16 + } + return %0 : i16 +} diff --git a/tests/tensor_ext/rotate_and_reduce.mlir b/tests/tensor_ext/rotate_and_reduce.mlir index ea4a911ca..f8bf41378 100644 --- a/tests/tensor_ext/rotate_and_reduce.mlir +++ b/tests/tensor_ext/rotate_and_reduce.mlir @@ -43,6 +43,35 @@ func.func @simple_sum(%arg0: tensor<8xi32>) -> i32 { return %14 : i32 } +// Sum all entries of two tensor into a single scalar +// CHECK-LABEL: @simple_sum_two_tensor +// CHECK-COUNT-2: tensor_ext.rotate +// CHECK: tensor.extract +// CHECK-COUNT-2: tensor_ext.rotate +// CHECK: tensor.extract +func.func @simple_sum_two_tensor(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> i32 { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %0 = tensor.extract %arg0[%c0] : tensor<4xi32> + %1 = tensor.extract %arg0[%c1] : tensor<4xi32> + %2 = tensor.extract %arg0[%c2] : tensor<4xi32> + %3 = tensor.extract %arg0[%c3] : tensor<4xi32> + %4 = tensor.extract %arg1[%c0] : tensor<4xi32> + %5 = tensor.extract %arg1[%c1] : tensor<4xi32> + %6 = tensor.extract %arg1[%c2] : tensor<4xi32> + %7 = tensor.extract %arg1[%c3] : tensor<4xi32> + %8 = arith.addi %0, %1 : i32 + %9 = arith.addi %8, %2 : i32 + %10 = arith.addi %9, %3 : i32 + %11 = arith.addi %10, %4 : i32 + %12 = arith.addi %11, %5 : i32 + %13 = arith.addi %12, %6 : i32 + %14 = arith.addi %13, %7 : i32 + return %14 : i32 +} + // Sum all entries of a tensor // CHECK-LABEL: @simple_sum_mixed_rotation_tensor // CHECK-SAME: (%[[arg0:.*]]: tensor<8xi32> @@ -387,10 +416,23 @@ func.func @not_supported_non_constant_index_access(%arg0: tensor<8xi32>, %arg1: return %14 : i32 } -// CHECK-LABEL: @not_supported_non_tensor_operands -// CHECK-NOT: tensor_ext.rotate -// TODO(#522): support this -func.func @not_supported_non_tensor_operands(%arg0: tensor<8xi32>) -> i32 { +// CHECK-LABEL: @simple_sum_non_tensor_operands +// CHECK-SAME: (%[[arg0:.*]]: tensor<8xi32> +// CHECK-NEXT: %[[c0:.*]] = arith.constant 0 +// CHECK-NEXT: %[[c1:.*]] = arith.constant 1 +// CHECK-NEXT: %[[c2:.*]] = arith.constant 2 +// CHECK-NEXT: %[[c4:.*]] = arith.constant 4 +// CHECK-NEXT: %[[c2_i32:.*]] = arith.constant 2 +// CHECK-NEXT: %[[v0:.*]] = tensor_ext.rotate %[[arg0]], %[[c4]] +// CHECK-NEXT: %[[v1:.*]] = arith.addi %[[arg0]], %[[v0]] +// CHECK-NEXT: %[[v2:.*]] = tensor_ext.rotate %[[v1]], %[[c2]] +// CHECK-NEXT: %[[v3:.*]] = arith.addi %[[v1]], %[[v2]] +// CHECK-NEXT: %[[v4:.*]] = tensor_ext.rotate %[[v3]], %[[c1]] +// CHECK-NEXT: %[[v5:.*]] = arith.addi %[[v3]], %[[v4]] +// CHECK-NEXT: %[[v6:.*]] = tensor.extract %[[v5]][%[[c0]]] +// CHECK-NEXT: %[[v7:.*]] = arith.addi %[[v6]], %[[c2_i32]] +// CHECK-NEXT: return %[[v7]] +func.func @simple_sum_non_tensor_operands(%arg0: tensor<8xi32>) -> i32 { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c2 = arith.constant 2 : index @@ -420,6 +462,173 @@ func.func @not_supported_non_tensor_operands(%arg0: tensor<8xi32>) -> i32 { return %15 : i32 } +// CHECK-LABEL: @simple_sum_multiple_non_tensor_operands +// CHECK-SAME: (%[[arg0:.*]]: tensor<8xi32>, %[[arg1:.*]]: i32 +// CHECK-NEXT: %[[c22_i32:.*]] = arith.constant 22 +// CHECK-NEXT: %[[c0:.*]] = arith.constant 0 +// CHECK-NEXT: %[[c1:.*]] = arith.constant 1 +// CHECK-NEXT: %[[c2:.*]] = arith.constant 2 +// CHECK-NEXT: %[[c4:.*]] = arith.constant 4 +// CHECK-NEXT: %[[v0:.*]] = tensor_ext.rotate %[[arg0]], %[[c4]] +// CHECK-NEXT: %[[v1:.*]] = arith.addi %[[arg0]], %[[v0]] +// CHECK-NEXT: %[[v2:.*]] = tensor_ext.rotate %[[v1]], %[[c2]] +// CHECK-NEXT: %[[v3:.*]] = arith.addi %[[v1]], %[[v2]] +// CHECK-NEXT: %[[v4:.*]] = tensor_ext.rotate %[[v3]], %[[c1]] +// CHECK-NEXT: %[[v5:.*]] = arith.addi %[[v3]], %[[v4]] +// CHECK-NEXT: %[[v6:.*]] = tensor.extract %[[v5]][%[[c0]]] +// CHECK-NEXT: %[[v7:.*]] = arith.addi %[[v6]], %[[c22_i32]] +// CHECK-NEXT: %[[v8:.*]] = arith.addi %[[v7]], %[[arg1]] +// CHECK-NEXT: return %[[v8]] +func.func @simple_sum_multiple_non_tensor_operands(%arg0: tensor<8xi32>, %arg1: i32) -> i32 { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %c4 = arith.constant 4 : index + %c5 = arith.constant 5 : index + %c6 = arith.constant 6 : index + %c7 = arith.constant 7 : index + %c2_i32 = arith.constant 2 : i32 + %c5_i32 = arith.constant 5 : i32 + %c6_i32 = arith.constant 6 : i32 + %c7_i32 = arith.addi %c2_i32, %c5_i32 : i32 + %0 = tensor.extract %arg0[%c0] : tensor<8xi32> + %1 = tensor.extract %arg0[%c1] : tensor<8xi32> + %2 = tensor.extract %arg0[%c2] : tensor<8xi32> + %3 = tensor.extract %arg0[%c3] : tensor<8xi32> + %4 = tensor.extract %arg0[%c4] : tensor<8xi32> + %5 = tensor.extract %arg0[%c5] : tensor<8xi32> + %6 = tensor.extract %arg0[%c6] : tensor<8xi32> + %7 = tensor.extract %arg0[%c7] : tensor<8xi32> + %8 = arith.addi %0, %1 : i32 + // next two ops use the same non-tensor operand + %9 = arith.addi %8, %c2_i32 : i32 + %10 = arith.addi %9, %c2_i32 : i32 + // next op uses another non-tensor operand + %11 = arith.addi %10, %c5_i32 : i32 + // lhs to rhs + %12 = arith.addi %c6_i32, %11 : i32 + // next op uses computed constant + %13 = arith.addi %c7_i32, %12 : i32 + // next op uses arg + %14 = arith.addi %arg1, %13 : i32 + %15 = arith.addi %14, %3 : i32 + %16 = arith.addi %15, %4 : i32 + %17 = arith.addi %16, %5 : i32 + %18 = arith.addi %17, %6 : i32 + %19 = arith.addi %18, %7 : i32 + %20 = arith.addi %19, %2 : i32 + return %20 : i32 +} + +// CHECK-LABEL: @simple_sum_operand_from_another_tensor +// CHECK-SAME: (%[[arg0:.*]]: tensor<8xi32>, %[[arg1:.*]]: tensor<8xi32> +// CHECK-NEXT: %[[c0:.*]] = arith.constant 0 +// CHECK-NEXT: %[[c1:.*]] = arith.constant 1 +// CHECK-NEXT: %[[c2:.*]] = arith.constant 2 +// CHECK-NEXT: %[[c4:.*]] = arith.constant 4 +// CHECK-NEXT: %[[a1:.*]] = tensor.extract %[[arg1]][%[[c0]]] +// CHECK-NEXT: %[[v0:.*]] = tensor_ext.rotate %[[arg0]], %[[c4]] +// CHECK-NEXT: %[[v1:.*]] = arith.addi %[[arg0]], %[[v0]] +// CHECK-NEXT: %[[v2:.*]] = tensor_ext.rotate %[[v1]], %[[c2]] +// CHECK-NEXT: %[[v3:.*]] = arith.addi %[[v1]], %[[v2]] +// CHECK-NEXT: %[[v4:.*]] = tensor_ext.rotate %[[v3]], %[[c1]] +// CHECK-NEXT: %[[v5:.*]] = arith.addi %[[v3]], %[[v4]] +// CHECK-NEXT: %[[v6:.*]] = tensor.extract %[[v5]][%[[c0]]] +// CHECK-NEXT: %[[v7:.*]] = arith.addi %[[v6]], %[[a1]] +// CHECK-NEXT: return %[[v7]] +func.func @simple_sum_operand_from_another_tensor(%arg0: tensor<8xi32>, %arg1: tensor<8xi32>) -> i32 { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %c4 = arith.constant 4 : index + %c5 = arith.constant 5 : index + %c6 = arith.constant 6 : index + %c7 = arith.constant 7 : index + %0 = tensor.extract %arg0[%c0] : tensor<8xi32> + %1 = tensor.extract %arg0[%c1] : tensor<8xi32> + %2 = tensor.extract %arg0[%c2] : tensor<8xi32> + %3 = tensor.extract %arg0[%c3] : tensor<8xi32> + %4 = tensor.extract %arg0[%c4] : tensor<8xi32> + %5 = tensor.extract %arg0[%c5] : tensor<8xi32> + %6 = tensor.extract %arg0[%c6] : tensor<8xi32> + %7 = tensor.extract %arg0[%c7] : tensor<8xi32> + %another = tensor.extract %arg1[%c0] : tensor<8xi32> + %8 = arith.addi %0, %another : i32 + %9 = arith.addi %8, %1 : i32 + %10 = arith.addi %9, %2 : i32 + %11 = arith.addi %10, %3 : i32 + %12 = arith.addi %11, %4 : i32 + %13 = arith.addi %12, %5 : i32 + %14 = arith.addi %13, %6 : i32 + %15 = arith.addi %14, %7 : i32 + return %15 : i32 +} + +// one tensor reduced and add another tensor +// TODO(#522): arg1 could be saved and applied later +// CHECK-LABEL: @not_supported_save_not_rotated_tensor +// CHECK-COUNT-7: tensor_ext.rotate +func.func @not_supported_save_not_rotated_tensor(%arg0: tensor<8xi32>, %arg1: tensor<8xi32>) -> tensor<8xi32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %c4 = arith.constant 4 : index + %c5 = arith.constant 5 : index + %c6 = arith.constant 6 : index + %c7 = arith.constant 7 : index + %0 = tensor_ext.rotate %arg0, %c1 : tensor<8xi32>, index + %1 = tensor_ext.rotate %arg0, %c2 : tensor<8xi32>, index + %2 = tensor_ext.rotate %arg0, %c3 : tensor<8xi32>, index + %3 = tensor_ext.rotate %arg0, %c4 : tensor<8xi32>, index + %4 = tensor_ext.rotate %arg0, %c5 : tensor<8xi32>, index + %5 = tensor_ext.rotate %arg0, %c6 : tensor<8xi32>, index + %6 = tensor_ext.rotate %arg0, %c7 : tensor<8xi32>, index + %7 = arith.addi %arg1, %0 : tensor<8xi32> + %8 = arith.addi %7, %1 : tensor<8xi32> + %9 = arith.addi %8, %2 : tensor<8xi32> + %10 = arith.addi %9, %3 : tensor<8xi32> + %11 = arith.addi %10, %4 : tensor<8xi32> + %12 = arith.addi %11, %5 : tensor<8xi32> + %13 = arith.addi %12, %6 : tensor<8xi32> + %14 = arith.addi %13, %arg0 : tensor<8xi32> + return %14 : tensor<8xi32> +} + +// CHECK-LABEL: @not_supported_mixed_op_non_tensor_operands +// CHECK-NOT: tensor_ext.rotate +func.func @not_supported_mixed_op_non_tensor_operands(%arg0: tensor<8xi32>) -> i32 { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %c4 = arith.constant 4 : index + %c5 = arith.constant 5 : index + %c6 = arith.constant 6 : index + %c7 = arith.constant 7 : index + %c2_i32 = arith.constant 2 : i32 + %0 = tensor.extract %arg0[%c0] : tensor<8xi32> + %1 = tensor.extract %arg0[%c1] : tensor<8xi32> + %2 = tensor.extract %arg0[%c2] : tensor<8xi32> + %3 = tensor.extract %arg0[%c3] : tensor<8xi32> + %4 = tensor.extract %arg0[%c4] : tensor<8xi32> + %5 = tensor.extract %arg0[%c5] : tensor<8xi32> + %6 = tensor.extract %arg0[%c6] : tensor<8xi32> + %7 = tensor.extract %arg0[%c7] : tensor<8xi32> + %8 = arith.addi %0, %1 : i32 + // next not supported op uses non-tensor operand + %9 = arith.muli %8, %c2_i32 : i32 + %10 = arith.addi %9, %3 : i32 + %11 = arith.addi %10, %4 : i32 + %12 = arith.addi %11, %5 : i32 + %13 = arith.addi %12, %6 : i32 + %14 = arith.addi %13, %7 : i32 + %15 = arith.addi %14, %2 : i32 + return %15 : i32 +} + // CHECK-LABEL: @sum_of_linear_rotates // CHECK-COUNT-5: tensor_ext.rotate // CHECK-NOT: tensor_ext.rotate