Skip to content

Commit

Permalink
Merge pull request #1058 from ahmedshakill:secret_to_cggi
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 687473368
  • Loading branch information
copybara-github committed Oct 19, 2024
2 parents ec41b1f + 88740c0 commit 87412a4
Show file tree
Hide file tree
Showing 21 changed files with 119 additions and 99 deletions.
4 changes: 2 additions & 2 deletions docs/content/en/docs/getting_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -382,8 +382,8 @@ input files. You can also access the underlying binary at

```bash
bazel run //tools:heir-opt -- \
--comb-to-cggi -cse \
$PWD/tests/comb_to_cggi/add_one.mlir
--secret-to-cggi -cse \
$PWD/tests/Dialect/Secret/Conversions/secret_to_cggi/add_one.mlir
```

To convert an existing lit test to a `bazel run` command for manual tweaking and
Expand Down
27 changes: 0 additions & 27 deletions lib/Dialect/Comb/Conversions/CombToCGGI/CMakeLists.txt

This file was deleted.

16 changes: 0 additions & 16 deletions lib/Dialect/Comb/Conversions/CombToCGGI/CombToCGGI.h

This file was deleted.

20 changes: 0 additions & 20 deletions lib/Dialect/Comb/Conversions/CombToCGGI/CombToCGGI.td

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ package(
)

