[Mlir-commits] [mlir] Expand the MemRefToEmitC pass - Lowering `AllocOp` (PR #148257)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jul 17 22:15:56 PDT 2025


https://github.com/ajaden-codes updated https://github.com/llvm/llvm-project/pull/148257

>From b62e8f9ac11a7e4d2dd70ab32258e58ac58e37bf Mon Sep 17 00:00:00 2001
From: Jaddyen <ajaden at google.com>
Date: Fri, 11 Jul 2025 00:23:11 +0000
Subject: [PATCH 1/3] allocop to emitc malloc

---
 .../MemRefToEmitC/MemRefToEmitC.cpp           | 36 +++++++++++++++++--
 .../MemRefToEmitC/memref-to-emitc.mlir        |  8 +++++
 2 files changed, 42 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index db244d1d1cac8..382affaff429f 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -77,6 +77,38 @@ struct ConvertAlloca final : public OpConversionPattern<memref::AllocaOp> {
   }
 };
 
+struct ConvertAlloc final : public OpConversionPattern<memref::AllocOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(memref::AllocOp allocOp, OpAdaptor operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    mlir::Location loc = allocOp.getLoc();
+    auto memrefType = allocOp.getType();
+    if (!memrefType.hasStaticShape())
+      return rewriter.notifyMatchFailure(
+          allocOp.getLoc(), "cannot transform alloc op with dynamic shape");
+
+    int64_t totalSize =
+        memrefType.getNumElements() * memrefType.getElementTypeBitWidth() / 8;
+    auto alignment = allocOp.getAlignment();
+    if (alignment) {
+      int64_t alignVal = alignment.value();
+      totalSize = (totalSize + alignVal - 1) / alignVal * alignVal;
+    }
+    mlir::Value sizeBytes = rewriter.create<emitc::ConstantOp>(
+        loc, rewriter.getIndexType(),
+        rewriter.getIntegerAttr(rewriter.getIndexType(), totalSize));
+    auto mallocPtrType = emitc::PointerType::get(rewriter.getContext(),
+                                                 memrefType.getElementType());
+    auto mallocCall = rewriter.create<emitc::CallOpaqueOp>(
+        loc, mallocPtrType, rewriter.getStringAttr("malloc"),
+        mlir::ValueRange{sizeBytes});
+
+    rewriter.replaceOp(allocOp, mallocCall);
+    return success();
+  }
+};
+
 struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
   using OpConversionPattern::OpConversionPattern;
 
@@ -222,6 +254,6 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
 
 void mlir::populateMemRefToEmitCConversionPatterns(
     RewritePatternSet &patterns, const TypeConverter &converter) {
-  patterns.add<ConvertAlloca, ConvertGlobal, ConvertGetGlobal, ConvertLoad,
-               ConvertStore>(converter, patterns.getContext());
+  patterns.add<ConvertAlloca, ConvertAlloc, ConvertGlobal, ConvertGetGlobal,
+               ConvertLoad, ConvertStore>(converter, patterns.getContext());
 }
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
index d37fd1de90add..23e1c20670f8c 100644
--- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
@@ -8,6 +8,14 @@ func.func @alloca() {
   return
 }
 
+// CHECK-LABEL: alloc()
+func.func @alloc() {
+  // CHECK-NEXT:  %0 = "emitc.constant"() <{value = 3996 : index}> : () -> index
+  // CHECK-NEXT:  %1 = emitc.call_opaque "malloc"(%0) : (index) -> !emitc.ptr<i32>
+  %alloc = memref.alloc() : memref<999xi32>
+  return
+}
+
 // -----
 
 // CHECK-LABEL: memref_store

>From a36e0a3e1dbc4d746073d82d1bf868b1e3403284 Mon Sep 17 00:00:00 2001
From: Jaddyen <ajaden at google.com>
Date: Mon, 14 Jul 2025 20:24:46 +0000
Subject: [PATCH 2/3] Specific TODOs

