[Mlir-commits] [mlir] [mlir][EmitC] Expand the MemRefToEmitC pass - Lowering `CopyOp` (PR #151206)
Jaden Angella
llvmlistbot at llvm.org
Fri Aug 8 14:25:14 PDT 2025
https://github.com/Jaddyen updated https://github.com/llvm/llvm-project/pull/151206
>From 5a54bdf27d2619982c876773b7ae3a9477eeeefb Mon Sep 17 00:00:00 2001
From: Jaddyen <ajaden at google.com>
Date: Fri, 25 Jul 2025 16:43:57 +0000
Subject: [PATCH 1/9] Adding lowering for copyop
---
.../MemRefToEmitC/MemRefToEmitC.cpp | 60 +++++++++++++++++++
.../MemRefToEmitC/memref-to-emitc-copy.mlir | 25 ++++++++
2 files changed, 85 insertions(+)
create mode 100644 mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index 6bd0e2d4d4b08..34ea4989c8156 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -87,6 +87,66 @@ struct ConvertAlloca final : public OpConversionPattern<memref::AllocaOp> {
}
};
+struct ConvertCopy final : public OpConversionPattern<memref::CopyOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(memref::CopyOp copyOp, OpAdaptor operands,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = copyOp.getLoc();
+ auto srcMemrefType = dyn_cast<MemRefType>(copyOp.getSource().getType());
+ auto targetMemrefType = dyn_cast<MemRefType>(copyOp.getTarget().getType());
+
+ if (!srcMemrefType || !targetMemrefType) {
+ return failure();
+ }
+
+ // 1. Cast source memref to a pointer.
+ auto srcPtrType = emitc::PointerType::get(srcMemrefType.getElementType());
+ auto srcArrayValue =
+ dyn_cast<TypedValue<emitc::ArrayType>>(operands.getSource());
+ auto stcArrayPtr =
+ emitc::PointerType::get(srcArrayValue.getType().getElementType());
+ auto srcPtr = rewriter.create<emitc::CastOp>(loc, srcPtrType,
+ stcArrayPtr.getPointee());
+
+ // 2. Cast target memref to a pointer.
+ auto targetPtrType =
+ emitc::PointerType::get(targetMemrefType.getElementType());
+
+ auto arrayValue =
+ dyn_cast<TypedValue<emitc::ArrayType>>(operands.getTarget());
+
+ // Cast the target memref value to a pointer type.
+ auto targetPtr =
+ rewriter.create<emitc::CastOp>(loc, targetPtrType, arrayValue);
+
+ // 3. Calculate the size in bytes of the memref.
+ auto elementSize = rewriter.create<emitc::CallOpaqueOp>(
+ loc, rewriter.getIndexType(), rewriter.getStringAttr("sizeof"),
+ mlir::ValueRange{},
+ mlir::ArrayAttr::get(
+ rewriter.getContext(),
+ {mlir::TypeAttr::get(srcMemrefType.getElementType())}));
+
+ auto numElements = rewriter.create<emitc::ConstantOp>(
+ loc, rewriter.getIndexType(),
+ rewriter.getIntegerAttr(rewriter.getIndexType(),
+ srcMemrefType.getNumElements()));
+ auto byteSize = rewriter.create<emitc::MulOp>(loc, rewriter.getIndexType(),
+ elementSize.getResult(0),
+ numElements.getResult());
+
+ // 4. Emit the memcpy call.
+ rewriter.create<emitc::CallOpaqueOp>(loc, TypeRange{}, "memcpy",
+ ValueRange{targetPtr.getResult(),
+ srcPtr.getResult(),
+ byteSize.getResult()});
+
+ return success();
+ }
+};
+
Type convertMemRefType(MemRefType opTy, const TypeConverter *typeConverter) {
Type resultTy;
if (opTy.getRank() == 0) {
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir
new file mode 100644
index 0000000000000..d031d60508df2
--- /dev/null
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir
@@ -0,0 +1,25 @@
+// RUN: mlir-opt -convert-memref-to-emitc %s -split-input-file | FileCheck %s
+
+func.func @copying(%arg0 : memref<2x4xf32>) {
+ memref.copy %arg0, %arg0 : memref<2x4xf32> to memref<2x4xf32>
+ return
+}
+
+// func.func @copying_memcpy(%arg_0: !emitc.ptr<f32>) {
+// %size = "emitc.constant"() <{value = 8 : index}> :() -> index
+// %element_size = "emitc.constant"() <{value = 4 : index}> :() -> index
+// %total_bytes = emitc.mul %size, %element_size : (index, index) -> index
+
+// emitc.call_opaque "memcpy"(%arg_0, %arg_0, %total_bytes) : (!emitc.ptr<f32>, !emitc.ptr<f32>, index) -> ()
+// return
+// }
+
+// CHECK-LABEL: copying_memcpy
+// CHECK-SAME: %arg_0: !emitc.ptr<f32>
+// CHECK-NEXT: %size = "emitc.constant"() <{value = 8 : index}> :() -> index
+// CHECK-NEXT: %element_size = "emitc.constant"() <{value = 4 : index}> :() -> index
+// CHECK-NEXT: %total_bytes = emitc.mul %size, %element_size : (index, index) -> index
+// CHECK-NEXT: emitc.call_opaque "memcpy"
+// CHECK-SAME: (%arg_0, %arg_0, %total_bytes)
+// CHECK-NEXT: : (!emitc.ptr<f32>, !emitc.ptr<f32>, index) -> ()
+// CHECK-NEXT: return
\ No newline at end of file
>From b5b7637c91ec064933fb5ab94a04956fe3d4783e Mon Sep 17 00:00:00 2001
From: Jaddyen <ajaden at google.com>
Date: Tue, 29 Jul 2025 18:14:39 +0000
Subject: [PATCH 2/9] use subscript and apply ops
---
.../Conversion/MemRefToEmitC/MemRefToEmitC.h | 2 +
.../MemRefToEmitC/MemRefToEmitC.cpp | 137 ++++++++++--------
.../MemRefToEmitC/MemRefToEmitCPass.cpp | 27 +++-
.../MemRefToEmitC/memref-to-emitc-copy.mlir | 35 +++--
.../MemRefToEmitC/memref-to-emitc-failed.mlir | 8 -
5 files changed, 114 insertions(+), 95 deletions(-)
diff --git a/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h b/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h
index b595b6a308bea..4ea6649d64a92 100644
--- a/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h
+++ b/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h
@@ -10,8 +10,10 @@
constexpr const char *alignedAllocFunctionName = "aligned_alloc";
constexpr const char *mallocFunctionName = "malloc";
+constexpr const char *memcpyFunctionName = "memcpy";
constexpr const char *cppStandardLibraryHeader = "cstdlib";
constexpr const char *cStandardLibraryHeader = "stdlib.h";
+constexpr const char *stringLibraryHeader = "string.h";
namespace mlir {
class DialectRegistry;
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index 34ea4989c8156..adb0eb77fdf35 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -87,66 +87,6 @@ struct ConvertAlloca final : public OpConversionPattern<memref::AllocaOp> {
}
};
-struct ConvertCopy final : public OpConversionPattern<memref::CopyOp> {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(memref::CopyOp copyOp, OpAdaptor operands,
- ConversionPatternRewriter &rewriter) const override {
- Location loc = copyOp.getLoc();
- auto srcMemrefType = dyn_cast<MemRefType>(copyOp.getSource().getType());
- auto targetMemrefType = dyn_cast<MemRefType>(copyOp.getTarget().getType());
-
- if (!srcMemrefType || !targetMemrefType) {
- return failure();
- }
-
- // 1. Cast source memref to a pointer.
- auto srcPtrType = emitc::PointerType::get(srcMemrefType.getElementType());
- auto srcArrayValue =
- dyn_cast<TypedValue<emitc::ArrayType>>(operands.getSource());
- auto stcArrayPtr =
- emitc::PointerType::get(srcArrayValue.getType().getElementType());
- auto srcPtr = rewriter.create<emitc::CastOp>(loc, srcPtrType,
- stcArrayPtr.getPointee());
-
- // 2. Cast target memref to a pointer.
- auto targetPtrType =
- emitc::PointerType::get(targetMemrefType.getElementType());
-
- auto arrayValue =
- dyn_cast<TypedValue<emitc::ArrayType>>(operands.getTarget());
-
- // Cast the target memref value to a pointer type.
- auto targetPtr =
- rewriter.create<emitc::CastOp>(loc, targetPtrType, arrayValue);
-
- // 3. Calculate the size in bytes of the memref.
- auto elementSize = rewriter.create<emitc::CallOpaqueOp>(
- loc, rewriter.getIndexType(), rewriter.getStringAttr("sizeof"),
- mlir::ValueRange{},
- mlir::ArrayAttr::get(
- rewriter.getContext(),
- {mlir::TypeAttr::get(srcMemrefType.getElementType())}));
-
- auto numElements = rewriter.create<emitc::ConstantOp>(
- loc, rewriter.getIndexType(),
- rewriter.getIntegerAttr(rewriter.getIndexType(),
- srcMemrefType.getNumElements()));
- auto byteSize = rewriter.create<emitc::MulOp>(loc, rewriter.getIndexType(),
- elementSize.getResult(0),
- numElements.getResult());
-
- // 4. Emit the memcpy call.
- rewriter.create<emitc::CallOpaqueOp>(loc, TypeRange{}, "memcpy",
- ValueRange{targetPtr.getResult(),
- srcPtr.getResult(),
- byteSize.getResult()});
-
- return success();
- }
-};
-
Type convertMemRefType(MemRefType opTy, const TypeConverter *typeConverter) {
Type resultTy;
if (opTy.getRank() == 0) {
@@ -157,6 +97,29 @@ Type convertMemRefType(MemRefType opTy, const TypeConverter *typeConverter) {
return resultTy;
}
+Value calculateMemrefTotalSizeBytes(Location loc, MemRefType memrefType,
+ ConversionPatternRewriter &rewriter) {
+ emitc::CallOpaqueOp elementSize = rewriter.create<emitc::CallOpaqueOp>(
+ loc, emitc::SizeTType::get(rewriter.getContext()),
+ rewriter.getStringAttr("sizeof"), ValueRange{},
+ ArrayAttr::get(rewriter.getContext(),
+ {TypeAttr::get(memrefType.getElementType())}));
+
+ IndexType indexType = rewriter.getIndexType();
+ int64_t numElements = 1;
+ for (int64_t dimSize : memrefType.getShape()) {
+ numElements *= dimSize;
+ }
+ emitc::ConstantOp numElementsValue = rewriter.create<emitc::ConstantOp>(
+ loc, indexType, rewriter.getIndexAttr(numElements));
+
+ Type sizeTType = emitc::SizeTType::get(rewriter.getContext());
+ emitc::MulOp totalSizeBytes = rewriter.create<emitc::MulOp>(
+ loc, sizeTType, elementSize.getResult(0), numElementsValue);
+
+ return totalSizeBytes.getResult();
+}
+
struct ConvertAlloc final : public OpConversionPattern<memref::AllocOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
@@ -219,6 +182,55 @@ struct ConvertAlloc final : public OpConversionPattern<memref::AllocOp> {
}
};
+struct ConvertCopy final : public OpConversionPattern<memref::CopyOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(memref::CopyOp copyOp, OpAdaptor operands,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = copyOp.getLoc();
+ MemRefType srcMemrefType =
+ dyn_cast<MemRefType>(copyOp.getSource().getType());
+ MemRefType targetMemrefType =
+ dyn_cast<MemRefType>(copyOp.getTarget().getType());
+
+ if (!isMemRefTypeLegalForEmitC(srcMemrefType) ||
+ !isMemRefTypeLegalForEmitC(targetMemrefType)) {
+ return rewriter.notifyMatchFailure(
+ loc, "incompatible memref type for EmitC conversion");
+ }
+
+ emitc::ConstantOp zeroIndex = rewriter.create<emitc::ConstantOp>(
+ loc, rewriter.getIndexType(), rewriter.getIndexAttr(0));
+ auto srcArrayValue =
+ dyn_cast<TypedValue<emitc::ArrayType>>(operands.getSource());
+
+ emitc::SubscriptOp srcSubPtr = rewriter.create<emitc::SubscriptOp>(
+ loc, srcArrayValue, ValueRange{zeroIndex, zeroIndex});
+ emitc::ApplyOp srcPtr = rewriter.create<emitc::ApplyOp>(
+ loc, emitc::PointerType::get(srcMemrefType.getElementType()),
+ rewriter.getStringAttr("&"), srcSubPtr);
+
+ auto arrayValue =
+ dyn_cast<TypedValue<emitc::ArrayType>>(operands.getTarget());
+ emitc::SubscriptOp targetSubPtr = rewriter.create<emitc::SubscriptOp>(
+ loc, arrayValue, ValueRange{zeroIndex, zeroIndex});
+ emitc::ApplyOp targetPtr = rewriter.create<emitc::ApplyOp>(
+ loc, emitc::PointerType::get(targetMemrefType.getElementType()),
+ rewriter.getStringAttr("&"), targetSubPtr);
+
+ emitc::CallOpaqueOp memCpyCall = rewriter.create<emitc::CallOpaqueOp>(
+ loc, TypeRange{}, "memcpy",
+ ValueRange{
+ targetPtr.getResult(), srcPtr.getResult(),
+ calculateMemrefTotalSizeBytes(loc, srcMemrefType, rewriter)});
+
+ rewriter.replaceOp(copyOp, memCpyCall.getResults());
+
+ return success();
+ }
+};
+
struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
using OpConversionPattern::OpConversionPattern;
@@ -380,6 +392,7 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
void mlir::populateMemRefToEmitCConversionPatterns(
RewritePatternSet &patterns, const TypeConverter &converter) {
- patterns.add<ConvertAlloca, ConvertAlloc, ConvertGlobal, ConvertGetGlobal,
- ConvertLoad, ConvertStore>(converter, patterns.getContext());
+ patterns.add<ConvertAlloca, ConvertAlloc, ConvertCopy, ConvertGlobal,
+ ConvertGetGlobal, ConvertLoad, ConvertStore>(
+ converter, patterns.getContext());
}
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
index e78dd76d6e256..8e965b42f1043 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
@@ -18,6 +18,7 @@
#include "mlir/IR/Attributes.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/StringRef.h"
namespace mlir {
#define GEN_PASS_DEF_CONVERTMEMREFTOEMITC
@@ -27,6 +28,15 @@ namespace mlir {
using namespace mlir;
namespace {
+
+emitc::IncludeOp addHeader(OpBuilder &builder, ModuleOp module,
+ StringRef headerName) {
+ StringAttr includeAttr = builder.getStringAttr(headerName);
+ return builder.create<emitc::IncludeOp>(
+ module.getLoc(), includeAttr,
+ /*is_standard_include=*/builder.getUnitAttr());
+}
+
struct ConvertMemRefToEmitCPass
: public impl::ConvertMemRefToEmitCBase<ConvertMemRefToEmitCPass> {
using Base::Base;
@@ -57,7 +67,8 @@ struct ConvertMemRefToEmitCPass
mlir::ModuleOp module = getOperation();
module.walk([&](mlir::emitc::CallOpaqueOp callOp) {
if (callOp.getCallee() != alignedAllocFunctionName &&
- callOp.getCallee() != mallocFunctionName) {
+ callOp.getCallee() != mallocFunctionName &&
+ callOp.getCallee() != memcpyFunctionName) {
return mlir::WalkResult::advance();
}
@@ -76,12 +87,14 @@ struct ConvertMemRefToEmitCPass
}
mlir::OpBuilder builder(module.getBody(), module.getBody()->begin());
- StringAttr includeAttr =
- builder.getStringAttr(options.lowerToCpp ? cppStandardLibraryHeader
- : cStandardLibraryHeader);
- builder.create<mlir::emitc::IncludeOp>(
- module.getLoc(), includeAttr,
- /*is_standard_include=*/builder.getUnitAttr());
+ StringRef headerName;
+ if (callOp.getCallee() == memcpyFunctionName)
+ headerName = stringLibraryHeader;
+ else
+ headerName = options.lowerToCpp ? cppStandardLibraryHeader
+ : cStandardLibraryHeader;
+
+ addHeader(builder, module, headerName);
return mlir::WalkResult::interrupt();
});
}
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir
index d031d60508df2..4b6eb50807513 100644
--- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir
@@ -1,25 +1,24 @@
-// RUN: mlir-opt -convert-memref-to-emitc %s -split-input-file | FileCheck %s
+// RUN: mlir-opt -convert-memref-to-emitc %s | FileCheck %s
func.func @copying(%arg0 : memref<2x4xf32>) {
memref.copy %arg0, %arg0 : memref<2x4xf32> to memref<2x4xf32>
return
}
-// func.func @copying_memcpy(%arg_0: !emitc.ptr<f32>) {
-// %size = "emitc.constant"() <{value = 8 : index}> :() -> index
-// %element_size = "emitc.constant"() <{value = 4 : index}> :() -> index
-// %total_bytes = emitc.mul %size, %element_size : (index, index) -> index
-
-// emitc.call_opaque "memcpy"(%arg_0, %arg_0, %total_bytes) : (!emitc.ptr<f32>, !emitc.ptr<f32>, index) -> ()
-// return
-// }
+// CHECK: module {
+// CHECK-NEXT: emitc.include <"string.h">
+// CHECK-LABEL: copying
+// CHECK-NEXT: %0 = builtin.unrealized_conversion_cast %arg0 : memref<2x4xf32> to !emitc.array<2x4xf32>
+// CHECK-NEXT: %1 = "emitc.constant"() <{value = 0 : index}> : () -> index
+// CHECK-NEXT: %2 = emitc.subscript %0[%1, %1] : (!emitc.array<2x4xf32>, index, index) -> !emitc.lvalue<f32>
+// CHECK-NEXT: %3 = emitc.apply "&"(%2) : (!emitc.lvalue<f32>) -> !emitc.ptr<f32>
+// CHECK-NEXT: %4 = emitc.subscript %0[%1, %1] : (!emitc.array<2x4xf32>, index, index) -> !emitc.lvalue<f32>
+// CHECK-NEXT: %5 = emitc.apply "&"(%4) : (!emitc.lvalue<f32>) -> !emitc.ptr<f32>
+// CHECK-NEXT: %6 = emitc.call_opaque "sizeof"() {args = [f32]} : () -> !emitc.size_t
+// CHECK-NEXT: %7 = "emitc.constant"() <{value = 8 : index}> : () -> index
+// CHECK-NEXT: %8 = emitc.mul %6, %7 : (!emitc.size_t, index) -> !emitc.size_t
+// CHECK-NEXT: emitc.call_opaque "memcpy"(%5, %3, %8) : (!emitc.ptr<f32>, !emitc.ptr<f32>, !emitc.size_t) -> ()
+// CHECK-NEXT: return
+// CHECK-NEXT: }
+// CHECK-NEXT:}
-// CHECK-LABEL: copying_memcpy
-// CHECK-SAME: %arg_0: !emitc.ptr<f32>
-// CHECK-NEXT: %size = "emitc.constant"() <{value = 8 : index}> :() -> index
-// CHECK-NEXT: %element_size = "emitc.constant"() <{value = 4 : index}> :() -> index
-// CHECK-NEXT: %total_bytes = emitc.mul %size, %element_size : (index, index) -> index
-// CHECK-NEXT: emitc.call_opaque "memcpy"
-// CHECK-SAME: (%arg_0, %arg_0, %total_bytes)
-// CHECK-NEXT: : (!emitc.ptr<f32>, !emitc.ptr<f32>, index) -> ()
-// CHECK-NEXT: return
\ No newline at end of file
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir
index fda01974d3fc8..b6eccfc8f0050 100644
--- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir
@@ -1,13 +1,5 @@
// RUN: mlir-opt -convert-memref-to-emitc %s -split-input-file -verify-diagnostics
-func.func @memref_op(%arg0 : memref<2x4xf32>) {
- // expected-error at +1 {{failed to legalize operation 'memref.copy'}}
- memref.copy %arg0, %arg0 : memref<2x4xf32> to memref<2x4xf32>
- return
-}
-
-// -----
-
func.func @alloca_with_dynamic_shape() {
%0 = index.constant 1
// expected-error at +1 {{failed to legalize operation 'memref.alloca'}}
>From c17cf73dbcd269f2de62f6efd473d40caff5805c Mon Sep 17 00:00:00 2001
From: Jaddyen <ajaden at google.com>
Date: Mon, 4 Aug 2025 18:11:17 +0000
Subject: [PATCH 3/9] allow for multi dimensional arays
---
.../Conversion/MemRefToEmitC/MemRefToEmitC.h | 3 +-
.../MemRefToEmitC/MemRefToEmitC.cpp | 88 ++++++++++++-------
.../MemRefToEmitC/MemRefToEmitCPass.cpp | 19 ++--
3 files changed, 68 insertions(+), 42 deletions(-)
diff --git a/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h b/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h
index 4ea6649d64a92..5abfb3d7e72dd 100644
--- a/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h
+++ b/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h
@@ -13,7 +13,8 @@ constexpr const char *mallocFunctionName = "malloc";
constexpr const char *memcpyFunctionName = "memcpy";
constexpr const char *cppStandardLibraryHeader = "cstdlib";
constexpr const char *cStandardLibraryHeader = "stdlib.h";
-constexpr const char *stringLibraryHeader = "string.h";
+constexpr const char *cppStringLibraryHeader = "cstring";
+constexpr const char *cStringLibraryHeader = "string.h";
namespace mlir {
class DialectRegistry;
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index adb0eb77fdf35..cabbfac4a1dca 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeRange.h"
#include "mlir/IR/Value.h"
@@ -98,23 +99,25 @@ Type convertMemRefType(MemRefType opTy, const TypeConverter *typeConverter) {
}
Value calculateMemrefTotalSizeBytes(Location loc, MemRefType memrefType,
- ConversionPatternRewriter &rewriter) {
- emitc::CallOpaqueOp elementSize = rewriter.create<emitc::CallOpaqueOp>(
- loc, emitc::SizeTType::get(rewriter.getContext()),
- rewriter.getStringAttr("sizeof"), ValueRange{},
- ArrayAttr::get(rewriter.getContext(),
+ OpBuilder &builder) {
+ assert(isMemRefTypeLegalForEmitC(memrefType) &&
+ "incompatible memref type for EmitC conversion");
+ emitc::CallOpaqueOp elementSize = builder.create<emitc::CallOpaqueOp>(
+ loc, emitc::SizeTType::get(builder.getContext()),
+ builder.getStringAttr("sizeof"), ValueRange{},
+ ArrayAttr::get(builder.getContext(),
{TypeAttr::get(memrefType.getElementType())}));
- IndexType indexType = rewriter.getIndexType();
+ IndexType indexType = builder.getIndexType();
int64_t numElements = 1;
for (int64_t dimSize : memrefType.getShape()) {
numElements *= dimSize;
}
- emitc::ConstantOp numElementsValue = rewriter.create<emitc::ConstantOp>(
- loc, indexType, rewriter.getIndexAttr(numElements));
+ emitc::ConstantOp numElementsValue = builder.create<emitc::ConstantOp>(
+ loc, indexType, builder.getIndexAttr(numElements));
- Type sizeTType = emitc::SizeTType::get(rewriter.getContext());
- emitc::MulOp totalSizeBytes = rewriter.create<emitc::MulOp>(
+ Type sizeTType = emitc::SizeTType::get(builder.getContext());
+ emitc::MulOp totalSizeBytes = builder.create<emitc::MulOp>(
loc, sizeTType, elementSize.getResult(0), numElementsValue);
return totalSizeBytes.getResult();
@@ -189,41 +192,64 @@ struct ConvertCopy final : public OpConversionPattern<memref::CopyOp> {
matchAndRewrite(memref::CopyOp copyOp, OpAdaptor operands,
ConversionPatternRewriter &rewriter) const override {
Location loc = copyOp.getLoc();
- MemRefType srcMemrefType =
- dyn_cast<MemRefType>(copyOp.getSource().getType());
+ MemRefType srcMemrefType = cast<MemRefType>(copyOp.getSource().getType());
MemRefType targetMemrefType =
- dyn_cast<MemRefType>(copyOp.getTarget().getType());
+ cast<MemRefType>(copyOp.getTarget().getType());
- if (!isMemRefTypeLegalForEmitC(srcMemrefType) ||
- !isMemRefTypeLegalForEmitC(targetMemrefType)) {
+ if (!isMemRefTypeLegalForEmitC(srcMemrefType)) {
return rewriter.notifyMatchFailure(
- loc, "incompatible memref type for EmitC conversion");
+ loc, "incompatible source memref type for EmitC conversion");
+ }
+ if (!isMemRefTypeLegalForEmitC(targetMemrefType)) {
+ return rewriter.notifyMatchFailure(
+ loc, "incompatible target memref type for EmitC conversion");
}
+ auto createPointerFromEmitcArray =
+ [&](mlir::Location loc, mlir::OpBuilder &rewriter,
+ mlir::TypedValue<emitc::ArrayType> arrayValue,
+ mlir::MemRefType memrefType,
+ emitc::ConstantOp zeroIndex) -> emitc::ApplyOp {
+ // Get the rank of the array to create the correct number of zero indices.
+ int64_t rank = arrayValue.getType().getRank();
+ llvm::SmallVector<mlir::Value> indices;
+ for (int i = 0; i < rank; ++i) {
+ indices.push_back(zeroIndex);
+ }
+
+ // Create a subscript operation to get the element at index [0, 0, ...,
+ // 0].
+ emitc::SubscriptOp subPtr = rewriter.create<emitc::SubscriptOp>(
+ loc, arrayValue, mlir::ValueRange(indices));
+
+ // Create an apply operation to take the address of the subscripted
+ // element.
+ emitc::ApplyOp ptr = rewriter.create<emitc::ApplyOp>(
+ loc, emitc::PointerType::get(memrefType.getElementType()),
+ rewriter.getStringAttr("&"), subPtr);
+
+ return ptr;
+ };
+
+ // Create a constant zero index.
emitc::ConstantOp zeroIndex = rewriter.create<emitc::ConstantOp>(
loc, rewriter.getIndexType(), rewriter.getIndexAttr(0));
+
auto srcArrayValue =
dyn_cast<TypedValue<emitc::ArrayType>>(operands.getSource());
+ emitc::ApplyOp srcPtr = createPointerFromEmitcArray(
+ loc, rewriter, srcArrayValue, srcMemrefType, zeroIndex);
- emitc::SubscriptOp srcSubPtr = rewriter.create<emitc::SubscriptOp>(
- loc, srcArrayValue, ValueRange{zeroIndex, zeroIndex});
- emitc::ApplyOp srcPtr = rewriter.create<emitc::ApplyOp>(
- loc, emitc::PointerType::get(srcMemrefType.getElementType()),
- rewriter.getStringAttr("&"), srcSubPtr);
-
- auto arrayValue =
+ auto targetArrayValue =
dyn_cast<TypedValue<emitc::ArrayType>>(operands.getTarget());
- emitc::SubscriptOp targetSubPtr = rewriter.create<emitc::SubscriptOp>(
- loc, arrayValue, ValueRange{zeroIndex, zeroIndex});
- emitc::ApplyOp targetPtr = rewriter.create<emitc::ApplyOp>(
- loc, emitc::PointerType::get(targetMemrefType.getElementType()),
- rewriter.getStringAttr("&"), targetSubPtr);
+ emitc::ApplyOp targetPtr = createPointerFromEmitcArray(
+ loc, rewriter, targetArrayValue, targetMemrefType, zeroIndex);
+ OpBuilder builder = rewriter;
emitc::CallOpaqueOp memCpyCall = rewriter.create<emitc::CallOpaqueOp>(
loc, TypeRange{}, "memcpy",
- ValueRange{
- targetPtr.getResult(), srcPtr.getResult(),
- calculateMemrefTotalSizeBytes(loc, srcMemrefType, rewriter)});
+ ValueRange{targetPtr.getResult(), srcPtr.getResult(),
+ calculateMemrefTotalSizeBytes(loc, srcMemrefType, builder)});
rewriter.replaceOp(copyOp, memCpyCall.getResults());
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
index 8e965b42f1043..c60e7488fdb38 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
@@ -29,8 +29,8 @@ using namespace mlir;
namespace {
-emitc::IncludeOp addHeader(OpBuilder &builder, ModuleOp module,
- StringRef headerName) {
+emitc::IncludeOp addStandardHeader(OpBuilder &builder, ModuleOp module,
+ StringRef headerName) {
StringAttr includeAttr = builder.getStringAttr(headerName);
return builder.create<emitc::IncludeOp>(
module.getLoc(), includeAttr,
@@ -68,33 +68,32 @@ struct ConvertMemRefToEmitCPass
module.walk([&](mlir::emitc::CallOpaqueOp callOp) {
if (callOp.getCallee() != alignedAllocFunctionName &&
callOp.getCallee() != mallocFunctionName &&
- callOp.getCallee() != memcpyFunctionName) {
+ callOp.getCallee() != memcpyFunctionName)
return mlir::WalkResult::advance();
- }
for (auto &op : *module.getBody()) {
emitc::IncludeOp includeOp = llvm::dyn_cast<mlir::emitc::IncludeOp>(op);
- if (!includeOp) {
+ if (!includeOp)
continue;
- }
+
if (includeOp.getIsStandardInclude() &&
((options.lowerToCpp &&
includeOp.getInclude() == cppStandardLibraryHeader) ||
(!options.lowerToCpp &&
- includeOp.getInclude() == cStandardLibraryHeader))) {
+ includeOp.getInclude() == cStandardLibraryHeader)))
return mlir::WalkResult::interrupt();
- }
}
mlir::OpBuilder builder(module.getBody(), module.getBody()->begin());
StringRef headerName;
if (callOp.getCallee() == memcpyFunctionName)
- headerName = stringLibraryHeader;
+ headerName =
+ options.lowerToCpp ? cppStringLibraryHeader : cStringLibraryHeader;
else
headerName = options.lowerToCpp ? cppStandardLibraryHeader
: cStandardLibraryHeader;
- addHeader(builder, module, headerName);
+ addStandardHeader(builder, module, headerName);
return mlir::WalkResult::interrupt();
});
}
>From a7a2d01ed0e2e2b9751d22af6cc329d0852ec80d Mon Sep 17 00:00:00 2001
From: Jaddyen <ajaden at google.com>
Date: Mon, 4 Aug 2025 18:19:32 +0000
Subject: [PATCH 4/9] update test file
---
.../MemRefToEmitC/MemRefToEmitC.cpp | 7 -----
.../MemRefToEmitC/memref-to-emitc-copy.mlir | 26 ++++++++++---------
2 files changed, 14 insertions(+), 19 deletions(-)
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index cabbfac4a1dca..4130e9be88a89 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -210,20 +210,14 @@ struct ConvertCopy final : public OpConversionPattern<memref::CopyOp> {
mlir::TypedValue<emitc::ArrayType> arrayValue,
mlir::MemRefType memrefType,
emitc::ConstantOp zeroIndex) -> emitc::ApplyOp {
- // Get the rank of the array to create the correct number of zero indices.
int64_t rank = arrayValue.getType().getRank();
llvm::SmallVector<mlir::Value> indices;
for (int i = 0; i < rank; ++i) {
indices.push_back(zeroIndex);
}
- // Create a subscript operation to get the element at index [0, 0, ...,
- // 0].
emitc::SubscriptOp subPtr = rewriter.create<emitc::SubscriptOp>(
loc, arrayValue, mlir::ValueRange(indices));
-
- // Create an apply operation to take the address of the subscripted
- // element.
emitc::ApplyOp ptr = rewriter.create<emitc::ApplyOp>(
loc, emitc::PointerType::get(memrefType.getElementType()),
rewriter.getStringAttr("&"), subPtr);
@@ -231,7 +225,6 @@ struct ConvertCopy final : public OpConversionPattern<memref::CopyOp> {
return ptr;
};
- // Create a constant zero index.
emitc::ConstantOp zeroIndex = rewriter.create<emitc::ConstantOp>(
loc, rewriter.getIndexType(), rewriter.getIndexAttr(0));
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir
index 4b6eb50807513..88325e57762d3 100644
--- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir
@@ -1,23 +1,25 @@
// RUN: mlir-opt -convert-memref-to-emitc %s | FileCheck %s
-func.func @copying(%arg0 : memref<2x4xf32>) {
- memref.copy %arg0, %arg0 : memref<2x4xf32> to memref<2x4xf32>
+func.func @copying(%arg0 : memref<9x4x5x7xf32>, %arg1 : memref<9x4x5x7xf32>) {
+ memref.copy %arg0, %arg1 : memref<9x4x5x7xf32> to memref<9x4x5x7xf32>
return
}
// CHECK: module {
// CHECK-NEXT: emitc.include <"string.h">
// CHECK-LABEL: copying
-// CHECK-NEXT: %0 = builtin.unrealized_conversion_cast %arg0 : memref<2x4xf32> to !emitc.array<2x4xf32>
-// CHECK-NEXT: %1 = "emitc.constant"() <{value = 0 : index}> : () -> index
-// CHECK-NEXT: %2 = emitc.subscript %0[%1, %1] : (!emitc.array<2x4xf32>, index, index) -> !emitc.lvalue<f32>
-// CHECK-NEXT: %3 = emitc.apply "&"(%2) : (!emitc.lvalue<f32>) -> !emitc.ptr<f32>
-// CHECK-NEXT: %4 = emitc.subscript %0[%1, %1] : (!emitc.array<2x4xf32>, index, index) -> !emitc.lvalue<f32>
-// CHECK-NEXT: %5 = emitc.apply "&"(%4) : (!emitc.lvalue<f32>) -> !emitc.ptr<f32>
-// CHECK-NEXT: %6 = emitc.call_opaque "sizeof"() {args = [f32]} : () -> !emitc.size_t
-// CHECK-NEXT: %7 = "emitc.constant"() <{value = 8 : index}> : () -> index
-// CHECK-NEXT: %8 = emitc.mul %6, %7 : (!emitc.size_t, index) -> !emitc.size_t
-// CHECK-NEXT: emitc.call_opaque "memcpy"(%5, %3, %8) : (!emitc.ptr<f32>, !emitc.ptr<f32>, !emitc.size_t) -> ()
+// CHECK-SAME: %[[arg0:.*]]: memref<9x4x5x7xf32>, %[[arg1:.*]]: memref<9x4x5x7xf32>
+// CHECK-NEXT: %0 = builtin.unrealized_conversion_cast %arg1 : memref<9x4x5x7xf32> to !emitc.array<9x4x5x7xf32>
+// CHECK-NEXT: %1 = builtin.unrealized_conversion_cast %arg0 : memref<9x4x5x7xf32> to !emitc.array<9x4x5x7xf32>
+// CHECK-NEXT: %2 = "emitc.constant"() <{value = 0 : index}> : () -> index
+// CHECK-NEXT: %3 = emitc.subscript %1[%2, %2, %2, %2] : (!emitc.array<9x4x5x7xf32>, index, index, index, index) -> !emitc.lvalue<f32>
+// CHECK-NEXT: %4 = emitc.apply "&"(%3) : (!emitc.lvalue<f32>) -> !emitc.ptr<f32>
+// CHECK-NEXT: %5 = emitc.subscript %0[%2, %2, %2, %2] : (!emitc.array<9x4x5x7xf32>, index, index, index, index) -> !emitc.lvalue<f32>
+// CHECK-NEXT: %6 = emitc.apply "&"(%5) : (!emitc.lvalue<f32>) -> !emitc.ptr<f32>
+// CHECK-NEXT: %7 = emitc.call_opaque "sizeof"() {args = [f32]} : () -> !emitc.size_t
+// CHECK-NEXT: %8 = "emitc.constant"() <{value = 1260 : index}> : () -> index
+// CHECK-NEXT: %9 = emitc.mul %7, %8 : (!emitc.size_t, index) -> !emitc.size_t
+// CHECK-NEXT: emitc.call_opaque "memcpy"(%6, %4, %9) : (!emitc.ptr<f32>, !emitc.ptr<f32>, !emitc.size_t) -> ()
// CHECK-NEXT: return
// CHECK-NEXT: }
// CHECK-NEXT:}
>From 12e967e9312bdffa74393ae4ec357e25a5df22dd Mon Sep 17 00:00:00 2001
From: Jaddyen <ajaden at google.com>
Date: Tue, 5 Aug 2025 17:44:17 +0000
Subject: [PATCH 5/9] test cpp output
---
.../MemRefToEmitC/MemRefToEmitC.cpp | 37 ++++++------
.../MemRefToEmitC/MemRefToEmitCPass.cpp | 15 +++--
.../MemRefToEmitC/memref-to-emitc-copy.mlir | 57 ++++++++++++-------
3 files changed, 65 insertions(+), 44 deletions(-)
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index 4130e9be88a89..c8124d2f16943 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -196,20 +196,20 @@ struct ConvertCopy final : public OpConversionPattern<memref::CopyOp> {
MemRefType targetMemrefType =
cast<MemRefType>(copyOp.getTarget().getType());
- if (!isMemRefTypeLegalForEmitC(srcMemrefType)) {
+ if (!isMemRefTypeLegalForEmitC(srcMemrefType))
return rewriter.notifyMatchFailure(
loc, "incompatible source memref type for EmitC conversion");
- }
- if (!isMemRefTypeLegalForEmitC(targetMemrefType)) {
+
+ if (!isMemRefTypeLegalForEmitC(targetMemrefType))
return rewriter.notifyMatchFailure(
loc, "incompatible target memref type for EmitC conversion");
- }
+
+ emitc::ConstantOp zeroIndex = rewriter.create<emitc::ConstantOp>(
+ loc, rewriter.getIndexType(), rewriter.getIndexAttr(0));
auto createPointerFromEmitcArray =
- [&](mlir::Location loc, mlir::OpBuilder &rewriter,
- mlir::TypedValue<emitc::ArrayType> arrayValue,
- mlir::MemRefType memrefType,
- emitc::ConstantOp zeroIndex) -> emitc::ApplyOp {
+ [loc, &rewriter, &zeroIndex](
+ mlir::TypedValue<emitc::ArrayType> arrayValue) -> emitc::ApplyOp {
int64_t rank = arrayValue.getType().getRank();
llvm::SmallVector<mlir::Value> indices;
for (int i = 0; i < rank; ++i) {
@@ -219,30 +219,25 @@ struct ConvertCopy final : public OpConversionPattern<memref::CopyOp> {
emitc::SubscriptOp subPtr = rewriter.create<emitc::SubscriptOp>(
loc, arrayValue, mlir::ValueRange(indices));
emitc::ApplyOp ptr = rewriter.create<emitc::ApplyOp>(
- loc, emitc::PointerType::get(memrefType.getElementType()),
+ loc, emitc::PointerType::get(arrayValue.getType().getElementType()),
rewriter.getStringAttr("&"), subPtr);
return ptr;
};
- emitc::ConstantOp zeroIndex = rewriter.create<emitc::ConstantOp>(
- loc, rewriter.getIndexType(), rewriter.getIndexAttr(0));
-
auto srcArrayValue =
- dyn_cast<TypedValue<emitc::ArrayType>>(operands.getSource());
- emitc::ApplyOp srcPtr = createPointerFromEmitcArray(
- loc, rewriter, srcArrayValue, srcMemrefType, zeroIndex);
+ cast<TypedValue<emitc::ArrayType>>(operands.getSource());
+ emitc::ApplyOp srcPtr = createPointerFromEmitcArray(srcArrayValue);
auto targetArrayValue =
- dyn_cast<TypedValue<emitc::ArrayType>>(operands.getTarget());
- emitc::ApplyOp targetPtr = createPointerFromEmitcArray(
- loc, rewriter, targetArrayValue, targetMemrefType, zeroIndex);
+ cast<TypedValue<emitc::ArrayType>>(operands.getTarget());
+ emitc::ApplyOp targetPtr = createPointerFromEmitcArray(targetArrayValue);
- OpBuilder builder = rewriter;
emitc::CallOpaqueOp memCpyCall = rewriter.create<emitc::CallOpaqueOp>(
loc, TypeRange{}, "memcpy",
- ValueRange{targetPtr.getResult(), srcPtr.getResult(),
- calculateMemrefTotalSizeBytes(loc, srcMemrefType, builder)});
+ ValueRange{
+ targetPtr.getResult(), srcPtr.getResult(),
+ calculateMemrefTotalSizeBytes(loc, srcMemrefType, rewriter)});
rewriter.replaceOp(copyOp, memCpyCall.getResults());
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
index c60e7488fdb38..3ffff9fca106a 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
@@ -37,6 +37,16 @@ emitc::IncludeOp addStandardHeader(OpBuilder &builder, ModuleOp module,
/*is_standard_include=*/builder.getUnitAttr());
}
+bool isExpectedStandardInclude(ConvertMemRefToEmitCOptions options,
+ emitc::IncludeOp includeOp) {
+ return ((options.lowerToCpp &&
+ (includeOp.getInclude() == cppStandardLibraryHeader ||
+ includeOp.getInclude() == cppStringLibraryHeader)) ||
+ (!options.lowerToCpp &&
+ (includeOp.getInclude() == cStandardLibraryHeader ||
+ includeOp.getInclude() == cStringLibraryHeader)));
+}
+
struct ConvertMemRefToEmitCPass
: public impl::ConvertMemRefToEmitCBase<ConvertMemRefToEmitCPass> {
using Base::Base;
@@ -77,10 +87,7 @@ struct ConvertMemRefToEmitCPass
continue;
if (includeOp.getIsStandardInclude() &&
- ((options.lowerToCpp &&
- includeOp.getInclude() == cppStandardLibraryHeader) ||
- (!options.lowerToCpp &&
- includeOp.getInclude() == cStandardLibraryHeader)))
+ isExpectedStandardInclude(options, includeOp))
return mlir::WalkResult::interrupt();
}
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir
index 88325e57762d3..1b515ba02dd46 100644
--- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir
@@ -1,26 +1,45 @@
-// RUN: mlir-opt -convert-memref-to-emitc %s | FileCheck %s
+// RUN: mlir-opt -convert-memref-to-emitc="lower-to-cpp=true" %s -split-input-file | FileCheck %s --check-prefix=CPP
+// RUN: mlir-opt -convert-memref-to-emitc="lower-to-cpp=false" %s -split-input-file | FileCheck %s --check-prefix=NOCPP
func.func @copying(%arg0 : memref<9x4x5x7xf32>, %arg1 : memref<9x4x5x7xf32>) {
memref.copy %arg0, %arg1 : memref<9x4x5x7xf32> to memref<9x4x5x7xf32>
return
}
-// CHECK: module {
-// CHECK-NEXT: emitc.include <"string.h">
-// CHECK-LABEL: copying
-// CHECK-SAME: %[[arg0:.*]]: memref<9x4x5x7xf32>, %[[arg1:.*]]: memref<9x4x5x7xf32>
-// CHECK-NEXT: %0 = builtin.unrealized_conversion_cast %arg1 : memref<9x4x5x7xf32> to !emitc.array<9x4x5x7xf32>
-// CHECK-NEXT: %1 = builtin.unrealized_conversion_cast %arg0 : memref<9x4x5x7xf32> to !emitc.array<9x4x5x7xf32>
-// CHECK-NEXT: %2 = "emitc.constant"() <{value = 0 : index}> : () -> index
-// CHECK-NEXT: %3 = emitc.subscript %1[%2, %2, %2, %2] : (!emitc.array<9x4x5x7xf32>, index, index, index, index) -> !emitc.lvalue<f32>
-// CHECK-NEXT: %4 = emitc.apply "&"(%3) : (!emitc.lvalue<f32>) -> !emitc.ptr<f32>
-// CHECK-NEXT: %5 = emitc.subscript %0[%2, %2, %2, %2] : (!emitc.array<9x4x5x7xf32>, index, index, index, index) -> !emitc.lvalue<f32>
-// CHECK-NEXT: %6 = emitc.apply "&"(%5) : (!emitc.lvalue<f32>) -> !emitc.ptr<f32>
-// CHECK-NEXT: %7 = emitc.call_opaque "sizeof"() {args = [f32]} : () -> !emitc.size_t
-// CHECK-NEXT: %8 = "emitc.constant"() <{value = 1260 : index}> : () -> index
-// CHECK-NEXT: %9 = emitc.mul %7, %8 : (!emitc.size_t, index) -> !emitc.size_t
-// CHECK-NEXT: emitc.call_opaque "memcpy"(%6, %4, %9) : (!emitc.ptr<f32>, !emitc.ptr<f32>, !emitc.size_t) -> ()
-// CHECK-NEXT: return
-// CHECK-NEXT: }
-// CHECK-NEXT:}
+// NOCPP: module {
+// NOCPP-NEXT: emitc.include <"string.h">
+// NOCPP-LABEL: copying
+// NOCPP-SAME: %[[arg0:.*]]: memref<9x4x5x7xf32>, %[[arg1:.*]]: memref<9x4x5x7xf32>
+// NOCPP-NEXT: %0 = builtin.unrealized_conversion_cast %arg1 : memref<9x4x5x7xf32> to !emitc.array<9x4x5x7xf32>
+// NOCPP-NEXT: %1 = builtin.unrealized_conversion_cast %arg0 : memref<9x4x5x7xf32> to !emitc.array<9x4x5x7xf32>
+// NOCPP-NEXT: %2 = "emitc.constant"() <{value = 0 : index}> : () -> index
+// NOCPP-NEXT: %3 = emitc.subscript %1[%2, %2, %2, %2] : (!emitc.array<9x4x5x7xf32>, index, index, index, index) -> !emitc.lvalue<f32>
+// NOCPP-NEXT: %4 = emitc.apply "&"(%3) : (!emitc.lvalue<f32>) -> !emitc.ptr<f32>
+// NOCPP-NEXT: %5 = emitc.subscript %0[%2, %2, %2, %2] : (!emitc.array<9x4x5x7xf32>, index, index, index, index) -> !emitc.lvalue<f32>
+// NOCPP-NEXT: %6 = emitc.apply "&"(%5) : (!emitc.lvalue<f32>) -> !emitc.ptr<f32>
+// NOCPP-NEXT: %7 = emitc.call_opaque "sizeof"() {args = [f32]} : () -> !emitc.size_t
+// NOCPP-NEXT: %8 = "emitc.constant"() <{value = 1260 : index}> : () -> index
+// NOCPP-NEXT: %9 = emitc.mul %7, %8 : (!emitc.size_t, index) -> !emitc.size_t
+// NOCPP-NEXT: emitc.call_opaque "memcpy"(%6, %4, %9) : (!emitc.ptr<f32>, !emitc.ptr<f32>, !emitc.size_t) -> ()
+// NOCPP-NEXT: return
+// NOCPP-NEXT: }
+// NOCPP-NEXT:}
+// CPP: module {
+// CPP-NEXT: emitc.include <"cstring">
+// CPP-LABEL: copying
+// CPP-SAME: %[[arg0:.*]]: memref<9x4x5x7xf32>, %[[arg1:.*]]: memref<9x4x5x7xf32>
+// CPP-NEXT: %0 = builtin.unrealized_conversion_cast %arg1 : memref<9x4x5x7xf32> to !emitc.array<9x4x5x7xf32>
+// CPP-NEXT: %1 = builtin.unrealized_conversion_cast %arg0 : memref<9x4x5x7xf32> to !emitc.array<9x4x5x7xf32>
+// CPP-NEXT: %2 = "emitc.constant"() <{value = 0 : index}> : () -> index
+// CPP-NEXT: %3 = emitc.subscript %1[%2, %2, %2, %2] : (!emitc.array<9x4x5x7xf32>, index, index, index, index) -> !emitc.lvalue<f32>
+// CPP-NEXT: %4 = emitc.apply "&"(%3) : (!emitc.lvalue<f32>) -> !emitc.ptr<f32>
+// CPP-NEXT: %5 = emitc.subscript %0[%2, %2, %2, %2] : (!emitc.array<9x4x5x7xf32>, index, index, index, index) -> !emitc.lvalue<f32>
+// CPP-NEXT: %6 = emitc.apply "&"(%5) : (!emitc.lvalue<f32>) -> !emitc.ptr<f32>
+// CPP-NEXT: %7 = emitc.call_opaque "sizeof"() {args = [f32]} : () -> !emitc.size_t
+// CPP-NEXT: %8 = "emitc.constant"() <{value = 1260 : index}> : () -> index
+// CPP-NEXT: %9 = emitc.mul %7, %8 : (!emitc.size_t, index) -> !emitc.size_t
+// CPP-NEXT: emitc.call_opaque "memcpy"(%6, %4, %9) : (!emitc.ptr<f32>, !emitc.ptr<f32>, !emitc.size_t) -> ()
+// CPP-NEXT: return
+// CPP-NEXT: }
+// CPP-NEXT:}
>From 5870e6b3fb21e0f644bd237cbb6b8f30f28cdb73 Mon Sep 17 00:00:00 2001
From: Jaddyen <ajaden at google.com>
Date: Thu, 7 Aug 2025 04:23:15 +0000
Subject: [PATCH 6/9] ensure both headers are added
---
.../MemRefToEmitC/MemRefToEmitCPass.cpp | 57 ++++++++-----------
.../memref-to-emitc-alloc-copy.mlir | 30 ++++++++++
.../MemRefToEmitC/memref-to-emitc-copy.mlir | 48 ++++++----------
3 files changed, 70 insertions(+), 65 deletions(-)
create mode 100644 mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-alloc-copy.mlir
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
index 3ffff9fca106a..5469949311879 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
@@ -37,16 +37,6 @@ emitc::IncludeOp addStandardHeader(OpBuilder &builder, ModuleOp module,
/*is_standard_include=*/builder.getUnitAttr());
}
-bool isExpectedStandardInclude(ConvertMemRefToEmitCOptions options,
- emitc::IncludeOp includeOp) {
- return ((options.lowerToCpp &&
- (includeOp.getInclude() == cppStandardLibraryHeader ||
- includeOp.getInclude() == cppStringLibraryHeader)) ||
- (!options.lowerToCpp &&
- (includeOp.getInclude() == cStandardLibraryHeader ||
- includeOp.getInclude() == cStringLibraryHeader)));
-}
-
struct ConvertMemRefToEmitCPass
: public impl::ConvertMemRefToEmitCBase<ConvertMemRefToEmitCPass> {
using Base::Base;
@@ -75,34 +65,33 @@ struct ConvertMemRefToEmitCPass
return signalPassFailure();
mlir::ModuleOp module = getOperation();
+ llvm::SmallVector<StringRef> requiredHeaders;
module.walk([&](mlir::emitc::CallOpaqueOp callOp) {
- if (callOp.getCallee() != alignedAllocFunctionName &&
- callOp.getCallee() != mallocFunctionName &&
- callOp.getCallee() != memcpyFunctionName)
- return mlir::WalkResult::advance();
-
- for (auto &op : *module.getBody()) {
- emitc::IncludeOp includeOp = llvm::dyn_cast<mlir::emitc::IncludeOp>(op);
- if (!includeOp)
- continue;
-
- if (includeOp.getIsStandardInclude() &&
- isExpectedStandardInclude(options, includeOp))
- return mlir::WalkResult::interrupt();
- }
-
- mlir::OpBuilder builder(module.getBody(), module.getBody()->begin());
- StringRef headerName;
- if (callOp.getCallee() == memcpyFunctionName)
- headerName =
+ StringRef expectedHeader;
+ if (callOp.getCallee() == alignedAllocFunctionName ||
+ callOp.getCallee() == mallocFunctionName)
+ expectedHeader = options.lowerToCpp ? cppStandardLibraryHeader
+ : cStandardLibraryHeader;
+ else if (callOp.getCallee() == memcpyFunctionName)
+ expectedHeader =
options.lowerToCpp ? cppStringLibraryHeader : cStringLibraryHeader;
else
- headerName = options.lowerToCpp ? cppStandardLibraryHeader
- : cStandardLibraryHeader;
-
- addStandardHeader(builder, module, headerName);
- return mlir::WalkResult::interrupt();
+ return mlir::WalkResult::advance();
+ requiredHeaders.push_back(expectedHeader);
+ return mlir::WalkResult::advance();
});
+ for (StringRef expectedHeader : requiredHeaders) {
+ bool headerFound = llvm::any_of(*module.getBody(), [&](Operation &op) {
+ auto includeOp = dyn_cast<mlir::emitc::IncludeOp>(op);
+ return includeOp && includeOp.getIsStandardInclude() &&
+ (includeOp.getInclude() == expectedHeader);
+ });
+
+ if (!headerFound) {
+ mlir::OpBuilder builder(module.getBody(), module.getBody()->begin());
+ addStandardHeader(builder, module, expectedHeader);
+ }
+ }
}
};
} // namespace
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-alloc-copy.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-alloc-copy.mlir
new file mode 100644
index 0000000000000..2e0ff63715355
--- /dev/null
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-alloc-copy.mlir
@@ -0,0 +1,30 @@
+// RUN: mlir-opt -convert-memref-to-emitc="lower-to-cpp=true" %s -split-input-file | FileCheck %s --check-prefix=CPP
+// RUN: mlir-opt -convert-memref-to-emitc="lower-to-cpp=false" %s -split-input-file | FileCheck %s --check-prefix=NOCPP
+
+
+func.func @alloc_copy(%arg0: memref<999xi32>) {
+ %alloc = memref.alloc() : memref<999xi32>
+ memref.copy %arg0, %alloc : memref<999xi32> to memref<999xi32>
+ return
+}
+
+// NOCPP: module {
+// NOCPP-NEXT: emitc.include <"string.h">
+// NOCPP-NEXT: emitc.include <"stdlib.h">
+
+// CPP: module {
+// CPP-NEXT: emitc.include <"cstring">
+// CHECK-NEXT: emitc.include <"cstdlib">
+// CHECK-LABEL: alloc_copy
+// CHECK-SAME: %[[arg0:.*]]: memref<999xi32>
+// CHECK-NEXT: %0 = builtin.unrealized_conversion_cast %arg0 : memref<999xi32> to !emitc.array<999xi32>
+// CHECK-NEXT: emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t
+// CHECK-NEXT: "emitc.constant"() <{value = 999 : index}> : () -> index
+// CHECK-NEXT: emitc.mul %1, %2 : (!emitc.size_t, index) -> !emitc.size_t
+// CHECK-NEXT: emitc.call_opaque "malloc"(%3) : (!emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void">>
+// CHECK-NEXT: emitc.cast %4 : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<i32>
+// CHECK-NEXT: builtin.unrealized_conversion_cast %5 : !emitc.ptr<i32> to !emitc.array<999xi32>
+// CHECK-NEXT: "emitc.constant"() <{value = 0 : index}> : () -> index
+// CHECK-NEXT: emitc.subscript %0[%7] : (!emitc.array<999xi32>, index) -> !emitc.lvalue<i32>
+// CHECK-NEXT: emitc.apply "&"(%8) : (!emitc.lvalue<i32>) -> !emitc.ptr<i32>
+
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir
index 1b515ba02dd46..615dbfeb461a2 100644
--- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir
@@ -8,38 +8,24 @@ func.func @copying(%arg0 : memref<9x4x5x7xf32>, %arg1 : memref<9x4x5x7xf32>) {
// NOCPP: module {
// NOCPP-NEXT: emitc.include <"string.h">
-// NOCPP-LABEL: copying
-// NOCPP-SAME: %[[arg0:.*]]: memref<9x4x5x7xf32>, %[[arg1:.*]]: memref<9x4x5x7xf32>
-// NOCPP-NEXT: %0 = builtin.unrealized_conversion_cast %arg1 : memref<9x4x5x7xf32> to !emitc.array<9x4x5x7xf32>
-// NOCPP-NEXT: %1 = builtin.unrealized_conversion_cast %arg0 : memref<9x4x5x7xf32> to !emitc.array<9x4x5x7xf32>
-// NOCPP-NEXT: %2 = "emitc.constant"() <{value = 0 : index}> : () -> index
-// NOCPP-NEXT: %3 = emitc.subscript %1[%2, %2, %2, %2] : (!emitc.array<9x4x5x7xf32>, index, index, index, index) -> !emitc.lvalue<f32>
-// NOCPP-NEXT: %4 = emitc.apply "&"(%3) : (!emitc.lvalue<f32>) -> !emitc.ptr<f32>
-// NOCPP-NEXT: %5 = emitc.subscript %0[%2, %2, %2, %2] : (!emitc.array<9x4x5x7xf32>, index, index, index, index) -> !emitc.lvalue<f32>
-// NOCPP-NEXT: %6 = emitc.apply "&"(%5) : (!emitc.lvalue<f32>) -> !emitc.ptr<f32>
-// NOCPP-NEXT: %7 = emitc.call_opaque "sizeof"() {args = [f32]} : () -> !emitc.size_t
-// NOCPP-NEXT: %8 = "emitc.constant"() <{value = 1260 : index}> : () -> index
-// NOCPP-NEXT: %9 = emitc.mul %7, %8 : (!emitc.size_t, index) -> !emitc.size_t
-// NOCPP-NEXT: emitc.call_opaque "memcpy"(%6, %4, %9) : (!emitc.ptr<f32>, !emitc.ptr<f32>, !emitc.size_t) -> ()
-// NOCPP-NEXT: return
-// NOCPP-NEXT: }
-// NOCPP-NEXT:}
+
// CPP: module {
// CPP-NEXT: emitc.include <"cstring">
// CPP-LABEL: copying
-// CPP-SAME: %[[arg0:.*]]: memref<9x4x5x7xf32>, %[[arg1:.*]]: memref<9x4x5x7xf32>
-// CPP-NEXT: %0 = builtin.unrealized_conversion_cast %arg1 : memref<9x4x5x7xf32> to !emitc.array<9x4x5x7xf32>
-// CPP-NEXT: %1 = builtin.unrealized_conversion_cast %arg0 : memref<9x4x5x7xf32> to !emitc.array<9x4x5x7xf32>
-// CPP-NEXT: %2 = "emitc.constant"() <{value = 0 : index}> : () -> index
-// CPP-NEXT: %3 = emitc.subscript %1[%2, %2, %2, %2] : (!emitc.array<9x4x5x7xf32>, index, index, index, index) -> !emitc.lvalue<f32>
-// CPP-NEXT: %4 = emitc.apply "&"(%3) : (!emitc.lvalue<f32>) -> !emitc.ptr<f32>
-// CPP-NEXT: %5 = emitc.subscript %0[%2, %2, %2, %2] : (!emitc.array<9x4x5x7xf32>, index, index, index, index) -> !emitc.lvalue<f32>
-// CPP-NEXT: %6 = emitc.apply "&"(%5) : (!emitc.lvalue<f32>) -> !emitc.ptr<f32>
-// CPP-NEXT: %7 = emitc.call_opaque "sizeof"() {args = [f32]} : () -> !emitc.size_t
-// CPP-NEXT: %8 = "emitc.constant"() <{value = 1260 : index}> : () -> index
-// CPP-NEXT: %9 = emitc.mul %7, %8 : (!emitc.size_t, index) -> !emitc.size_t
-// CPP-NEXT: emitc.call_opaque "memcpy"(%6, %4, %9) : (!emitc.ptr<f32>, !emitc.ptr<f32>, !emitc.size_t) -> ()
-// CPP-NEXT: return
-// CPP-NEXT: }
-// CPP-NEXT:}
+// CHECK-SAME: %[[arg0:.*]]: memref<9x4x5x7xf32>, %[[arg1:.*]]: memref<9x4x5x7xf32>
+// CHECK-NEXT: %0 = builtin.unrealized_conversion_cast %arg1 : memref<9x4x5x7xf32> to !emitc.array<9x4x5x7xf32>
+// CHECK-NEXT: %1 = builtin.unrealized_conversion_cast %arg0 : memref<9x4x5x7xf32> to !emitc.array<9x4x5x7xf32>
+// CHECK-NEXT: %2 = "emitc.constant"() <{value = 0 : index}> : () -> index
+// CHECK-NEXT: %3 = emitc.subscript %1[%2, %2, %2, %2] : (!emitc.array<9x4x5x7xf32>, index, index, index, index) -> !emitc.lvalue<f32>
+// CHECK-NEXT: %4 = emitc.apply "&"(%3) : (!emitc.lvalue<f32>) -> !emitc.ptr<f32>
+// CHECK-NEXT: %5 = emitc.subscript %0[%2, %2, %2, %2] : (!emitc.array<9x4x5x7xf32>, index, index, index, index) -> !emitc.lvalue<f32>
+// CHECK-NEXT: %6 = emitc.apply "&"(%5) : (!emitc.lvalue<f32>) -> !emitc.ptr<f32>
+// CHECK-NEXT: %7 = emitc.call_opaque "sizeof"() {args = [f32]} : () -> !emitc.size_t
+// CHECK-NEXT: %8 = "emitc.constant"() <{value = 1260 : index}> : () -> index
+// CHECK-NEXT: %9 = emitc.mul %7, %8 : (!emitc.size_t, index) -> !emitc.size_t
+// CHECK-NEXT: emitc.call_opaque "memcpy"(%6, %4, %9) : (!emitc.ptr<f32>, !emitc.ptr<f32>, !emitc.size_t) -> ()
+// CHECK-NEXT: return
+// CHECK-NEXT: }
+// CHECK-NEXT:}
+
>From 3663062675a375ef4ab8c7a9d050494c7d5203e0 Mon Sep 17 00:00:00 2001
From: Jaddyen <ajaden at google.com>
Date: Thu, 7 Aug 2025 04:26:59 +0000
Subject: [PATCH 7/9] update test file
---
.../Conversion/MemRefToEmitC/memref-to-emitc-alloc-copy.mlir | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-alloc-copy.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-alloc-copy.mlir
index 2e0ff63715355..a1fa58803e15a 100644
--- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-alloc-copy.mlir
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-alloc-copy.mlir
@@ -14,7 +14,7 @@ func.func @alloc_copy(%arg0: memref<999xi32>) {
// CPP: module {
// CPP-NEXT: emitc.include <"cstring">
-// CHECK-NEXT: emitc.include <"cstdlib">
+// CPP-NEXT: emitc.include <"cstdlib">
// CHECK-LABEL: alloc_copy
// CHECK-SAME: %[[arg0:.*]]: memref<999xi32>
// CHECK-NEXT: %0 = builtin.unrealized_conversion_cast %arg0 : memref<999xi32> to !emitc.array<999xi32>
>From ec252fcc01afe0f819caa9025270fe520187f3f4 Mon Sep 17 00:00:00 2001
From: Jaddyen <ajaden at google.com>
Date: Fri, 8 Aug 2025 00:15:33 +0000
Subject: [PATCH 8/9] ensure no duplicate headers
---
.../MemRefToEmitC/MemRefToEmitC.cpp | 17 ++++------
.../MemRefToEmitC/MemRefToEmitCPass.cpp | 26 +++++++-------
.../memref-to-emitc-alloc-copy.mlir | 34 +++++++++++++++----
.../MemRefToEmitC/memref-to-emitc-copy.mlir | 9 +++--
4 files changed, 50 insertions(+), 36 deletions(-)
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index c8124d2f16943..4e86188af719e 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -23,6 +23,7 @@
#include "mlir/IR/Value.h"
#include "mlir/Transforms/DialectConversion.h"
#include <cstdint>
+#include <numeric>
using namespace mlir;
@@ -98,8 +99,8 @@ Type convertMemRefType(MemRefType opTy, const TypeConverter *typeConverter) {
return resultTy;
}
-Value calculateMemrefTotalSizeBytes(Location loc, MemRefType memrefType,
- OpBuilder &builder) {
+static Value calculateMemrefTotalSizeBytes(Location loc, MemRefType memrefType,
+ OpBuilder &builder) {
assert(isMemRefTypeLegalForEmitC(memrefType) &&
"incompatible memref type for EmitC conversion");
emitc::CallOpaqueOp elementSize = builder.create<emitc::CallOpaqueOp>(
@@ -109,10 +110,9 @@ Value calculateMemrefTotalSizeBytes(Location loc, MemRefType memrefType,
{TypeAttr::get(memrefType.getElementType())}));
IndexType indexType = builder.getIndexType();
- int64_t numElements = 1;
- for (int64_t dimSize : memrefType.getShape()) {
- numElements *= dimSize;
- }
+ int64_t numElements = std::accumulate(memrefType.getShape().begin(),
+ memrefType.getShape().end(), int64_t{1},
+ std::multiplies<int64_t>());
emitc::ConstantOp numElementsValue = builder.create<emitc::ConstantOp>(
loc, indexType, builder.getIndexAttr(numElements));
@@ -211,10 +211,7 @@ struct ConvertCopy final : public OpConversionPattern<memref::CopyOp> {
[loc, &rewriter, &zeroIndex](
mlir::TypedValue<emitc::ArrayType> arrayValue) -> emitc::ApplyOp {
int64_t rank = arrayValue.getType().getRank();
- llvm::SmallVector<mlir::Value> indices;
- for (int i = 0; i < rank; ++i) {
- indices.push_back(zeroIndex);
- }
+ llvm::SmallVector<mlir::Value> indices(rank, zeroIndex);
emitc::SubscriptOp subPtr = rewriter.create<emitc::SubscriptOp>(
loc, arrayValue, mlir::ValueRange(indices));
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
index 5469949311879..a51890248271f 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
@@ -18,6 +18,7 @@
#include "mlir/IR/Attributes.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/StringRef.h"
namespace mlir {
@@ -65,7 +66,13 @@ struct ConvertMemRefToEmitCPass
return signalPassFailure();
mlir::ModuleOp module = getOperation();
- llvm::SmallVector<StringRef> requiredHeaders;
+ llvm::SmallSet<StringRef, 4> existingHeaders;
+ mlir::OpBuilder builder(module.getBody(), module.getBody()->begin());
+ module.walk([&](mlir::emitc::IncludeOp includeOp) {
+ if (includeOp.getIsStandardInclude())
+ existingHeaders.insert(includeOp.getInclude());
+ });
+
module.walk([&](mlir::emitc::CallOpaqueOp callOp) {
StringRef expectedHeader;
if (callOp.getCallee() == alignedAllocFunctionName ||
@@ -77,21 +84,12 @@ struct ConvertMemRefToEmitCPass
options.lowerToCpp ? cppStringLibraryHeader : cStringLibraryHeader;
else
return mlir::WalkResult::advance();
- requiredHeaders.push_back(expectedHeader);
- return mlir::WalkResult::advance();
- });
- for (StringRef expectedHeader : requiredHeaders) {
- bool headerFound = llvm::any_of(*module.getBody(), [&](Operation &op) {
- auto includeOp = dyn_cast<mlir::emitc::IncludeOp>(op);
- return includeOp && includeOp.getIsStandardInclude() &&
- (includeOp.getInclude() == expectedHeader);
- });
-
- if (!headerFound) {
- mlir::OpBuilder builder(module.getBody(), module.getBody()->begin());
+ if (!existingHeaders.contains(expectedHeader)) {
addStandardHeader(builder, module, expectedHeader);
+ existingHeaders.insert(expectedHeader);
}
- }
+ return mlir::WalkResult::advance();
+ });
}
};
} // namespace
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-alloc-copy.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-alloc-copy.mlir
index a1fa58803e15a..c1627a0d4d023 100644
--- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-alloc-copy.mlir
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-alloc-copy.mlir
@@ -1,23 +1,24 @@
// RUN: mlir-opt -convert-memref-to-emitc="lower-to-cpp=true" %s -split-input-file | FileCheck %s --check-prefix=CPP
// RUN: mlir-opt -convert-memref-to-emitc="lower-to-cpp=false" %s -split-input-file | FileCheck %s --check-prefix=NOCPP
-
func.func @alloc_copy(%arg0: memref<999xi32>) {
%alloc = memref.alloc() : memref<999xi32>
memref.copy %arg0, %alloc : memref<999xi32> to memref<999xi32>
+ %alloc_1 = memref.alloc() : memref<999xi32>
+ memref.copy %arg0, %alloc_1 : memref<999xi32> to memref<999xi32>
return
}
-// NOCPP: module {
+// CHECK: module {
+// NOCPP: emitc.include <"stdlib.h">
// NOCPP-NEXT: emitc.include <"string.h">
-// NOCPP-NEXT: emitc.include <"stdlib.h">
-// CPP: module {
+// CPP: emitc.include <"cstdlib">
// CPP-NEXT: emitc.include <"cstring">
-// CPP-NEXT: emitc.include <"cstdlib">
+
// CHECK-LABEL: alloc_copy
// CHECK-SAME: %[[arg0:.*]]: memref<999xi32>
-// CHECK-NEXT: %0 = builtin.unrealized_conversion_cast %arg0 : memref<999xi32> to !emitc.array<999xi32>
+// CHECK-NEXT: builtin.unrealized_conversion_cast %arg0 : memref<999xi32> to !emitc.array<999xi32>
// CHECK-NEXT: emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t
// CHECK-NEXT: "emitc.constant"() <{value = 999 : index}> : () -> index
// CHECK-NEXT: emitc.mul %1, %2 : (!emitc.size_t, index) -> !emitc.size_t
@@ -27,4 +28,23 @@ func.func @alloc_copy(%arg0: memref<999xi32>) {
// CHECK-NEXT: "emitc.constant"() <{value = 0 : index}> : () -> index
// CHECK-NEXT: emitc.subscript %0[%7] : (!emitc.array<999xi32>, index) -> !emitc.lvalue<i32>
// CHECK-NEXT: emitc.apply "&"(%8) : (!emitc.lvalue<i32>) -> !emitc.ptr<i32>
-
+// CHECK-NEXT: emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t
+// CHECK-NEXT: "emitc.constant"() <{value = 999 : index}> : () -> index
+// CHECK-NEXT: emitc.mul %12, %13 : (!emitc.size_t, index) -> !emitc.size_t
+// CHECK-NEXT: emitc.call_opaque "memcpy"(%11, %9, %14) : (!emitc.ptr<i32>, !emitc.ptr<i32>, !emitc.size_t) -> ()
+// CHECK-NEXT: emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t
+// CHECK-NEXT: "emitc.constant"() <{value = 999 : index}> : () -> index
+// CHECK-NEXT: emitc.mul %15, %16 : (!emitc.size_t, index) -> !emitc.size_t
+// CHECK-NEXT: emitc.call_opaque "malloc"(%17) : (!emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void">>
+// CHECK-NEXT: emitc.cast %18 : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<i32>
+// CHECK-NEXT: builtin.unrealized_conversion_cast %19 : !emitc.ptr<i32> to !emitc.array<999xi32>
+// CHECK-NEXT: "emitc.constant"() <{value = 0 : index}> : () -> index
+// CHECK-NEXT: emitc.subscript %0[%21] : (!emitc.array<999xi32>, index) -> !emitc.lvalue<i32>
+// CHECK-NEXT: emitc.apply "&"(%22) : (!emitc.lvalue<i32>) -> !emitc.ptr<i32>
+// CHECK-NEXT: emitc.subscript %20[%21] : (!emitc.array<999xi32>, index) -> !emitc.lvalue<i32>
+// CHECK-NEXT: emitc.apply "&"(%24) : (!emitc.lvalue<i32>) -> !emitc.ptr<i32>
+// CHECK-NEXT: emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t
+// CHECK-NEXT: "emitc.constant"() <{value = 999 : index}> : () -> index
+// CHECK-NEXT: emitc.mul %26, %27 : (!emitc.size_t, index) -> !emitc.size_t
+// CHECK-NEXT: emitc.call_opaque "memcpy"(%25, %23, %28) : (!emitc.ptr<i32>, !emitc.ptr<i32>, !emitc.size_t) -> ()
+// CHECK-NEXT: return
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir
index 615dbfeb461a2..6eb2c2db6a0b0 100644
--- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir
@@ -6,13 +6,12 @@ func.func @copying(%arg0 : memref<9x4x5x7xf32>, %arg1 : memref<9x4x5x7xf32>) {
return
}
-// NOCPP: module {
-// NOCPP-NEXT: emitc.include <"string.h">
+// CHECK: module {
+// NOCPP: emitc.include <"string.h">
+// CPP: emitc.include <"cstring">
-// CPP: module {
-// CPP-NEXT: emitc.include <"cstring">
-// CPP-LABEL: copying
+// CHECK-LABEL: copying
// CHECK-SAME: %[[arg0:.*]]: memref<9x4x5x7xf32>, %[[arg1:.*]]: memref<9x4x5x7xf32>
// CHECK-NEXT: %0 = builtin.unrealized_conversion_cast %arg1 : memref<9x4x5x7xf32> to !emitc.array<9x4x5x7xf32>
// CHECK-NEXT: %1 = builtin.unrealized_conversion_cast %arg0 : memref<9x4x5x7xf32> to !emitc.array<9x4x5x7xf32>
>From de3c63d248fcb7d6ac25c93cd2b61aea1b527b96 Mon Sep 17 00:00:00 2001
From: Jaddyen <ajaden at google.com>
Date: Fri, 8 Aug 2025 21:24:51 +0000
Subject: [PATCH 9/9] refactoring
---
.../MemRefToEmitC/MemRefToEmitC.cpp | 43 ++++++++++---------
.../MemRefToEmitC/memref-to-emitc-copy.mlir | 1 -
2 files changed, 23 insertions(+), 21 deletions(-)
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index 4e86188af719e..c3b937fa41431 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -123,6 +123,25 @@ static Value calculateMemrefTotalSizeBytes(Location loc, MemRefType memrefType,
return totalSizeBytes.getResult();
}
+static emitc::ApplyOp
+createPointerFromEmitcArray(Location loc, OpBuilder &builder,
+ TypedValue<emitc::ArrayType> arrayValue) {
+
+ emitc::ConstantOp zeroIndex = builder.create<emitc::ConstantOp>(
+ loc, builder.getIndexType(), builder.getIndexAttr(0));
+
+ int64_t rank = arrayValue.getType().getRank();
+ llvm::SmallVector<mlir::Value> indices(rank, zeroIndex);
+
+ emitc::SubscriptOp subPtr =
+ builder.create<emitc::SubscriptOp>(loc, arrayValue, ValueRange(indices));
+ emitc::ApplyOp ptr = builder.create<emitc::ApplyOp>(
+ loc, emitc::PointerType::get(arrayValue.getType().getElementType()),
+ builder.getStringAttr("&"), subPtr);
+
+ return ptr;
+}
+
struct ConvertAlloc final : public OpConversionPattern<memref::AllocOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
@@ -204,31 +223,15 @@ struct ConvertCopy final : public OpConversionPattern<memref::CopyOp> {
return rewriter.notifyMatchFailure(
loc, "incompatible target memref type for EmitC conversion");
- emitc::ConstantOp zeroIndex = rewriter.create<emitc::ConstantOp>(
- loc, rewriter.getIndexType(), rewriter.getIndexAttr(0));
-
- auto createPointerFromEmitcArray =
- [loc, &rewriter, &zeroIndex](
- mlir::TypedValue<emitc::ArrayType> arrayValue) -> emitc::ApplyOp {
- int64_t rank = arrayValue.getType().getRank();
- llvm::SmallVector<mlir::Value> indices(rank, zeroIndex);
-
- emitc::SubscriptOp subPtr = rewriter.create<emitc::SubscriptOp>(
- loc, arrayValue, mlir::ValueRange(indices));
- emitc::ApplyOp ptr = rewriter.create<emitc::ApplyOp>(
- loc, emitc::PointerType::get(arrayValue.getType().getElementType()),
- rewriter.getStringAttr("&"), subPtr);
-
- return ptr;
- };
-
auto srcArrayValue =
cast<TypedValue<emitc::ArrayType>>(operands.getSource());
- emitc::ApplyOp srcPtr = createPointerFromEmitcArray(srcArrayValue);
+ emitc::ApplyOp srcPtr =
+ createPointerFromEmitcArray(loc, rewriter, srcArrayValue);
auto targetArrayValue =
cast<TypedValue<emitc::ArrayType>>(operands.getTarget());
- emitc::ApplyOp targetPtr = createPointerFromEmitcArray(targetArrayValue);
+ emitc::ApplyOp targetPtr =
+ createPointerFromEmitcArray(loc, rewriter, targetArrayValue);
emitc::CallOpaqueOp memCpyCall = rewriter.create<emitc::CallOpaqueOp>(
loc, TypeRange{}, "memcpy",
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir
index 6eb2c2db6a0b0..d151d1bd53458 100644
--- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir
@@ -8,7 +8,6 @@ func.func @copying(%arg0 : memref<9x4x5x7xf32>, %arg1 : memref<9x4x5x7xf32>) {
// CHECK: module {
// NOCPP: emitc.include <"string.h">
-
// CPP: emitc.include <"cstring">
// CHECK-LABEL: copying
More information about the Mlir-commits
mailing list