Skip to content

Commit

Permalink
[CIR] Use only cir:: instead of mlir::cir:: to match latest ClangIR
Browse files Browse the repository at this point in the history
Follow up the change made up-stream with
llvm/clangir#1084
  • Loading branch information
keryell committed Nov 13, 2024
1 parent 53f36cd commit abce593
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 51 deletions.
4 changes: 2 additions & 2 deletions include/aie/CIR/CIRToAIEPasses.td
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def CIRToAIEPrepare : Pass<"cir-to-aie-prepare", "mlir::ModuleOp"> {

let constructor = "xilinx::AIE::CIR::createCIRToAIEPreparePass()";
let dependentDialects = [
"mlir::cir::CIRDialect",
"cir::CIRDialect",
"xilinx::AIE::AIEDialect",
];
}
Expand All @@ -40,7 +40,7 @@ def CIRToAIE : Pass<"cir-to-aie", "mlir::ModuleOp"> {

let constructor = "xilinx::AIE::CIR::createCIRToAIEPass()";
let dependentDialects = [
"mlir::cir::CIRDialect",
"cir::CIRDialect",
"xilinx::AIE::AIEDialect",
];
}
Expand Down
89 changes: 42 additions & 47 deletions lib/CIR/CIRToAIEPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,9 @@ class CIRToAIETypesAnalysis {
llvm::Regex{"^(aie::buffer)<([^,]+), ([^>]+)>$"}};