---
 mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp | 13 +++++++++----
 1 file changed, 9 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index 382affaff429f..ee6b7d89a76a6 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -85,13 +85,18 @@ struct ConvertAlloc final : public OpConversionPattern<memref::AllocOp> {
     mlir::Location loc = allocOp.getLoc();
     auto memrefType = allocOp.getType();
     if (!memrefType.hasStaticShape())
+      // TODO: Handle Dynamic shapes in the future. If the size
+      // of the allocation is the result of some function, we could
+      // potentially evaluate the function and use the result in the call to
+      // allocate.
       return rewriter.notifyMatchFailure(
           allocOp.getLoc(), "cannot transform alloc op with dynamic shape");
 
-    int64_t totalSize =
-        memrefType.getNumElements() * memrefType.getElementTypeBitWidth() / 8;
-    auto alignment = allocOp.getAlignment();
-    if (alignment) {
+    // TODO: Is there a better API to determine the number of bits in a byte in
+    // MLIR?
+    int64_t totalSize = memrefType.getNumElements() *
+                        memrefType.getElementTypeBitWidth() / CHAR_BIT;
+    if (auto alignment = allocOp.getAlignment()) {
       int64_t alignVal = alignment.value();
       totalSize = (totalSize + alignVal - 1) / alignVal * alignVal;
     }

>From 69e5a982e826af1eb7dc12a49d0ff7902d3ea732 Mon Sep 17 00:00:00 2001
From: Jaddyen <ajaden at google.com>
Date: Fri, 18 Jul 2025 05:14:35 +0000
Subject: [PATCH 3/3] Add the size computations to emitc output

---
 mlir/include/mlir/Conversion/Passes.td        |  2 +-
 .../MemRefToEmitC/MemRefToEmitC.cpp           | 74 ++++++++++++++-----
 .../MemRefToEmitC/MemRefToEmitCPass.cpp       | 10 +++
 .../MemRefToEmitC/memref-to-emitc.mlir        |  8 +-
 4 files changed, 72 insertions(+), 22 deletions(-)

diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 76e751243a12c..4660758f35b04 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -841,7 +841,7 @@ def ConvertMathToFuncs : Pass<"convert-math-to-funcs", "ModuleOp"> {
 // MemRefToEmitC
 //===----------------------------------------------------------------------===//
 
-def ConvertMemRefToEmitC : Pass<"convert-memref-to-emitc"> {
+def ConvertMemRefToEmitC : Pass<"convert-memref-to-emitc", "ModuleOp"> {
   let summary = "Convert MemRef dialect to EmitC dialect";
   let dependentDialects = ["emitc::EmitCDialect"];
 }
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index ee6b7d89a76a6..d8ceb4b205a55 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/EmitC/IR/EmitC.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Transforms/DialectConversion.h"
 
@@ -84,34 +85,69 @@ struct ConvertAlloc final : public OpConversionPattern<memref::AllocOp> {
                   ConversionPatternRewriter &rewriter) const override {
     mlir::Location loc = allocOp.getLoc();
     auto memrefType = allocOp.getType();
-    if (!memrefType.hasStaticShape())
+    if (!memrefType.hasStaticShape()) {
       // TODO: Handle Dynamic shapes in the future. If the size
       // of the allocation is the result of some function, we could
       // potentially evaluate the function and use the result in the call to
       // allocate.
       return rewriter.notifyMatchFailure(
-          allocOp.getLoc(), "cannot transform alloc op with dynamic shape");
-
-    // TODO: Is there a better API to determine the number of bits in a byte in
-    // MLIR?
-    int64_t totalSize = memrefType.getNumElements() *
-                        memrefType.getElementTypeBitWidth() / CHAR_BIT;
-    if (auto alignment = allocOp.getAlignment()) {
-      int64_t alignVal = alignment.value();
-      totalSize = (totalSize + alignVal - 1) / alignVal * alignVal;
+          loc, "cannot transform alloc with dynamic shape");
     }
-    mlir::Value sizeBytes = rewriter.create<emitc::ConstantOp>(
-        loc, rewriter.getIndexType(),
-        rewriter.getIntegerAttr(rewriter.getIndexType(), totalSize));
-    auto mallocPtrType = emitc::PointerType::get(rewriter.getContext(),
-                                                 memrefType.getElementType());
-    auto mallocCall = rewriter.create<emitc::CallOpaqueOp>(
-        loc, mallocPtrType, rewriter.getStringAttr("malloc"),
-        mlir::ValueRange{sizeBytes});
 
-    rewriter.replaceOp(allocOp, mallocCall);
+    Type elementType = memrefType.getElementType();
+    mlir::Value elementTypeLiteral = rewriter.create<emitc::LiteralOp>(
+        loc, mlir::emitc::OpaqueType::get(rewriter.getContext(), "type"),
+        rewriter.getStringAttr(getCTypeName(elementType)));
+    emitc::CallOpaqueOp sizeofElementOp = rewriter.create<emitc::CallOpaqueOp>(
+        loc, mlir::emitc::SizeTType::get(rewriter.getContext()),
+        rewriter.getStringAttr("sizeof"), mlir::ValueRange{elementTypeLiteral});
+    mlir::Value sizeofElement = sizeofElementOp.getResult(0);
+
+    unsigned int elementWidth = elementType.getIntOrFloatBitWidth();
+    mlir::Value numElements;
+    if (elementType.isF32())
+      numElements = rewriter.create<emitc::ConstantOp>(
+          loc, elementType, rewriter.getFloatAttr(elementType, elementWidth));
+    else
+      numElements = rewriter.create<emitc::ConstantOp>(
+          loc, elementType, rewriter.getIntegerAttr(elementType, elementWidth));
+    mlir::Value totalSizeBytes = rewriter.create<emitc::MulOp>(
+        loc, mlir::emitc::SizeTType::get(rewriter.getContext()), sizeofElement,
+        numElements);
+
+    auto mallocCall = rewriter.create<emitc::CallOpaqueOp>(
+        loc,
+        emitc::PointerType::get(
+            rewriter.getContext(),
+            mlir::emitc::OpaqueType::get(rewriter.getContext(), "void")),
+        rewriter.getStringAttr("malloc"), mlir::ValueRange{totalSizeBytes});
+    auto targetPointerType =
+        emitc::PointerType::get(rewriter.getContext(), elementType);
+    auto castOp = rewriter.create<emitc::CastOp>(loc, targetPointerType,
+                                                 mallocCall.getResult(0));
+
+    rewriter.replaceOp(allocOp, castOp);
     return success();
   }
+
+private:
+  std::string getCTypeName(mlir::Type type) const {
+    if (type.isF32())
+      return "float";
+    if (type.isF64())
+      return "double";
+    if (type.isInteger(8))
+      return "int8_t";
+    if (type.isInteger(16))
+      return "int16_t";
+    if (type.isInteger(32))
+      return "int32_t";
+    if (type.isInteger(64))
+      return "int64_t";
+    if (type.isIndex())
+      return "size_t";
+    return "void";
+  }
 };
 
 struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
index cf25c09a2c2f3..d7544007718eb 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
@@ -30,6 +30,16 @@ struct ConvertMemRefToEmitCPass
     : public impl::ConvertMemRefToEmitCBase<ConvertMemRefToEmitCPass> {
   void runOnOperation() override {
     TypeConverter converter;
+    mlir::ModuleOp module = getOperation();
+    module.walk([&](mlir::Operation *op) {
+      if (llvm::isa<mlir::memref::AllocOp, mlir::memref::CopyOp>(op)) {
+        OpBuilder builder(module.getBody(), module.getBody()->begin());
+        builder.create<emitc::IncludeOp>(module.getLoc(),
+                                         builder.getStringAttr("stdlib.h"));
+        return mlir::WalkResult::interrupt();
+      }
+      return mlir::WalkResult::advance();
+    });
 
     // Fallback for other types.
     converter.addConversion([](Type type) -> std::optional<Type> {
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
index 23e1c20670f8c..ad401c1a604ef 100644
--- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
@@ -10,8 +10,12 @@ func.func @alloca() {
 
 // CHECK-LABEL: alloc()
 func.func @alloc() {
-  // CHECK-NEXT:  %0 = "emitc.constant"() <{value = 3996 : index}> : () -> index
-  // CHECK-NEXT:  %1 = emitc.call_opaque "malloc"(%0) : (index) -> !emitc.ptr<i32>
+  // CHECK-NEXT:  %0 = emitc.literal "int32_t" : !emitc.opaque<"type">
+  // CHECK-NEXT:  %1 = emitc.call_opaque "sizeof"(%0) : (!emitc.opaque<"type">) -> !emitc.size_t
+  // CHECK-NEXT:  %2 = "emitc.constant"() <{value = 32 : i32}> : () -> i32
+  // CHECK-NEXT:  %3 = emitc.mul %1, %2 : (!emitc.size_t, i32) -> !emitc.size_t
+  // CHECK-NEXT:  %4 = emitc.call_opaque "malloc"(%3) : (!emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void">>
+  // CHECK-NEXT:  %5 = emitc.cast %4 : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<i32>
   %alloc = memref.alloc() : memref<999xi32>
   return
 }



More information about the Mlir-commits mailing list