Skip to content

Commit

Permalink
secret: fix CollapseSecretlessGeneric pattern for tensor semantics
Browse files Browse the repository at this point in the history
CollapseSecretlessGeneric ensures that memref allocation ops are not collapsed - this also adds in empty tensor allocations for IRs that are not using pure buffer semantics (as is the case in the RLWE pipelines)

Part of #954

PiperOrigin-RevId: 672544871
  • Loading branch information
asraa authored and copybara-github committed Sep 9, 2024
1 parent 61487a2 commit 64ad14f
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 1 deletion.
1 change: 1 addition & 0 deletions lib/Dialect/Secret/IR/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ cc_library(
"@llvm-project//mlir:MemRefDialect",
"@llvm-project//mlir:SideEffectInterfaces",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect",
],
)

Expand Down
4 changes: 3 additions & 1 deletion lib/Dialect/Secret/IR/SecretPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
#include "mlir/include/mlir/IR/AffineExpr.h" // from @llvm-project
#include "mlir/include/mlir/IR/Block.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project
Expand Down Expand Up @@ -98,7 +99,8 @@ LogicalResult CollapseSecretlessGeneric::matchAndRewrite(
//
// There is no good way to identify an allocation op in general. Maybe we can
// upstream a trait for this?
for ([[maybe_unused]] const auto op : op.getOps<memref::AllocOp>()) {
if (!op.getOps<tensor::EmptyOp>().empty() ||
!op.getOps<memref::AllocOp>().empty()) {
return failure();
}

Expand Down
59 changes: 59 additions & 0 deletions tests/secret/canonicalize_tensor.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
// RUN: heir-opt --canonicalize %s | FileCheck %s

// Regression test for issue #954 to ensure that generics wrapping tensor.empty
// operations do not collapse.

module {
// CHECK-LABEL: @tensor_empty
func.func @tensor_empty() -> !secret.secret<tensor<1x10xf32>> {
// CHECK-NEXT: %[[v1:.*]] = secret.generic
// CHECK-NEXT: %[[v0:.*]] = tensor.empty() : tensor<1x10xf32>
// CHECK-NEXT: secret.yield %[[v0]] : tensor<1x10xf32>
// CHECK: return %[[v1]] : !secret.secret<tensor<1x10xf32>>
%0 = secret.generic {
%3 = tensor.empty() : tensor<1x10xf32>
secret.yield %3 : tensor<1x10xf32>
} -> !secret.secret<tensor<1x10xf32>>
return %0 : !secret.secret<tensor<1x10xf32>>
}

// CHECK-LABEL: @main
func.func @main(%arg0: !secret.secret<tensor<28x28xf32>>, %arg1: !secret.secret<tensor<784x10xf32>>, %arg2: !secret.secret<tensor<1x10xf32>>) -> !secret.secret<tensor<1x10xf32>> {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
%0 = secret.generic ins(%arg0, %arg1, %arg2 : !secret.secret<tensor<28x28xf32>>, !secret.secret<tensor<784x10xf32>>, !secret.secret<tensor<1x10xf32>>) {
^bb0(%arg3: tensor<28x28xf32>, %arg4: tensor<784x10xf32>, %arg5: tensor<1x10xf32>):
%1 = tosa.reshape %arg3 {new_shape = array<i64: 1, 1, 784>} : (tensor<28x28xf32>) -> tensor<1x1x784xf32>
%2 = tosa.reshape %arg4 {new_shape = array<i64: 1, 784, 10>} : (tensor<784x10xf32>) -> tensor<1x784x10xf32>
%3 = tensor.empty() : tensor<1x1x10xf32>
%4 = affine.for %arg6 = 0 to 10 iter_args(%arg7 = %3) -> (tensor<1x1x10xf32>) {
%inserted = tensor.insert %cst into %arg7[%c0, %c0, %arg6] : tensor<1x1x10xf32>
affine.yield %inserted : tensor<1x1x10xf32>
}
%5 = affine.for %arg6 = 0 to 10 iter_args(%arg7 = %4) -> (tensor<1x1x10xf32>) {
%9 = affine.for %arg8 = 0 to 784 iter_args(%arg9 = %arg7) -> (tensor<1x1x10xf32>) {
%extracted = tensor.extract %1[%c0, %c0, %arg8] : tensor<1x1x784xf32>
%extracted_0 = tensor.extract %2[%c0, %arg8, %arg6] : tensor<1x784x10xf32>
%extracted_1 = tensor.extract %4[%c0, %c0, %arg6] : tensor<1x1x10xf32>
%10 = arith.mulf %extracted, %extracted_0 : f32
%11 = arith.addf %extracted_1, %10 : f32
%inserted = tensor.insert %11 into %arg9[%c0, %c0, %arg6] : tensor<1x1x10xf32>
affine.yield %inserted : tensor<1x1x10xf32>
}
affine.yield %9 : tensor<1x1x10xf32>
}
%6 = tosa.reshape %5 {new_shape = array<i64: 1, 10>} : (tensor<1x1x10xf32>) -> tensor<1x10xf32>
%7 = tensor.empty() : tensor<1x10xf32>
%8 = affine.for %arg6 = 0 to 10 iter_args(%arg7 = %7) -> (tensor<1x10xf32>) {
%extracted = tensor.extract %6[%c0, %arg6] : tensor<1x10xf32>
%extracted_0 = tensor.extract %arg5[%c0, %arg6] : tensor<1x10xf32>
%9 = arith.addf %extracted, %extracted_0 : f32
%10 = arith.maximumf %9, %cst : f32
%inserted = tensor.insert %10 into %arg7[%c0, %arg6] : tensor<1x10xf32>
affine.yield %inserted : tensor<1x10xf32>
}
secret.yield %8 : tensor<1x10xf32>
} -> !secret.secret<tensor<1x10xf32>>
return %0 : !secret.secret<tensor<1x10xf32>>
}
}

0 comments on commit 64ad14f

Please sign in to comment.