[Mlir-commits] [mlir] [mlir][EmitC] Expand the MemRefToEmitC pass - Lowering `CopyOp` (PR #151206)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Aug 4 11:19:56 PDT 2025


https://github.com/ajaden-codes updated https://github.com/llvm/llvm-project/pull/151206

>From 3768def1fc42b74efd8517add59f8b32b00e895a Mon Sep 17 00:00:00 2001
From: Jaddyen <ajaden at google.com>
Date: Fri, 25 Jul 2025 16:43:57 +0000
Subject: [PATCH 1/4] Adding lowering for copyop

---
 .../MemRefToEmitC/MemRefToEmitC.cpp           | 60 +++++++++++++++++++
 .../MemRefToEmitC/memref-to-emitc-copy.mlir   | 25 ++++++++
 2 files changed, 85 insertions(+)
 create mode 100644 mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir

diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index 6bd0e2d4d4b08..34ea4989c8156 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -87,6 +87,66 @@ struct ConvertAlloca final : public OpConversionPattern<memref::AllocaOp> {
   }
 };
 
+struct ConvertCopy final : public OpConversionPattern<memref::CopyOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(memref::CopyOp copyOp, OpAdaptor operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = copyOp.getLoc();
+    auto srcMemrefType = dyn_cast<MemRefType>(copyOp.getSource().getType());
+    auto targetMemrefType = dyn_cast<MemRefType>(copyOp.getTarget().getType());
+
+    if (!srcMemrefType || !targetMemrefType) {
+      return failure();
+    }
+
+    // 1. Cast source memref to a pointer.
+    auto srcPtrType = emitc::PointerType::get(srcMemrefType.getElementType());
+    auto srcArrayValue =
+        dyn_cast<TypedValue<emitc::ArrayType>>(operands.getSource());
+    auto stcArrayPtr =
+        emitc::PointerType::get(srcArrayValue.getType().getElementType());
+    auto srcPtr = rewriter.create<emitc::CastOp>(loc, srcPtrType,
+                                                 stcArrayPtr.getPointee());
+
+    // 2. Cast target memref to a pointer.
+    auto targetPtrType =
+        emitc::PointerType::get(targetMemrefType.getElementType());
+
+    auto arrayValue =
+        dyn_cast<TypedValue<emitc::ArrayType>>(operands.getTarget());
+
+    // Cast the target memref value to a pointer type.
+    auto targetPtr =
+        rewriter.create<emitc::CastOp>(loc, targetPtrType, arrayValue);
+
+    // 3. Calculate the size in bytes of the memref.
+    auto elementSize = rewriter.create<emitc::CallOpaqueOp>(
+        loc, rewriter.getIndexType(), rewriter.getStringAttr("sizeof"),
+        mlir::ValueRange{},
+        mlir::ArrayAttr::get(
+            rewriter.getContext(),
+            {mlir::TypeAttr::get(srcMemrefType.getElementType())}));
+
+    auto numElements = rewriter.create<emitc::ConstantOp>(
+        loc, rewriter.getIndexType(),
+        rewriter.getIntegerAttr(rewriter.getIndexType(),
+                                srcMemrefType.getNumElements()));
+    auto byteSize = rewriter.create<emitc::MulOp>(loc, rewriter.getIndexType(),
+                                                  elementSize.getResult(0),
+                                                  numElements.getResult());
+
+    // 4. Emit the memcpy call.
+    rewriter.create<emitc::CallOpaqueOp>(loc, TypeRange{}, "memcpy",
+                                         ValueRange{targetPtr.getResult(),
+                                                    srcPtr.getResult(),
+                                                    byteSize.getResult()});
+
+    return success();
+  }
+};
+
 Type convertMemRefType(MemRefType opTy, const TypeConverter *typeConverter) {
   Type resultTy;
   if (opTy.getRank() == 0) {
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir
new file mode 100644
index 0000000000000..d031d60508df2
--- /dev/null
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir
@@ -0,0 +1,25 @@
+// RUN: mlir-opt -convert-memref-to-emitc %s -split-input-file | FileCheck %s
+
+func.func @copying(%arg0 : memref<2x4xf32>) {
+  memref.copy %arg0, %arg0 : memref<2x4xf32> to memref<2x4xf32>
+  return
+}
+
+// func.func @copying_memcpy(%arg_0: !emitc.ptr<f32>) {
+//   %size = "emitc.constant"() <{value = 8 : index}> :() -> index
+//   %element_size = "emitc.constant"() <{value = 4 : index}> :() -> index
+//   %total_bytes = emitc.mul %size, %element_size : (index, index) -> index
+  
+//   emitc.call_opaque "memcpy"(%arg_0, %arg_0, %total_bytes) : (!emitc.ptr<f32>, !emitc.ptr<f32>, index) -> ()
+//   return
+// }
+
+// CHECK-LABEL: copying_memcpy
+// CHECK-SAME: %arg_0: !emitc.ptr<f32>
+// CHECK-NEXT: %size = "emitc.constant"() <{value = 8 : index}> :() -> index
+// CHECK-NEXT: %element_size = "emitc.constant"() <{value = 4 : index}> :() -> index
+// CHECK-NEXT: %total_bytes = emitc.mul %size, %element_size : (index, index) -> index
+// CHECK-NEXT: emitc.call_opaque "memcpy"
+// CHECK-SAME: (%arg_0, %arg_0, %total_bytes)
+// CHECK-NEXT: : (!emitc.ptr<f32>, !emitc.ptr<f32>, index) -> ()
+// CHECK-NEXT: return
\ No newline at end of file

>From bc4dd4fafacfdee0cd3d2bf3f4b8a9c2831cb8fa Mon Sep 17 00:00:00 2001
From: Jaddyen <ajaden at google.com>
Date: Tue, 29 Jul 2025 18:14:39 +0000
Subject: [PATCH 2/4] use subscript and apply ops

---
 .../Conversion/MemRefToEmitC/MemRefToEmitC.h  |   2 +
 .../MemRefToEmitC/MemRefToEmitC.cpp           | 137 ++++++++++--------
 .../MemRefToEmitC/MemRefToEmitCPass.cpp       |  27 +++-
 .../MemRefToEmitC/memref-to-emitc-copy.mlir   |  35 +++--
 .../MemRefToEmitC/memref-to-emitc-failed.mlir |   8 -
 5 files changed, 114 insertions(+), 95 deletions(-)

diff --git a/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h b/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h
index b595b6a308bea..4ea6649d64a92 100644
--- a/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h
+++ b/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h
@@ -10,8 +10,10 @@
 
 constexpr const char *alignedAllocFunctionName = "aligned_alloc";
 constexpr const char *mallocFunctionName = "malloc";
+constexpr const char *memcpyFunctionName = "memcpy";
 constexpr const char *cppStandardLibraryHeader = "cstdlib";
 constexpr const char *cStandardLibraryHeader = "stdlib.h";
+constexpr const char *stringLibraryHeader = "string.h";
 
 namespace mlir {
 class DialectRegistry;
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index 34ea4989c8156..adb0eb77fdf35 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -87,66 +87,6 @@ struct ConvertAlloca final : public OpConversionPattern<memref::AllocaOp> {
   }
 };
 
-struct ConvertCopy final : public OpConversionPattern<memref::CopyOp> {
-  using OpConversionPattern::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(memref::CopyOp copyOp, OpAdaptor operands,
-                  ConversionPatternRewriter &rewriter) const override {
-    Location loc = copyOp.getLoc();
-    auto srcMemrefType = dyn_cast<MemRefType>(copyOp.getSource().getType());
-    auto targetMemrefType = dyn_cast<MemRefType>(copyOp.getTarget().getType());
-
-    if (!srcMemrefType || !targetMemrefType) {
-      return failure();
-    }
-
-    // 1. Cast source memref to a pointer.
-    auto srcPtrType = emitc::PointerType::get(srcMemrefType.getElementType());
-    auto srcArrayValue =
-        dyn_cast<TypedValue<emitc::ArrayType>>(operands.getSource());
-    auto stcArrayPtr =
-        emitc::PointerType::get(srcArrayValue.getType().getElementType());
-    auto srcPtr = rewriter.create<emitc::CastOp>(loc, srcPtrType,
-                                                 stcArrayPtr.getPointee());
-
-    // 2. Cast target memref to a pointer.
-    auto targetPtrType =
-        emitc::PointerType::get(targetMemrefType.getElementType());
-
-    auto arrayValue =
-        dyn_cast<TypedValue<emitc::ArrayType>>(operands.getTarget());
-
-    // Cast the target memref value to a pointer type.
-    auto targetPtr =
-        rewriter.create<emitc::CastOp>(loc, targetPtrType, arrayValue);
-
-    // 3. Calculate the size in bytes of the memref.
-    auto elementSize = rewriter.create<emitc::CallOpaqueOp>(
-        loc, rewriter.getIndexType(), rewriter.getStringAttr("sizeof"),
-        mlir::ValueRange{},
-        mlir::ArrayAttr::get(
-            rewriter.getContext(),
-            {mlir::TypeAttr::get(srcMemrefType.getElementType())}));
-
-    auto numElements = rewriter.create<emitc::ConstantOp>(
-        loc, rewriter.getIndexType(),
-        rewriter.getIntegerAttr(rewriter.getIndexType(),
-                                srcMemrefType.getNumElements()));
-    auto byteSize = rewriter.create<emitc::MulOp>(loc, rewriter.getIndexType(),
-                                                  elementSize.getResult(0),
-                                                  numElements.getResult());
-
-    // 4. Emit the memcpy call.
-    rewriter.create<emitc::CallOpaqueOp>(loc, TypeRange{}, "memcpy",
-                                         ValueRange{targetPtr.getResult(),
-                                                    srcPtr.getResult(),
-                                                    byteSize.getResult()});
-
-    return success();
-  }
-};
-
 Type convertMemRefType(MemRefType opTy, const TypeConverter *typeConverter) {
   Type resultTy;
   if (opTy.getRank() == 0) {
@@ -157,6 +97,29 @@ Type convertMemRefType(MemRefType opTy, const TypeConverter *typeConverter) {
   return resultTy;
 }
 
+Value calculateMemrefTotalSizeBytes(Location loc, MemRefType memrefType,
+                                    ConversionPatternRewriter &rewriter) {
+  emitc::CallOpaqueOp elementSize = rewriter.create<emitc::CallOpaqueOp>(
+      loc, emitc::SizeTType::get(rewriter.getContext()),
+      rewriter.getStringAttr("sizeof"), ValueRange{},
+      ArrayAttr::get(rewriter.getContext(),
+                     {TypeAttr::get(memrefType.getElementType())}));
+
+  IndexType indexType = rewriter.getIndexType();
+  int64_t numElements = 1;
+  for (int64_t dimSize : memrefType.getShape()) {
+    numElements *= dimSize;
+  }
+  emitc::ConstantOp numElementsValue = rewriter.create<emitc::ConstantOp>(
+      loc, indexType, rewriter.getIndexAttr(numElements));
+
+  Type sizeTType = emitc::SizeTType::get(rewriter.getContext());
+  emitc::MulOp totalSizeBytes = rewriter.create<emitc::MulOp>(
+      loc, sizeTType, elementSize.getResult(0), numElementsValue);
+
+  return totalSizeBytes.getResult();
+}
+
 struct ConvertAlloc final : public OpConversionPattern<memref::AllocOp> {
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
@@ -219,6 +182,55 @@ struct ConvertAlloc final : public OpConversionPattern<memref::AllocOp> {
   }
 };
 
+struct ConvertCopy final : public OpConversionPattern<memref::CopyOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(memref::CopyOp copyOp, OpAdaptor operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = copyOp.getLoc();
+    MemRefType srcMemrefType =
+        dyn_cast<MemRefType>(copyOp.getSource().getType());
+    MemRefType targetMemrefType =
+        dyn_cast<MemRefType>(copyOp.getTarget().getType());
+
+    if (!isMemRefTypeLegalForEmitC(srcMemrefType) ||
+        !isMemRefTypeLegalForEmitC(targetMemrefType)) {
+      return rewriter.notifyMatchFailure(
+          loc, "incompatible memref type for EmitC conversion");
+    }
+
+    emitc::ConstantOp zeroIndex = rewriter.create<emitc::ConstantOp>(
+        loc, rewriter.getIndexType(), rewriter.getIndexAttr(0));
+    auto srcArrayValue =
+        dyn_cast<TypedValue<emitc::ArrayType>>(operands.getSource());
+
+    emitc::SubscriptOp srcSubPtr = rewriter.create<emitc::SubscriptOp>(
+        loc, srcArrayValue, ValueRange{zeroIndex, zeroIndex});
+    emitc::ApplyOp srcPtr = rewriter.create<emitc::ApplyOp>(
+        loc, emitc::PointerType::get(srcMemrefType.getElementType()),
+        rewriter.getStringAttr("&"), srcSubPtr);
+
+    auto arrayValue =
+        dyn_cast<TypedValue<emitc::ArrayType>>(operands.getTarget());
+    emitc::SubscriptOp targetSubPtr = rewriter.create<emitc::SubscriptOp>(
+        loc, arrayValue, ValueRange{zeroIndex, zeroIndex});
+    emitc::ApplyOp targetPtr = rewriter.create<emitc::ApplyOp>(
+        loc, emitc::PointerType::get(targetMemrefType.getElementType()),
+        rewriter.getStringAttr("&"), targetSubPtr);
+
+    emitc::CallOpaqueOp memCpyCall = rewriter.create<emitc::CallOpaqueOp>(
+        loc, TypeRange{}, "memcpy",
+        ValueRange{
+            targetPtr.getResult(), srcPtr.getResult(),
+            calculateMemrefTotalSizeBytes(loc, srcMemrefType, rewriter)});
+
+    rewriter.replaceOp(copyOp, memCpyCall.getResults());
+
+    return success();
+  }
+};
+
 struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
   using OpConversionPattern::OpConversionPattern;
 
@@ -380,6 +392,7 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
 
 void mlir::populateMemRefToEmitCConversionPatterns(
     RewritePatternSet &patterns, const TypeConverter &converter) {
-  patterns.add<ConvertAlloca, ConvertAlloc, ConvertGlobal, ConvertGetGlobal,
-               ConvertLoad, ConvertStore>(converter, patterns.getContext());
+  patterns.add<ConvertAlloca, ConvertAlloc, ConvertCopy, ConvertGlobal,
+               ConvertGetGlobal, ConvertLoad, ConvertStore>(
+      converter, patterns.getContext());
 }
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
index e78dd76d6e256..8e965b42f1043 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
@@ -18,6 +18,7 @@
 #include "mlir/IR/Attributes.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/StringRef.h"
 
 namespace mlir {
 #define GEN_PASS_DEF_CONVERTMEMREFTOEMITC
@@ -27,6 +28,15 @@ namespace mlir {
 using namespace mlir;
 
 namespace {
+
+emitc::IncludeOp addHeader(OpBuilder &builder, ModuleOp module,
+                           StringRef headerName) {
+  StringAttr includeAttr = builder.getStringAttr(headerName);
+  return builder.create<emitc::IncludeOp>(
+      module.getLoc(), includeAttr,
+      /*is_standard_include=*/builder.getUnitAttr());
+}
+
 struct ConvertMemRefToEmitCPass
     : public impl::ConvertMemRefToEmitCBase<ConvertMemRefToEmitCPass> {
   using Base::Base;
@@ -57,7 +67,8 @@ struct ConvertMemRefToEmitCPass
     mlir::ModuleOp module = getOperation();
     module.walk([&](mlir::emitc::CallOpaqueOp callOp) {
       if (callOp.getCallee() != alignedAllocFunctionName &&
-          callOp.getCallee() != mallocFunctionName) {
+          callOp.getCallee() != mallocFunctionName &&
+          callOp.getCallee() != memcpyFunctionName) {
         return mlir::WalkResult::advance();
       }
 
@@ -76,12 +87,14 @@ struct ConvertMemRefToEmitCPass
       }
 
       mlir::OpBuilder builder(module.getBody(), module.getBody()->begin());
-      StringAttr includeAttr =
-          builder.getStringAttr(options.lowerToCpp ? cppStandardLibraryHeader
-                                                   : cStandardLibraryHeader);
-      builder.create<mlir::emitc::IncludeOp>(
-          module.getLoc(), includeAttr,
-          /*is_standard_include=*/builder.getUnitAttr());
+      StringRef headerName;
+      if (callOp.getCallee() == memcpyFunctionName)
+        headerName = stringLibraryHeader;
+      else
+        headerName = options.lowerToCpp ? cppStandardLibraryHeader
+                                        : cStandardLibraryHeader;
+
+      addHeader(builder, module, headerName);
       return mlir::WalkResult::interrupt();
     });
   }
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir
index d031d60508df2..4b6eb50807513 100644
--- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir
@@ -1,25 +1,24 @@
-// RUN: mlir-opt -convert-memref-to-emitc %s -split-input-file | FileCheck %s
+// RUN: mlir-opt -convert-memref-to-emitc %s  | FileCheck %s
 
 func.func @copying(%arg0 : memref<2x4xf32>) {
   memref.copy %arg0, %arg0 : memref<2x4xf32> to memref<2x4xf32>
   return
 }
 
-// func.func @copying_memcpy(%arg_0: !emitc.ptr<f32>) {
-//   %size = "emitc.constant"() <{value = 8 : index}> :() -> index
-//   %element_size = "emitc.constant"() <{value = 4 : index}> :() -> index
-//   %total_bytes = emitc.mul %size, %element_size : (index, index) -> index
-  
-//   emitc.call_opaque "memcpy"(%arg_0, %arg_0, %total_bytes) : (!emitc.ptr<f32>, !emitc.ptr<f32>, index) -> ()
-//   return
-// }
+// CHECK: module {
+// CHECK-NEXT:  emitc.include <"string.h">
+// CHECK-LABEL:  copying
+// CHECK-NEXT: %0 = builtin.unrealized_conversion_cast %arg0 : memref<2x4xf32> to !emitc.array<2x4xf32>
+// CHECK-NEXT: %1 = "emitc.constant"() <{value = 0 : index}> : () -> index
+// CHECK-NEXT: %2 = emitc.subscript %0[%1, %1] : (!emitc.array<2x4xf32>, index, index) -> !emitc.lvalue<f32>
+// CHECK-NEXT: %3 = emitc.apply "&"(%2) : (!emitc.lvalue<f32>) -> !emitc.ptr<f32>
+// CHECK-NEXT: %4 = emitc.subscript %0[%1, %1] : (!emitc.array<2x4xf32>, index, index) -> !emitc.lvalue<f32>
+// CHECK-NEXT: %5 = emitc.apply "&"(%4) : (!emitc.lvalue<f32>) -> !emitc.ptr<f32>
+// CHECK-NEXT: %6 = emitc.call_opaque "sizeof"() {args = [f32]} : () -> !emitc.size_t
+// CHECK-NEXT: %7 = "emitc.constant"() <{value = 8 : index}> : () -> index
+// CHECK-NEXT: %8 = emitc.mul %6, %7 : (!emitc.size_t, index) -> !emitc.size_t
+// CHECK-NEXT: emitc.call_opaque "memcpy"(%5, %3, %8) : (!emitc.ptr<f32>, !emitc.ptr<f32>, !emitc.size_t) -> ()
+// CHECK-NEXT:    return
+// CHECK-NEXT:  }
+// CHECK-NEXT:}
 
-// CHECK-LABEL: copying_memcpy
-// CHECK-SAME: %arg_0: !emitc.ptr<f32>
-// CHECK-NEXT: %size = "emitc.constant"() <{value = 8 : index}> :() -> index
-// CHECK-NEXT: %element_size = "emitc.constant"() <{value = 4 : index}> :() -> index
-// CHECK-NEXT: %total_bytes = emitc.mul %size, %element_size : (index, index) -> index
-// CHECK-NEXT: emitc.call_opaque "memcpy"
-// CHECK-SAME: (%arg_0, %arg_0, %total_bytes)
-// CHECK-NEXT: : (!emitc.ptr<f32>, !emitc.ptr<f32>, index) -> ()
-// CHECK-NEXT: return
\ No newline at end of file
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir
index fda01974d3fc8..b6eccfc8f0050 100644
--- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir
@@ -1,13 +1,5 @@
 // RUN: mlir-opt -convert-memref-to-emitc %s -split-input-file -verify-diagnostics
 
-func.func @memref_op(%arg0 : memref<2x4xf32>) {
-  // expected-error at +1 {{failed to legalize operation 'memref.copy'}}
-  memref.copy %arg0, %arg0 : memref<2x4xf32> to memref<2x4xf32>
-  return
-}
-
-// -----
-
 func.func @alloca_with_dynamic_shape() {
   %0 = index.constant 1
   // expected-error at +1 {{failed to legalize operation 'memref.alloca'}}

>From 7e7f59c010d2f53714d629971d8d2dc3c3d17a2e Mon Sep 17 00:00:00 2001
From: Jaddyen <ajaden at google.com>
Date: Mon, 4 Aug 2025 18:11:17 +0000
Subject: [PATCH 3/4] allow for multi dimensional arays

---
 .../Conversion/MemRefToEmitC/MemRefToEmitC.h  |  3 +-
 .../MemRefToEmitC/MemRefToEmitC.cpp           | 88 ++++++++++++-------
 .../MemRefToEmitC/MemRefToEmitCPass.cpp       | 19 ++--
 3 files changed, 68 insertions(+), 42 deletions(-)

diff --git a/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h b/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h
index 4ea6649d64a92..5abfb3d7e72dd 100644
--- a/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h
+++ b/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h
@@ -13,7 +13,8 @@ constexpr const char *mallocFunctionName = "malloc";
 constexpr const char *memcpyFunctionName = "memcpy";
 constexpr const char *cppStandardLibraryHeader = "cstdlib";
 constexpr const char *cStandardLibraryHeader = "stdlib.h";
-constexpr const char *stringLibraryHeader = "string.h";
+constexpr const char *cppStringLibraryHeader = "cstring";
+constexpr const char *cStringLibraryHeader = "string.h";
 
 namespace mlir {
 class DialectRegistry;
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index adb0eb77fdf35..cabbfac4a1dca 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeRange.h"
 #include "mlir/IR/Value.h"
@@ -98,23 +99,25 @@ Type convertMemRefType(MemRefType opTy, const TypeConverter *typeConverter) {
 }
 
 Value calculateMemrefTotalSizeBytes(Location loc, MemRefType memrefType,
-                                    ConversionPatternRewriter &rewriter) {
-  emitc::CallOpaqueOp elementSize = rewriter.create<emitc::CallOpaqueOp>(
-      loc, emitc::SizeTType::get(rewriter.getContext()),
-      rewriter.getStringAttr("sizeof"), ValueRange{},
-      ArrayAttr::get(rewriter.getContext(),
+                                    OpBuilder &builder) {
+  assert(isMemRefTypeLegalForEmitC(memrefType) &&
+         "incompatible memref type for EmitC conversion");
+  emitc::CallOpaqueOp elementSize = builder.create<emitc::CallOpaqueOp>(
+      loc, emitc::SizeTType::get(builder.getContext()),
+      builder.getStringAttr("sizeof"), ValueRange{},
+      ArrayAttr::get(builder.getContext(),
                      {TypeAttr::get(memrefType.getElementType())}));
 
-  IndexType indexType = rewriter.getIndexType();
+  IndexType indexType = builder.getIndexType();
   int64_t numElements = 1;
   for (int64_t dimSize : memrefType.getShape()) {
     numElements *= dimSize;
   }
-  emitc::ConstantOp numElementsValue = rewriter.create<emitc::ConstantOp>(
-      loc, indexType, rewriter.getIndexAttr(numElements));
+  emitc::ConstantOp numElementsValue = builder.create<emitc::ConstantOp>(
+      loc, indexType, builder.getIndexAttr(numElements));
 
-  Type sizeTType = emitc::SizeTType::get(rewriter.getContext());
-  emitc::MulOp totalSizeBytes = rewriter.create<emitc::MulOp>(
+  Type sizeTType = emitc::SizeTType::get(builder.getContext());
+  emitc::MulOp totalSizeBytes = builder.create<emitc::MulOp>(
       loc, sizeTType, elementSize.getResult(0), numElementsValue);
 
   return totalSizeBytes.getResult();
@@ -189,41 +192,64 @@ struct ConvertCopy final : public OpConversionPattern<memref::CopyOp> {
   matchAndRewrite(memref::CopyOp copyOp, OpAdaptor operands,
                   ConversionPatternRewriter &rewriter) const override {
     Location loc = copyOp.getLoc();
-    MemRefType srcMemrefType =
-        dyn_cast<MemRefType>(copyOp.getSource().getType());
+    MemRefType srcMemrefType = cast<MemRefType>(copyOp.getSource().getType());
     MemRefType targetMemrefType =
-        dyn_cast<MemRefType>(copyOp.getTarget().getType());
+        cast<MemRefType>(copyOp.getTarget().getType());
 
-    if (!isMemRefTypeLegalForEmitC(srcMemrefType) ||
-        !isMemRefTypeLegalForEmitC(targetMemrefType)) {
+    if (!isMemRefTypeLegalForEmitC(srcMemrefType)) {
       return rewriter.notifyMatchFailure(
-          loc, "incompatible memref type for EmitC conversion");
+          loc, "incompatible source memref type for EmitC conversion");
+    }
+    if (!isMemRefTypeLegalForEmitC(targetMemrefType)) {
+      return rewriter.notifyMatchFailure(
+          loc, "incompatible target memref type for EmitC conversion");
     }
 
+    auto createPointerFromEmitcArray =
+        [&](mlir::Location loc, mlir::OpBuilder &rewriter,
+            mlir::TypedValue<emitc::ArrayType> arrayValue,
+            mlir::MemRefType memrefType,
+            emitc::ConstantOp zeroIndex) -> emitc::ApplyOp {
+      // Get the rank of the array to create the correct number of zero indices.
+      int64_t rank = arrayValue.getType().getRank();
+      llvm::SmallVector<mlir::Value> indices;
+      for (int i = 0; i < rank; ++i) {
+        indices.push_back(zeroIndex);
+      }
+
+      // Create a subscript operation to get the element at index [0, 0, ...,
+      // 0].
+      emitc::SubscriptOp subPtr = rewriter.create<emitc::SubscriptOp>(
+          loc, arrayValue, mlir::ValueRange(indices));
+
+      // Create an apply operation to take the address of the subscripted
+      // element.
+      emitc::ApplyOp ptr = rewriter.create<emitc::ApplyOp>(
+          loc, emitc::PointerType::get(memrefType.getElementType()),
+          rewriter.getStringAttr("&"), subPtr);
+
+      return ptr;
+    };
+
+    // Create a constant zero index.
     emitc::ConstantOp zeroIndex = rewriter.create<emitc::ConstantOp>(
         loc, rewriter.getIndexType(), rewriter.getIndexAttr(0));
+
     auto srcArrayValue =
         dyn_cast<TypedValue<emitc::ArrayType>>(operands.getSource());
+    emitc::ApplyOp srcPtr = createPointerFromEmitcArray(
+        loc, rewriter, srcArrayValue, srcMemrefType, zeroIndex);
 
-    emitc::SubscriptOp srcSubPtr = rewriter.create<emitc::SubscriptOp>(
-        loc, srcArrayValue, ValueRange{zeroIndex, zeroIndex});
-    emitc::ApplyOp srcPtr = rewriter.create<emitc::ApplyOp>(
-        loc, emitc::PointerType::get(srcMemrefType.getElementType()),
-        rewriter.getStringAttr("&"), srcSubPtr);
-
-    auto arrayValue =
+    auto targetArrayValue =
         dyn_cast<TypedValue<emitc::ArrayType>>(operands.getTarget());
-    emitc::SubscriptOp targetSubPtr = rewriter.create<emitc::SubscriptOp>(
-        loc, arrayValue, ValueRange{zeroIndex, zeroIndex});
-    emitc::ApplyOp targetPtr = rewriter.create<emitc::ApplyOp>(
-        loc, emitc::PointerType::get(targetMemrefType.getElementType()),
-        rewriter.getStringAttr("&"), targetSubPtr);
+    emitc::ApplyOp targetPtr = createPointerFromEmitcArray(
+        loc, rewriter, targetArrayValue, targetMemrefType, zeroIndex);
 
+    OpBuilder builder = rewriter;
     emitc::CallOpaqueOp memCpyCall = rewriter.create<emitc::CallOpaqueOp>(
         loc, TypeRange{}, "memcpy",
-        ValueRange{
-            targetPtr.getResult(), srcPtr.getResult(),
-            calculateMemrefTotalSizeBytes(loc, srcMemrefType, rewriter)});
+        ValueRange{targetPtr.getResult(), srcPtr.getResult(),
+                   calculateMemrefTotalSizeBytes(loc, srcMemrefType, builder)});
 
     rewriter.replaceOp(copyOp, memCpyCall.getResults());
 
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
index 8e965b42f1043..c60e7488fdb38 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
@@ -29,8 +29,8 @@ using namespace mlir;
 
 namespace {
 
-emitc::IncludeOp addHeader(OpBuilder &builder, ModuleOp module,
-                           StringRef headerName) {
+emitc::IncludeOp addStandardHeader(OpBuilder &builder, ModuleOp module,
+                                   StringRef headerName) {
   StringAttr includeAttr = builder.getStringAttr(headerName);
   return builder.create<emitc::IncludeOp>(
       module.getLoc(), includeAttr,
@@ -68,33 +68,32 @@ struct ConvertMemRefToEmitCPass
     module.walk([&](mlir::emitc::CallOpaqueOp callOp) {
       if (callOp.getCallee() != alignedAllocFunctionName &&
           callOp.getCallee() != mallocFunctionName &&
-          callOp.getCallee() != memcpyFunctionName) {
+          callOp.getCallee() != memcpyFunctionName)
         return mlir::WalkResult::advance();
-      }
 
       for (auto &op : *module.getBody()) {
         emitc::IncludeOp includeOp = llvm::dyn_cast<mlir::emitc::IncludeOp>(op);
-        if (!includeOp) {
+        if (!includeOp)
           continue;
-        }
+
         if (includeOp.getIsStandardInclude() &&
             ((options.lowerToCpp &&
               includeOp.getInclude() == cppStandardLibraryHeader) ||
              (!options.lowerToCpp &&
-              includeOp.getInclude() == cStandardLibraryHeader))) {
+              includeOp.getInclude() == cStandardLibraryHeader)))
           return mlir::WalkResult::interrupt();
-        }
       }
 
       mlir::OpBuilder builder(module.getBody(), module.getBody()->begin());
       StringRef headerName;
       if (callOp.getCallee() == memcpyFunctionName)
-        headerName = stringLibraryHeader;
+        headerName =
+            options.lowerToCpp ? cppStringLibraryHeader : cStringLibraryHeader;
       else
         headerName = options.lowerToCpp ? cppStandardLibraryHeader
                                         : cStandardLibraryHeader;
 
-      addHeader(builder, module, headerName);
+      addStandardHeader(builder, module, headerName);
       return mlir::WalkResult::interrupt();
     });
   }

>From 7650eaaa0c41e370a07581d8d2ccd5c9ddfc0824 Mon Sep 17 00:00:00 2001
From: Jaddyen <ajaden at google.com>
Date: Mon, 4 Aug 2025 18:19:32 +0000
Subject: [PATCH 4/4] update test file

---
 .../MemRefToEmitC/MemRefToEmitC.cpp           |  7 -----
 .../MemRefToEmitC/memref-to-emitc-copy.mlir   | 26 ++++++++++---------
 2 files changed, 14 insertions(+), 19 deletions(-)

diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index cabbfac4a1dca..4130e9be88a89 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -210,20 +210,14 @@ struct ConvertCopy final : public OpConversionPattern<memref::CopyOp> {
             mlir::TypedValue<emitc::ArrayType> arrayValue,
             mlir::MemRefType memrefType,
             emitc::ConstantOp zeroIndex) -> emitc::ApplyOp {
-      // Get the rank of the array to create the correct number of zero indices.
       int64_t rank = arrayValue.getType().getRank();
       llvm::SmallVector<mlir::Value> indices;
       for (int i = 0; i < rank; ++i) {
         indices.push_back(zeroIndex);
       }
 
-      // Create a subscript operation to get the element at index [0, 0, ...,
-      // 0].
       emitc::SubscriptOp subPtr = rewriter.create<emitc::SubscriptOp>(
           loc, arrayValue, mlir::ValueRange(indices));
-
-      // Create an apply operation to take the address of the subscripted
-      // element.
       emitc::ApplyOp ptr = rewriter.create<emitc::ApplyOp>(
           loc, emitc::PointerType::get(memrefType.getElementType()),
           rewriter.getStringAttr("&"), subPtr);
@@ -231,7 +225,6 @@ struct ConvertCopy final : public OpConversionPattern<memref::CopyOp> {
       return ptr;
     };
 
-    // Create a constant zero index.
     emitc::ConstantOp zeroIndex = rewriter.create<emitc::ConstantOp>(
         loc, rewriter.getIndexType(), rewriter.getIndexAttr(0));
 
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir
index 4b6eb50807513..88325e57762d3 100644
--- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir
@@ -1,23 +1,25 @@
 // RUN: mlir-opt -convert-memref-to-emitc %s  | FileCheck %s
 
-func.func @copying(%arg0 : memref<2x4xf32>) {
-  memref.copy %arg0, %arg0 : memref<2x4xf32> to memref<2x4xf32>
+func.func @copying(%arg0 : memref<9x4x5x7xf32>, %arg1 : memref<9x4x5x7xf32>) {
+  memref.copy %arg0, %arg1 : memref<9x4x5x7xf32> to memref<9x4x5x7xf32>
   return
 }
 
 // CHECK: module {
 // CHECK-NEXT:  emitc.include <"string.h">
 // CHECK-LABEL:  copying
-// CHECK-NEXT: %0 = builtin.unrealized_conversion_cast %arg0 : memref<2x4xf32> to !emitc.array<2x4xf32>
-// CHECK-NEXT: %1 = "emitc.constant"() <{value = 0 : index}> : () -> index
-// CHECK-NEXT: %2 = emitc.subscript %0[%1, %1] : (!emitc.array<2x4xf32>, index, index) -> !emitc.lvalue<f32>
-// CHECK-NEXT: %3 = emitc.apply "&"(%2) : (!emitc.lvalue<f32>) -> !emitc.ptr<f32>
-// CHECK-NEXT: %4 = emitc.subscript %0[%1, %1] : (!emitc.array<2x4xf32>, index, index) -> !emitc.lvalue<f32>
-// CHECK-NEXT: %5 = emitc.apply "&"(%4) : (!emitc.lvalue<f32>) -> !emitc.ptr<f32>
-// CHECK-NEXT: %6 = emitc.call_opaque "sizeof"() {args = [f32]} : () -> !emitc.size_t
-// CHECK-NEXT: %7 = "emitc.constant"() <{value = 8 : index}> : () -> index
-// CHECK-NEXT: %8 = emitc.mul %6, %7 : (!emitc.size_t, index) -> !emitc.size_t
-// CHECK-NEXT: emitc.call_opaque "memcpy"(%5, %3, %8) : (!emitc.ptr<f32>, !emitc.ptr<f32>, !emitc.size_t) -> ()
+// CHECK-SAME: %[[arg0:.*]]: memref<9x4x5x7xf32>, %[[arg1:.*]]: memref<9x4x5x7xf32>
+// CHECK-NEXT: %0 = builtin.unrealized_conversion_cast %arg1 : memref<9x4x5x7xf32> to !emitc.array<9x4x5x7xf32>
+// CHECK-NEXT: %1 = builtin.unrealized_conversion_cast %arg0 : memref<9x4x5x7xf32> to !emitc.array<9x4x5x7xf32>
+// CHECK-NEXT: %2 = "emitc.constant"() <{value = 0 : index}> : () -> index
+// CHECK-NEXT: %3 = emitc.subscript %1[%2, %2, %2, %2] : (!emitc.array<9x4x5x7xf32>, index, index, index, index) -> !emitc.lvalue<f32>
+// CHECK-NEXT: %4 = emitc.apply "&"(%3) : (!emitc.lvalue<f32>) -> !emitc.ptr<f32>
+// CHECK-NEXT: %5 = emitc.subscript %0[%2, %2, %2, %2] : (!emitc.array<9x4x5x7xf32>, index, index, index, index) -> !emitc.lvalue<f32>
+// CHECK-NEXT: %6 = emitc.apply "&"(%5) : (!emitc.lvalue<f32>) -> !emitc.ptr<f32>
+// CHECK-NEXT: %7 = emitc.call_opaque "sizeof"() {args = [f32]} : () -> !emitc.size_t
+// CHECK-NEXT: %8 = "emitc.constant"() <{value = 1260 : index}> : () -> index
+// CHECK-NEXT: %9 = emitc.mul %7, %8 : (!emitc.size_t, index) -> !emitc.size_t
+// CHECK-NEXT: emitc.call_opaque "memcpy"(%6, %4, %9) : (!emitc.ptr<f32>, !emitc.ptr<f32>, !emitc.size_t) -> ()
 // CHECK-NEXT:    return
 // CHECK-NEXT:  }
 // CHECK-NEXT:}



More information about the Mlir-commits mailing list