[Mlir-commits] [mlir] Expand the MemRefToEmitC pass - Lowering `AllocOp` (PR #148257)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jul 17 22:15:56 PDT 2025
https://github.com/ajaden-codes updated https://github.com/llvm/llvm-project/pull/148257
>From b62e8f9ac11a7e4d2dd70ab32258e58ac58e37bf Mon Sep 17 00:00:00 2001
From: Jaddyen <ajaden at google.com>
Date: Fri, 11 Jul 2025 00:23:11 +0000
Subject: [PATCH 1/3] allocop to emitc malloc
---
.../MemRefToEmitC/MemRefToEmitC.cpp | 36 +++++++++++++++++--
.../MemRefToEmitC/memref-to-emitc.mlir | 8 +++++
2 files changed, 42 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index db244d1d1cac8..382affaff429f 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -77,6 +77,38 @@ struct ConvertAlloca final : public OpConversionPattern<memref::AllocaOp> {
}
};
+struct ConvertAlloc final : public OpConversionPattern<memref::AllocOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(memref::AllocOp allocOp, OpAdaptor operands,
+ ConversionPatternRewriter &rewriter) const override {
+ mlir::Location loc = allocOp.getLoc();
+ auto memrefType = allocOp.getType();
+ if (!memrefType.hasStaticShape())
+ return rewriter.notifyMatchFailure(
+ allocOp.getLoc(), "cannot transform alloc op with dynamic shape");
+
+ int64_t totalSize =
+ memrefType.getNumElements() * memrefType.getElementTypeBitWidth() / 8;
+ auto alignment = allocOp.getAlignment();
+ if (alignment) {
+ int64_t alignVal = alignment.value();
+ totalSize = (totalSize + alignVal - 1) / alignVal * alignVal;
+ }
+ mlir::Value sizeBytes = rewriter.create<emitc::ConstantOp>(
+ loc, rewriter.getIndexType(),
+ rewriter.getIntegerAttr(rewriter.getIndexType(), totalSize));
+ auto mallocPtrType = emitc::PointerType::get(rewriter.getContext(),
+ memrefType.getElementType());
+ auto mallocCall = rewriter.create<emitc::CallOpaqueOp>(
+ loc, mallocPtrType, rewriter.getStringAttr("malloc"),
+ mlir::ValueRange{sizeBytes});
+
+ rewriter.replaceOp(allocOp, mallocCall);
+ return success();
+ }
+};
+
struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
using OpConversionPattern::OpConversionPattern;
@@ -222,6 +254,6 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
void mlir::populateMemRefToEmitCConversionPatterns(
RewritePatternSet &patterns, const TypeConverter &converter) {
- patterns.add<ConvertAlloca, ConvertGlobal, ConvertGetGlobal, ConvertLoad,
- ConvertStore>(converter, patterns.getContext());
+ patterns.add<ConvertAlloca, ConvertAlloc, ConvertGlobal, ConvertGetGlobal,
+ ConvertLoad, ConvertStore>(converter, patterns.getContext());
}
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
index d37fd1de90add..23e1c20670f8c 100644
--- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
@@ -8,6 +8,14 @@ func.func @alloca() {
return
}
+// CHECK-LABEL: alloc()
+func.func @alloc() {
+ // CHECK-NEXT: %0 = "emitc.constant"() <{value = 3996 : index}> : () -> index
+ // CHECK-NEXT: %1 = emitc.call_opaque "malloc"(%0) : (index) -> !emitc.ptr<i32>
+ %alloc = memref.alloc() : memref<999xi32>
+ return
+}
+
// -----
// CHECK-LABEL: memref_store
>From a36e0a3e1dbc4d746073d82d1bf868b1e3403284 Mon Sep 17 00:00:00 2001
From: Jaddyen <ajaden at google.com>
Date: Mon, 14 Jul 2025 20:24:46 +0000
Subject: [PATCH 2/3] Specific TODOs
---
mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp | 13 +++++++++----
1 file changed, 9 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index 382affaff429f..ee6b7d89a76a6 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -85,13 +85,18 @@ struct ConvertAlloc final : public OpConversionPattern<memref::AllocOp> {
mlir::Location loc = allocOp.getLoc();
auto memrefType = allocOp.getType();
if (!memrefType.hasStaticShape())
+ // TODO: Handle Dynamic shapes in the future. If the size
+ // of the allocation is the result of some function, we could
+ // potentially evaluate the function and use the result in the call to
+ // allocate.
return rewriter.notifyMatchFailure(
allocOp.getLoc(), "cannot transform alloc op with dynamic shape");
- int64_t totalSize =
- memrefType.getNumElements() * memrefType.getElementTypeBitWidth() / 8;
- auto alignment = allocOp.getAlignment();
- if (alignment) {
+ // TODO: Is there a better API to determine the number of bits in a byte in
+ // MLIR?
+ int64_t totalSize = memrefType.getNumElements() *
+ memrefType.getElementTypeBitWidth() / CHAR_BIT;
+ if (auto alignment = allocOp.getAlignment()) {
int64_t alignVal = alignment.value();
totalSize = (totalSize + alignVal - 1) / alignVal * alignVal;
}
>From 69e5a982e826af1eb7dc12a49d0ff7902d3ea732 Mon Sep 17 00:00:00 2001
From: Jaddyen <ajaden at google.com>
Date: Fri, 18 Jul 2025 05:14:35 +0000
Subject: [PATCH 3/3] Add the size computations to emitc output
---
mlir/include/mlir/Conversion/Passes.td | 2 +-
.../MemRefToEmitC/MemRefToEmitC.cpp | 74 ++++++++++++++-----
.../MemRefToEmitC/MemRefToEmitCPass.cpp | 10 +++
.../MemRefToEmitC/memref-to-emitc.mlir | 8 +-
4 files changed, 72 insertions(+), 22 deletions(-)
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 76e751243a12c..4660758f35b04 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -841,7 +841,7 @@ def ConvertMathToFuncs : Pass<"convert-math-to-funcs", "ModuleOp"> {
// MemRefToEmitC
//===----------------------------------------------------------------------===//
-def ConvertMemRefToEmitC : Pass<"convert-memref-to-emitc"> {
+def ConvertMemRefToEmitC : Pass<"convert-memref-to-emitc", "ModuleOp"> {
let summary = "Convert MemRef dialect to EmitC dialect";
let dependentDialects = ["emitc::EmitCDialect"];
}
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index ee6b7d89a76a6..d8ceb4b205a55 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -16,6 +16,7 @@
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
@@ -84,34 +85,69 @@ struct ConvertAlloc final : public OpConversionPattern<memref::AllocOp> {
ConversionPatternRewriter &rewriter) const override {
mlir::Location loc = allocOp.getLoc();
auto memrefType = allocOp.getType();
- if (!memrefType.hasStaticShape())
+ if (!memrefType.hasStaticShape()) {
// TODO: Handle Dynamic shapes in the future. If the size
// of the allocation is the result of some function, we could
// potentially evaluate the function and use the result in the call to
// allocate.
return rewriter.notifyMatchFailure(
- allocOp.getLoc(), "cannot transform alloc op with dynamic shape");
-
- // TODO: Is there a better API to determine the number of bits in a byte in
- // MLIR?
- int64_t totalSize = memrefType.getNumElements() *
- memrefType.getElementTypeBitWidth() / CHAR_BIT;
- if (auto alignment = allocOp.getAlignment()) {
- int64_t alignVal = alignment.value();
- totalSize = (totalSize + alignVal - 1) / alignVal * alignVal;
+ loc, "cannot transform alloc with dynamic shape");
}
- mlir::Value sizeBytes = rewriter.create<emitc::ConstantOp>(
- loc, rewriter.getIndexType(),
- rewriter.getIntegerAttr(rewriter.getIndexType(), totalSize));
- auto mallocPtrType = emitc::PointerType::get(rewriter.getContext(),
- memrefType.getElementType());
- auto mallocCall = rewriter.create<emitc::CallOpaqueOp>(
- loc, mallocPtrType, rewriter.getStringAttr("malloc"),
- mlir::ValueRange{sizeBytes});
- rewriter.replaceOp(allocOp, mallocCall);
+ Type elementType = memrefType.getElementType();
+ mlir::Value elementTypeLiteral = rewriter.create<emitc::LiteralOp>(
+ loc, mlir::emitc::OpaqueType::get(rewriter.getContext(), "type"),
+ rewriter.getStringAttr(getCTypeName(elementType)));
+ emitc::CallOpaqueOp sizeofElementOp = rewriter.create<emitc::CallOpaqueOp>(
+ loc, mlir::emitc::SizeTType::get(rewriter.getContext()),
+ rewriter.getStringAttr("sizeof"), mlir::ValueRange{elementTypeLiteral});
+ mlir::Value sizeofElement = sizeofElementOp.getResult(0);
+
+ unsigned int elementWidth = elementType.getIntOrFloatBitWidth();
+ mlir::Value numElements;
+ if (elementType.isF32())
+ numElements = rewriter.create<emitc::ConstantOp>(
+ loc, elementType, rewriter.getFloatAttr(elementType, elementWidth));
+ else
+ numElements = rewriter.create<emitc::ConstantOp>(
+ loc, elementType, rewriter.getIntegerAttr(elementType, elementWidth));
+ mlir::Value totalSizeBytes = rewriter.create<emitc::MulOp>(
+ loc, mlir::emitc::SizeTType::get(rewriter.getContext()), sizeofElement,
+ numElements);
+
+ auto mallocCall = rewriter.create<emitc::CallOpaqueOp>(
+ loc,
+ emitc::PointerType::get(
+ rewriter.getContext(),
+ mlir::emitc::OpaqueType::get(rewriter.getContext(), "void")),
+ rewriter.getStringAttr("malloc"), mlir::ValueRange{totalSizeBytes});
+ auto targetPointerType =
+ emitc::PointerType::get(rewriter.getContext(), elementType);
+ auto castOp = rewriter.create<emitc::CastOp>(loc, targetPointerType,
+ mallocCall.getResult(0));
+
+ rewriter.replaceOp(allocOp, castOp);
return success();
}
+
+private:
+ std::string getCTypeName(mlir::Type type) const {
+ if (type.isF32())
+ return "float";
+ if (type.isF64())
+ return "double";
+ if (type.isInteger(8))
+ return "int8_t";
+ if (type.isInteger(16))
+ return "int16_t";
+ if (type.isInteger(32))
+ return "int32_t";
+ if (type.isInteger(64))
+ return "int64_t";
+ if (type.isIndex())
+ return "size_t";
+ return "void";
+ }
};
struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
index cf25c09a2c2f3..d7544007718eb 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
@@ -30,6 +30,16 @@ struct ConvertMemRefToEmitCPass
: public impl::ConvertMemRefToEmitCBase<ConvertMemRefToEmitCPass> {
void runOnOperation() override {
TypeConverter converter;
+ mlir::ModuleOp module = getOperation();
+ module.walk([&](mlir::Operation *op) {
+ if (llvm::isa<mlir::memref::AllocOp, mlir::memref::CopyOp>(op)) {
+ OpBuilder builder(module.getBody(), module.getBody()->begin());
+ builder.create<emitc::IncludeOp>(module.getLoc(),
+ builder.getStringAttr("stdlib.h"));
+ return mlir::WalkResult::interrupt();
+ }
+ return mlir::WalkResult::advance();
+ });
// Fallback for other types.
converter.addConversion([](Type type) -> std::optional<Type> {
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
index 23e1c20670f8c..ad401c1a604ef 100644
--- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
@@ -10,8 +10,12 @@ func.func @alloca() {
// CHECK-LABEL: alloc()
func.func @alloc() {
- // CHECK-NEXT: %0 = "emitc.constant"() <{value = 3996 : index}> : () -> index
- // CHECK-NEXT: %1 = emitc.call_opaque "malloc"(%0) : (index) -> !emitc.ptr<i32>
+ // CHECK-NEXT: %0 = emitc.literal "int32_t" : !emitc.opaque<"type">
+ // CHECK-NEXT: %1 = emitc.call_opaque "sizeof"(%0) : (!emitc.opaque<"type">) -> !emitc.size_t
+ // CHECK-NEXT: %2 = "emitc.constant"() <{value = 32 : i32}> : () -> i32
+ // CHECK-NEXT: %3 = emitc.mul %1, %2 : (!emitc.size_t, i32) -> !emitc.size_t
+ // CHECK-NEXT: %4 = emitc.call_opaque "malloc"(%3) : (!emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void">>
+ // CHECK-NEXT: %5 = emitc.cast %4 : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<i32>
%alloc = memref.alloc() : memref<999xi32>
return
}
More information about the Mlir-commits
mailing list