cc_library(
name = "CombToCGGI",
srcs = ["CombToCGGI.cpp"],
name = "SecretToCGGI",
srcs = ["SecretToCGGI.cpp"],
hdrs = [
"CombToCGGI.h",
"SecretToCGGI.h",
],
deps = [
":pass_inc_gen",
Expand Down Expand Up @@ -38,17 +38,17 @@ gentbl_cc_library(
(
[
"-gen-pass-decls",
"-name=CombToCGGI",
"-name=SecretToCGGI",
],
"CombToCGGI.h.inc",
"SecretToCGGI.h.inc",
),
(
["-gen-pass-doc"],
"CombToCGGI.md",
"SecretToCGGI.md",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "CombToCGGI.td",
td_file = "SecretToCGGI.td",
deps = [
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:PassBaseTdFiles",
Expand Down
30 changes: 30 additions & 0 deletions lib/Dialect/Secret/Conversions/SecretToCGGI/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
add_heir_pass(SecretToCGGI)

add_mlir_dialect_library(HEIRSecretToCGGI
SecretToCGGI.cpp

ADDITIONAL_HEADER_DIRS
${CMAKE_CURRENT_SOURCE_DIR}/IR

DEPENDS
HEIRSecretToCGGIIncGen

LINK_LIBS PUBLIC

MLIRHEIRUtils
MLIRCGGI
MLIRComb
MLIRLWE
MLIRSecret

LLVMSupport
MLIRAffineDialect
MLIRAffineUtils
MLIRArithDialect
MLIRDialectUtils
MLIRIR
MLIRMemRefDialect
MLIRPass
MLIRSupport
MLIRTransformUtils
)
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include "lib/Dialect/Comb/Conversions/CombToCGGI/CombToCGGI.h"
#include "lib/Dialect/Secret/Conversions/SecretToCGGI/SecretToCGGI.h"

#include <cassert>
#include <cstdint>
Expand Down Expand Up @@ -39,10 +39,10 @@
#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project

namespace mlir::heir::comb {
namespace mlir::heir {

#define GEN_PASS_DEF_COMBTOCGGI
#include "lib/Dialect/Comb/Conversions/CombToCGGI/CombToCGGI.h.inc"
#define GEN_PASS_DEF_SECRETTOCGGI
#include "lib/Dialect/Secret/Conversions/SecretToCGGI/SecretToCGGI.h.inc"

namespace {

Expand Down Expand Up @@ -414,14 +414,14 @@ class SecretGenericOpMemRefStoreConversion
};

// ConvertTruthTableOp converts truth table ops with fully plaintext values.
struct ConvertTruthTableOp : public OpConversionPattern<TruthTableOp> {
struct ConvertTruthTableOp : public OpConversionPattern<comb::TruthTableOp> {
ConvertTruthTableOp(mlir::MLIRContext *context)
: OpConversionPattern<TruthTableOp>(context) {}
: OpConversionPattern<comb::TruthTableOp>(context) {}

using OpConversionPattern::OpConversionPattern;

LogicalResult matchAndRewrite(
TruthTableOp op, OpAdaptor adaptor,
comb::TruthTableOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (op->getNumOperands() != 3) {
op->emitError() << "expected 3 truth table arguments to lower to CGGI";
Expand Down Expand Up @@ -538,7 +538,7 @@ struct ConvertSecretCastOp : public OpConversionPattern<secret::CastOp> {
int findLUTSize(MLIRContext *context, Operation *module) {
int max_int_size = 0;
auto processOperation = [&](Operation *op) {
if (isa<CombDialect>(op->getDialect())) {
if (isa<comb::CombDialect>(op->getDialect())) {
int current_size = 0;
if (dyn_cast<comb::TruthTableOp>(op))
current_size = 3;
Expand All @@ -555,7 +555,7 @@ int findLUTSize(MLIRContext *context, Operation *module) {
return max_int_size;
}

struct CombToCGGI : public impl::CombToCGGIBase<CombToCGGI> {
struct SecretToCGGI : public impl::SecretToCGGIBase<SecretToCGGI> {
void runOnOperation() override {
MLIRContext *context = &getContext();
auto *module = getOperation();
Expand All @@ -581,7 +581,8 @@ struct CombToCGGI : public impl::CombToCGGIBase<CombToCGGI> {
SecretGenericOpOrConversion, SecretGenericOpXNorConversion,
SecretGenericOpXorConversion, ConvertSecretCastOp>(typeConverter,
context);
target.addIllegalOp<TruthTableOp, secret::CastOp, secret::GenericOp>();
target
.addIllegalOp<comb::TruthTableOp, secret::CastOp, secret::GenericOp>();
target.addDynamicallyLegalOp<memref::StoreOp>([&](memref::StoreOp op) {
// Legal only when the memref element type matches the stored
// type.
Expand Down Expand Up @@ -614,4 +615,4 @@ struct CombToCGGI : public impl::CombToCGGIBase<CombToCGGI> {
}
};

} // namespace mlir::heir::comb
} // namespace mlir::heir
16 changes: 16 additions & 0 deletions lib/Dialect/Secret/Conversions/SecretToCGGI/SecretToCGGI.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#ifndef LIB_DIALECT_SECRET_CONVERSIONS_SECRETTOCGGI_SECRETTOCGGI_H_
#define LIB_DIALECT_SECRET_CONVERSIONS_SECRETTOCGGI_SECRETTOCGGI_H_

#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project

namespace mlir::heir {

#define GEN_PASS_DECL
#include "lib/Dialect/Secret/Conversions/SecretToCGGI/SecretToCGGI.h.inc"

#define GEN_PASS_REGISTRATION
#include "lib/Dialect/Secret/Conversions/SecretToCGGI/SecretToCGGI.h.inc"

} // namespace mlir::heir

#endif // LIB_DIALECT_SECRET_CONVERSIONS_SECRETTOCGGI_SECRETTOCGGI_H_
20 changes: 20 additions & 0 deletions lib/Dialect/Secret/Conversions/SecretToCGGI/SecretToCGGI.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#ifndef LIB_DIALECT_SECRET_CONVERSIONS_SECRETTOCGGI_SECRETTOCGGI_TD_
#define LIB_DIALECT_SECRET_CONVERSIONS_SECRETTOCGGI_SECRETTOCGGI_TD_

include "mlir/Pass/PassBase.td"

def SecretToCGGI : Pass<"secret-to-cggi"> {
let summary = "Lower `secret` to `cggi` dialect.";

let description = [{
This pass lowers the `secret` dialect to `cggi` dialect.
}];

let dependentDialects = [
"mlir::heir::comb::CombDialect",
"mlir::heir::cggi::CGGIDialect",
"mlir::memref::MemRefDialect",
];
}

#endif // LIB_DIALECT_SECRET_CONVERSIONS_SECRETTOCGGI_SECRETTOCGGI_TD_
18 changes: 17 additions & 1 deletion lib/Utils/ConversionUtils/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,20 @@ add_mlir_library(HEIRConversionUtils
MLIRDialect
)

target_link_libraries(HEIRUtils INTERFACE HEIRConversionUtils)
add_subdirectory(ModArithToArith)
add_subdirectory(BGVToOpenfhe)
add_subdirectory(BGVToLWE)
add_subdirectory(CGGIToJaxite)
add_subdirectory(CGGIToTfheRust)
add_subdirectory(CGGIToTfheRustBool)
add_subdirectory(SecretToCGGI)
add_subdirectory(MemrefToArith)
add_subdirectory(PolynomialToStandard)
add_subdirectory(SecretToBGV)
add_subdirectory(SecretToCKKS)
add_subdirectory(LWEToPolynomial)
add_subdirectory(LinalgToTensorExt)
add_subdirectory(TosaToSecretArith)
add_subdirectory(LWEToOpenfhe)
add_subdirectory(CKKSToOpenfhe)
add_subdirectory(RlweToOpenfhe)
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: heir-opt --secret-distribute-generic --comb-to-cggi -cse %s | FileCheck %s
// RUN: heir-opt --secret-distribute-generic --secret-to-cggi -cse %s | FileCheck %s

// This test was produced by running
// heir-opt --yosys-optimizer --canonicalize tests/yosys_optimizer/add_one.mlir
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: heir-opt --secret-distribute-generic --comb-to-cggi -cse %s | FileCheck %s
// RUN: heir-opt --secret-distribute-generic --secret-to-cggi -cse %s | FileCheck %s

// This test was produced by running
// heir-opt --yosys-optimizer="mode=Boolean" --canonicalize tests/yosys_optimizer/add_one.mlir
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: heir-opt --secret-distribute-generic --split-input-file --comb-to-cggi --cse %s | FileCheck %s
// RUN: heir-opt --secret-distribute-generic --split-input-file --secret-to-cggi --cse %s | FileCheck %s

// CHECK-NOT: secret
// CHECK: @boolean_gates([[ARG:%.*]]: [[LWET:!lwe.lwe_ciphertext<.*>]]) -> [[LWET]]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: heir-opt --secret-distribute-generic --comb-to-cggi %s | FileCheck %s
// RUN: heir-opt --secret-distribute-generic --secret-to-cggi %s | FileCheck %s

// CHECK: module
module attributes {tf_saved_model.semantics} {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// This test ensures that secret casting before and after generics lowers to CGGI properly.

// RUN: heir-opt --secret-distribute-generic --comb-to-cggi -cse --split-input-file %s | FileCheck %s
// RUN: heir-opt --secret-distribute-generic --secret-to-cggi -cse --split-input-file %s | FileCheck %s

// CHECK: module
module attributes {tf_saved_model.semantics} {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: heir-opt --secret-distribute-generic --comb-to-cggi %s | FileCheck %s
// RUN: heir-opt --secret-distribute-generic --secret-to-cggi %s | FileCheck %s

module {
// CHECK-NOT: secret
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: heir-opt --secret-distribute-generic --split-input-file --comb-to-cggi --cse %s | FileCheck %s
// RUN: heir-opt --secret-distribute-generic --split-input-file --secret-to-cggi --cse %s | FileCheck %s

// CHECK-NOT: secret
// CHECK: @truth_table_all_secret([[ARG:%.*]]: [[LWET:!lwe.lwe_ciphertext<.* = 3>>]]) -> [[LWET:!lwe.lwe_ciphertext<.* = 3>>]]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: heir-opt --comb-to-cggi %s | FileCheck %s
// RUN: heir-opt --secret-to-cggi %s | FileCheck %s

module {
// CHECK: func.func @types([[ARG:%.*]]: [[T:!lwe.lwe_ciphertext<.*>]]) -> [[T]]
Expand Down
2 changes: 1 addition & 1 deletion tools/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ cc_binary(
"@heir//lib/Dialect/CGGI/Transforms",
"@heir//lib/Dialect/CKKS/Conversions/CKKSToOpenfhe",
"@heir//lib/Dialect/CKKS/IR:Dialect",
"@heir//lib/Dialect/Comb/Conversions/CombToCGGI",
"@heir//lib/Dialect/Comb/IR:Dialect",
"@heir//lib/Dialect/Jaxite/IR:Dialect",
"@heir//lib/Dialect/LWE/Conversions/LWEToPolynomial",
Expand All @@ -61,6 +60,7 @@ cc_binary(
"@heir//lib/Dialect/RNS/IR:Dialect",
"@heir//lib/Dialect/Random/IR:Dialect",
"@heir//lib/Dialect/Secret/Conversions/SecretToBGV",
"@heir//lib/Dialect/Secret/Conversions/SecretToCGGI",
"@heir//lib/Dialect/Secret/Conversions/SecretToCKKS",
"@heir//lib/Dialect/Secret/IR:Dialect",
"@heir//lib/Dialect/Secret/Transforms",
Expand Down
12 changes: 6 additions & 6 deletions tools/heir-opt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
#include "lib/Dialect/CGGI/Transforms/Passes.h"
#include "lib/Dialect/CKKS/Conversions/CKKSToOpenfhe/CKKSToOpenfhe.h"
#include "lib/Dialect/CKKS/IR/CKKSDialect.h"
#include "lib/Dialect/Comb/Conversions/CombToCGGI/CombToCGGI.h"
#include "lib/Dialect/Comb/IR/CombDialect.h"
#include "lib/Dialect/Jaxite/IR/JaxiteDialect.h"
#include "lib/Dialect/LWE/Conversions/LWEToPolynomial/LWEToPolynomial.h"
Expand All @@ -34,6 +33,7 @@
#include "lib/Dialect/RNS/IR/RNSTypes.h"
#include "lib/Dialect/Random/IR/RandomDialect.h"
#include "lib/Dialect/Secret/Conversions/SecretToBGV/SecretToBGV.h"
#include "lib/Dialect/Secret/Conversions/SecretToCGGI/SecretToCGGI.h"
#include "lib/Dialect/Secret/Conversions/SecretToCKKS/SecretToCKKS.h"
#include "lib/Dialect/Secret/IR/SecretDialect.h"
#include "lib/Dialect/Secret/Transforms/BufferizableOpInterfaceImpl.h"
Expand Down Expand Up @@ -394,7 +394,7 @@ void tosaToBooleanTfhePipeline(const std::string &yosysFilesPath,

pm.addPass(mlir::createCSEPass());
pm.addPass(secret::createSecretDistributeGeneric());
pm.addPass(comb::createCombToCGGI());
pm.addPass(createSecretToCGGI());

// CGGI to Tfhe-Rust exit dialect
pm.addPass(createCGGIToTfheRust());
Expand Down Expand Up @@ -486,8 +486,8 @@ void tosaToBooleanFpgaTfhePipeline(const std::string &yosysFilesPath,
pm.addPass(createForwardStoreToLoad());
pm.addPass(mlir::createCSEPass());
pm.addPass(secret::createSecretDistributeGeneric());
pm.addPass(comb::createCombToCGGI());
// Cleanup CombToCGGI
pm.addPass(createSecretToCGGI());
// Cleanup SecretToCGGI
pm.addPass(createExpandCopyPass(
ExpandCopyPassOptions{.disableAffineLoop = true}));
pm.addPass(memref::createFoldMemRefAliasOpsPass());
Expand Down Expand Up @@ -581,7 +581,7 @@ void tosaToJaxitePipeline(const std::string &yosysFilesPath,

pm.addPass(mlir::createCSEPass());
pm.addPass(secret::createSecretDistributeGeneric());
pm.addPass(comb::createCombToCGGI());
pm.addPass(createSecretToCGGI());

// CGGI to Jaxite exit dialect
pm.addPass(createCGGIToJaxite());
Expand Down Expand Up @@ -839,7 +839,7 @@ int main(int argc, char **argv) {
bgv::registerBGVToLWEPasses();
bgv::registerBGVToOpenfhePasses();
ckks::registerCKKSToOpenfhePasses();
comb::registerCombToCGGIPasses();
registerSecretToCGGIPasses();
lwe::registerLWEToPolynomialPasses();
::mlir::heir::linalg::registerLinalgToTensorExtPasses();
::mlir::heir::polynomial::registerPolynomialToStandardPasses();
Expand Down

0 comments on commit 87412a4

Please sign in to comment.