for (auto &[type, value] : moduleTypes) {
if (auto maybePointerType = mlir::dyn_cast<mlir::cir::PointerType>(type))
if (auto maybeStructType = mlir::dyn_cast<mlir::cir::StructType>(
maybePointerType.getPointee()))
if (auto maybePointerType = mlir::dyn_cast<cir::PointerType>(type))
if (auto maybeStructType =
mlir::dyn_cast<cir::StructType>(maybePointerType.getPointee()))
for (auto &tnp : typeNamePatterns)
if (llvm::SmallVector<llvm::StringRef> matches;
tnp.match(maybeStructType.getName(), &matches)) {
Expand Down Expand Up @@ -258,14 +258,13 @@ namespace {
// Return true if the call operation calls a function with any of the given
// string annotations
bool isCallingFunctionWithAnnotation(
mlir::cir::CallOp op, llvm::ArrayRef<llvm::StringRef> anyAnnotations) {
if (auto calledFunc =
mlir::SymbolTable::lookupNearestSymbolFrom<mlir::cir::FuncOp>(
op, op.getCalleeAttr())) {
cir::CallOp op, llvm::ArrayRef<llvm::StringRef> anyAnnotations) {
if (auto calledFunc = mlir::SymbolTable::lookupNearestSymbolFrom<cir::FuncOp>(
op, op.getCalleeAttr())) {
if (auto annnotations = calledFunc.getAnnotationsAttr())
for (auto a : calledFunc.getAnnotationsAttr()) {
for (auto one : anyAnnotations)
if (mlir::cast<mlir::cir::AnnotationAttr>(a).getName() == one)
if (mlir::cast<cir::AnnotationAttr>(a).getName() == one)
return true;
}
}
Expand All @@ -288,18 +287,17 @@ bool isUnrealizedConversionCastWithAnnotation(
mlir::MemRefType bufferMemrefType(mlir::Type buffer) {
static mlir::TypeConverter typeConverter = cir::prepareTypeConverter();
LLVM_DEBUG(buffer.dump());
if (auto p = mlir::dyn_cast<mlir::cir::PointerType>(buffer)) {
if (auto bufferType =
mlir::dyn_cast<mlir::cir::StructType>(p.getPointee())) {
if (auto p = mlir::dyn_cast<cir::PointerType>(buffer)) {
if (auto bufferType = mlir::dyn_cast<cir::StructType>(p.getPointee())) {
LLVM_DEBUG(bufferType.dump());
// For now the aie::buffer is implemented as a std::array in the buffer
// struct
auto members = bufferType.getMembers();
if (auto stdArrayType =
mlir::dyn_cast<mlir::cir::StructType>(members.front())) {
mlir::dyn_cast<cir::StructType>(members.front())) {
LLVM_DEBUG(stdArrayType.dump());
// Access the array inside the std::array struct
if (auto arrayType = mlir::dyn_cast<mlir::cir::ArrayType>(
if (auto arrayType = mlir::dyn_cast<cir::ArrayType>(
stdArrayType.getMembers().front())) {
LLVM_DEBUG(arrayType.dump());
auto memref = mlir::dyn_cast<mlir::MemRefType>(
Expand Down Expand Up @@ -367,17 +365,16 @@ void cloneReferencedSymbolsIntoDevice(xilinx::AIE::DeviceOp device) {

// Lower C++ code like \code aie::device<aie::npu1> into an \code
// aie.device(npu1){} operation
struct PrepareDeviceLowering
: public mlir::OpConversionPattern<mlir::cir::AllocaOp> {
using mlir::OpConversionPattern<mlir::cir::AllocaOp>::OpConversionPattern;
struct PrepareDeviceLowering : public mlir::OpConversionPattern<cir::AllocaOp> {
using mlir::OpConversionPattern<cir::AllocaOp>::OpConversionPattern;

// \todo Find a less ugly way to access the analysis. How is it possible for a
// pattern to access some contextual information?
// It should be OK since it is a module pass, so no parallelism here.
static inline CIRToAIETypesAnalysis *cat;

mlir::LogicalResult
matchAndRewrite(mlir::cir::AllocaOp op, OpAdaptor adaptor,
matchAndRewrite(cir::AllocaOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const final {
// The struct has a name like "aie::device<aie::npu1>" and the "npu1"
// is used directly for the MLIR aie.device attribute
Expand Down Expand Up @@ -415,25 +412,25 @@ struct PrepareDeviceLowering
// %2 = builtin.unrealized_conversion_cast %1 : !cir.ptr<!ty_aie3A3Adevice3Caie3A3Anpu13E> to !cir.ptr<!ty_aie3A3Atile3C12C_43E> {"aie::tile" = ["1", "4"]}
// clang-format on
struct PrepareTileBufferLowering
: public mlir::OpConversionPattern<mlir::cir::CallOp> {
using mlir::OpConversionPattern<mlir::cir::CallOp>::OpConversionPattern;
: public mlir::OpConversionPattern<cir::CallOp> {
using mlir::OpConversionPattern<cir::CallOp>::OpConversionPattern;

// \todo Find a less ugly way to access the analysis. How is it possible for a
// pattern to access some contextual information?
// It should be OK since it is a module pass, so no parallelism here.
static inline CIRToAIETypesAnalysis *cat;

mlir::LogicalResult
matchAndRewrite(mlir::cir::CallOp op, OpAdaptor adaptor,
matchAndRewrite(cir::CallOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const final {
if (isCallingFunctionWithAnnotation(
op, {"aie.device.tile", "aie.tile.buffer"})) {
auto device = op.getOperand(0);
auto user = op.getResult().getUsers().begin();
// Track the alloca where the tiled is stored
auto store = mlir::dyn_cast<mlir::cir::StoreOp>(*user);
auto alloca = mlir::dyn_cast<mlir::cir::AllocaOp>(
store.getOperand(1).getDefiningOp());
auto store = mlir::dyn_cast<cir::StoreOp>(*user);
auto alloca =
mlir::dyn_cast<cir::AllocaOp>(store.getOperand(1).getDefiningOp());
auto aieLike = cat->getTypeDetail(alloca.getResult().getType());
// Replace the alloca by a conversion to be replaced later in
// another pass.
Expand Down Expand Up @@ -471,30 +468,29 @@ struct PrepareTileBufferLowering
capture by the direct def/use forwarding
*/
struct PrepareCoreLowering
: public mlir::OpConversionPattern<mlir::cir::CallOp> {
using mlir::OpConversionPattern<mlir::cir::CallOp>::OpConversionPattern;
struct PrepareCoreLowering : public mlir::OpConversionPattern<cir::CallOp> {
using mlir::OpConversionPattern<cir::CallOp>::OpConversionPattern;

mlir::LogicalResult
matchAndRewrite(mlir::cir::CallOp op, OpAdaptor adaptor,
matchAndRewrite(cir::CallOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const final {
if (isCallingFunctionWithAnnotation(op, {"aie.tile.program"})) {
// Get tile::program() member function
if (auto calledFunc =
mlir::SymbolTable::lookupNearestSymbolFrom<mlir::cir::FuncOp>(
mlir::SymbolTable::lookupNearestSymbolFrom<cir::FuncOp>(
op, op.getCalleeAttr())) {
// The last function instruction is cir.return and the one before
// is the call to the lambda
// calledFunc.getBlocks().front().back().dump();
auto lambdaCall = mlir::dyn_cast<mlir::cir::CallOp>(
auto lambdaCall = mlir::dyn_cast<cir::CallOp>(
*std::next(calledFunc.getBlocks().front().rbegin()));
// lambdaCall.dump();
if (auto lambdaFunc =
mlir::SymbolTable::lookupNearestSymbolFrom<mlir::cir::FuncOp>(
mlir::SymbolTable::lookupNearestSymbolFrom<cir::FuncOp>(
lambdaCall, lambdaCall.getCalleeAttr())) {
// lambdaFunc.dump();
assert(lambdaFunc.getLambda());
// auto scopeOp = op->getParentOfType<mlir::cir::ScopeOp>();
// auto scopeOp = op->getParentOfType<cir::ScopeOp>();
// scopeOp.dump();
// The aie++ tile value
rewriter.setInsertionPoint(op);
Expand Down Expand Up @@ -526,14 +522,13 @@ struct CIRToAIEPrepare : CIRToAIEPrepareBase<CIRToAIEPrepare> {
mlir::ConversionTarget target{getContext()};
target.addLegalDialect<xilinx::AIE::AIEDialect>();
target.addLegalOp<mlir::UnrealizedConversionCastOp>();
target.addDynamicallyLegalOp<mlir::cir::AllocaOp>(
[&](mlir::cir::AllocaOp op) {
// If the struct has a name like "aie::device<aie::npu1>", mark
// the operation illegal so it has to be rewritten
auto aieLike = cat.getOptionalTypeDetail(op.getType());
return !(aieLike && aieLike->base == "aie::device");
});
target.addDynamicallyLegalOp<mlir::cir::CallOp>([](mlir::cir::CallOp op) {
target.addDynamicallyLegalOp<cir::AllocaOp>([&](cir::AllocaOp op) {
// If the struct has a name like "aie::device<aie::npu1>", mark
// the operation illegal so it has to be rewritten
auto aieLike = cat.getOptionalTypeDetail(op.getType());
return !(aieLike && aieLike->base == "aie::device");
});
target.addDynamicallyLegalOp<cir::CallOp>([](cir::CallOp op) {
return !isCallingFunctionWithAnnotation(
op, {"aie.device.tile", "aie.tile.buffer"});
});
Expand Down Expand Up @@ -628,27 +623,27 @@ struct CIRToAIE : CIRToAIEBase<CIRToAIE> {

// Lower aie::tile::program(<tile code>) to aie.core
bool tryTileProgramLowering(mlir::Operation *op, mlir::OpBuilder &b) {
if (auto callOp = mlir::dyn_cast<mlir::cir::CallOp>(op)) {
if (auto callOp = mlir::dyn_cast<cir::CallOp>(op)) {
LLVM_DEBUG(
callOp.emitRemark("tryTileProgramLowering: CallOp using a tile"));
if (isCallingFunctionWithAnnotation(callOp, {"aie.tile.program"})) {
LLVM_DEBUG(
callOp.emitRemark("tryTileProgramLowering: CallOp using a tile"));
if (auto calledFunc =
mlir::SymbolTable::lookupNearestSymbolFrom<mlir::cir::FuncOp>(
mlir::SymbolTable::lookupNearestSymbolFrom<cir::FuncOp>(
callOp, callOp.getCalleeAttr())) {
// The last function instruction is cir.return and the one before
// is the call to the lambda
if (auto lambdaCall = mlir::dyn_cast<mlir::cir::CallOp>(
if (auto lambdaCall = mlir::dyn_cast<cir::CallOp>(
*std::next(calledFunc.getBlocks().front().rbegin()))) {
LLVM_DEBUG(lambdaCall.emitRemark("lambdaCall"));
if (auto lambdaFunc = mlir::SymbolTable::lookupNearestSymbolFrom<
mlir::cir::FuncOp>(lambdaCall,
lambdaCall.getCalleeAttr())) {
if (auto lambdaFunc =
mlir::SymbolTable::lookupNearestSymbolFrom<cir::FuncOp>(
lambdaCall, lambdaCall.getCalleeAttr())) {
LLVM_DEBUG(lambdaFunc.emitRemark(
"tryTileProgramLowering: Tile core lambda"));
assert(lambdaFunc.getLambda());
auto scopeOp = callOp->getParentOfType<mlir::cir::ScopeOp>();
auto scopeOp = callOp->getParentOfType<cir::ScopeOp>();
LLVM_DEBUG(scopeOp.emitRemark("tryTileProgramLowering: Scope"));
// \todo outline
auto tileDetail =
Expand Down Expand Up @@ -699,7 +694,7 @@ struct CIRToAIE : CIRToAIEBase<CIRToAIE> {
resolveSomeDeviceToTileAfterCloning(clone);
LLVM_DEBUG(
coreOp.emitRemark("tryTileProgramLowering: Stuffed core");
coreOp->getParentOfType<mlir::cir::FuncOp>().emitRemark(
coreOp->getParentOfType<cir::FuncOp>().emitRemark(
"tryTileProgramLowering: Top function"));
}
}
Expand Down
2 changes: 1 addition & 1 deletion tools/aie-lsp-server/aie-lsp-server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,6 @@ int main(int argc, char **argv) {
mlir::DialectRegistry registry;
mlir::registerAllDialects(registry);
xilinx::registerAllDialects(registry);
registry.insert<mlir::cir::CIRDialect>();
registry.insert<cir::CIRDialect>();
return mlir::failed(mlir::MlirLspServerMain(argc, argv, registry));
}
2 changes: 1 addition & 1 deletion tools/aie-opt/aie-opt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ int main(int argc, char **argv) {
llvm::cl::AddExtraVersionPrinter(versionPrinter);

// ClangIR dialect
registry.insert<mlir::cir::CIRDialect>();
registry.insert<cir::CIRDialect>();

// ClangIR-specific passes
mlir::registerPass([]() -> std::unique_ptr<mlir::Pass> {
Expand Down

0 comments on commit abce593

Please sign in to comment.