[Mlir-commits] [mlir] [mlir][EmitC] Expand the MemRefToEmitC pass - Lowering `CopyOp` (PR #151206)

Jaden Angella llvmlistbot at llvm.org
Fri Aug 8 14:25:14 PDT 2025


https://github.com/Jaddyen updated https://github.com/llvm/llvm-project/pull/151206

>From 5a54bdf27d2619982c876773b7ae3a9477eeeefb Mon Sep 17 00:00:00 2001
From: Jaddyen <ajaden at google.com>
Date: Fri, 25 Jul 2025 16:43:57 +0000
Subject: [PATCH 1/9] Adding lowering for copyop

---
 .../MemRefToEmitC/MemRefToEmitC.cpp           | 60 +++++++++++++++++++
 .../MemRefToEmitC/memref-to-emitc-copy.mlir   | 25 ++++++++
 2 files changed, 85 insertions(+)
 create mode 100644 mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir

diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index 6bd0e2d4d4b08..34ea4989c8156 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -87,6 +87,66 @@ struct ConvertAlloca final : public OpConversionPattern<memref::AllocaOp> {
   }
 };
 
+struct ConvertCopy final : public OpConversionPattern<memref::CopyOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(memref::CopyOp copyOp, OpAdaptor operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = copyOp.getLoc();
+    auto srcMemrefType = dyn_cast<MemRefType>(copyOp.getSource().getType());
+    auto targetMemrefType = dyn_cast<MemRefType>(copyOp.getTarget().getType());
+
+    if (!srcMemrefType || !targetMemrefType) {
+      return failure();
+    }
+
+    // 1. Cast source memref to a pointer.
+    auto srcPtrType = emitc::PointerType::get(srcMemrefType.getElementType());
+    auto srcArrayValue =
+        dyn_cast<TypedValue<emitc::ArrayType>>(operands.getSource());
+    auto stcArrayPtr =
+        emitc::PointerType::get(srcArrayValue.getType().getElementType());
+    auto srcPtr = rewriter.create<emitc::CastOp>(loc, srcPtrType,
+                                                 stcArrayPtr.getPointee());
+
+    // 2. Cast target memref to a pointer.
+    auto targetPtrType =
+        emitc::PointerType::get(targetMemrefType.getElementType());
+
+    auto arrayValue =
+        dyn_cast<TypedValue<emitc::ArrayType>>(operands.getTarget());
+
+    // Cast the target memref value to a pointer type.
+    auto targetPtr =
+        rewriter.create<emitc::CastOp>(loc, targetPtrType, arrayValue);
+
+    // 3. Calculate the size in bytes of the memref.
+    auto elementSize = rewriter.create<emitc::CallOpaqueOp>(
+        loc, rewriter.getIndexType(), rewriter.getStringAttr("sizeof"),
+        mlir::ValueRange{},
+        mlir::ArrayAttr::get(
+            rewriter.getContext(),
+            {mlir::TypeAttr::get(srcMemrefType.getElementType())}));
+
+    auto numElements = rewriter.create<emitc::ConstantOp>(
+        loc, rewriter.getIndexType(),
+        rewriter.getIntegerAttr(rewriter.getIndexType(),
+                                srcMemrefType.getNumElements()));
+    auto byteSize = rewriter.create<emitc::MulOp>(loc, rewriter.getIndexType(),
+                                                  elementSize.getResult(0),
+                                                  numElements.getResult());
+
+    // 4. Emit the memcpy call.
+    rewriter.create<emitc::CallOpaqueOp>(loc, TypeRange{}, "memcpy",
+                                         ValueRange{targetPtr.getResult(),
+                                                    srcPtr.getResult(),
+                                                    byteSize.getResult()});
+
+    return success();
+  }
+};
+
 Type convertMemRefType(MemRefType opTy, const TypeConverter *typeConverter) {
   Type resultTy;
   if (opTy.getRank() == 0) {
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir
new file mode 100644
index 0000000000000..d031d60508df2
--- /dev/null
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir
@@ -0,0 +1,25 @@
+// RUN: mlir-opt -convert-memref-to-emitc %s -split-input-file | FileCheck %s
+
+func.func @copying(%arg0 : memref<2x4xf32>) {
+  memref.copy %arg0, %arg0 : memref<2x4xf32> to memref<2x4xf32>
+  return
+}
+
+// func.func @copying_memcpy(%arg_0: !emitc.ptr<f32>) {
+//   %size = "emitc.constant"() <{value = 8 : index}> :() -> index
+//   %element_size = "emitc.constant"() <{value = 4 : index}> :() -> index
+//   %total_bytes = emitc.mul %size, %element_size : (index, index) -> index
+  
+//   emitc.call_opaque "memcpy"(%arg_0, %arg_0, %total_bytes) : (!emitc.ptr<f32>, !emitc.ptr<f32>, index) -> ()
+//   return
+// }
+
+// CHECK-LABEL: copying_memcpy
+// CHECK-SAME: %arg_0: !emitc.ptr<f32>
+// CHECK-NEXT: %size = "emitc.constant"() <{value = 8 : index}> :() -> index
+// CHECK-NEXT: %element_size = "emitc.constant"() <{value = 4 : index}> :() -> index
+// CHECK-NEXT: %total_bytes = emitc.mul %size, %element_size : (index, index) -> index
+// CHECK-NEXT: emitc.call_opaque "memcpy"
+// CHECK-SAME: (%arg_0, %arg_0, %total_bytes)
+// CHECK-NEXT: : (!emitc.ptr<f32>, !emitc.ptr<f32>, index) -> ()
+// CHECK-NEXT: return
\ No newline at end of file

>From b5b7637c91ec064933fb5ab94a04956fe3d4783e Mon Sep 17 00:00:00 2001
From: Jaddyen <ajaden at google.com>
Date: Tue, 29 Jul 2025 18:14:39 +0000
Subject: [PATCH 2/9] use subscript and apply ops

---
 .../Conversion/MemRefToEmitC/MemRefToEmitC.h  |   2 +
 .../MemRefToEmitC/MemRefToEmitC.cpp           | 137 ++++++++++--------
 .../MemRefToEmitC/MemRefToEmitCPass.cpp       |  27 +++-
 .../MemRefToEmitC/memref-to-emitc-copy.mlir   |  35 +++--
 .../MemRefToEmitC/memref-to-emitc-failed.mlir |   8 -
 5 files changed, 114 insertions(+), 95 deletions(-)

diff --git a/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h b/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h
index b595b6a308bea..4ea6649d64a92 100644
--- a/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h
+++ b/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h
@@ -10,8 +10,10 @@
 
 constexpr const char *alignedAllocFunctionName = "aligned_alloc";
 constexpr const char *mallocFunctionName = "malloc";
+constexpr const char *memcpyFunctionName = "memcpy";
 constexpr const char *cppStandardLibraryHeader = "cstdlib";
 constexpr const char *cStandardLibraryHeader = "stdlib.h";
+constexpr const char *stringLibraryHeader = "string.h";
 
 namespace mlir {
 class DialectRegistry;
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index 34ea4989c8156..adb0eb77fdf35 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -87,66 +87,6 @@ struct ConvertAlloca final : public OpConversionPattern<memref::AllocaOp> {
   }
 };
 
-struct ConvertCopy final : public OpConversionPattern<memref::CopyOp> {
-  using OpConversionPattern::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(memref::CopyOp copyOp, OpAdaptor operands,
-                  ConversionPatternRewriter &rewriter) const override {
-    Location loc = copyOp.getLoc();
-    auto srcMemrefType = dyn_cast<MemRefType>(copyOp.getSource().getType());
-    auto targetMemrefType = dyn_cast<MemRefType>(copyOp.getTarget().getType());
-
-    if (!srcMemrefType || !targetMemrefType) {
-      return failure();
-    }
-
-    // 1. Cast source memref to a pointer.
-    auto srcPtrType = emitc::PointerType::get(srcMemrefType.getElementType());
-    auto srcArrayValue =
-        dyn_cast<TypedValue<emitc::ArrayType>>(operands.getSource());
-    auto stcArrayPtr =
-        emitc::PointerType::get(srcArrayValue.getType().getElementType());
-    auto srcPtr = rewriter.create<emitc::CastOp>(loc, srcPtrType,
-                                                 stcArrayPtr.getPointee());
-
-    // 2. Cast target memref to a pointer.
-    auto targetPtrType =
-        emitc::PointerType::get(targetMemrefType.getElementType());
-
-    auto arrayValue =
-        dyn_cast<TypedValue<emitc::ArrayType>>(operands.getTarget());
-
-    // Cast the target memref value to a pointer type.
-    auto targetPtr =
-        rewriter.create<emitc::CastOp>(loc, targetPtrType, arrayValue);
-
-    // 3. Calculate the size in bytes of the memref.
-    auto elementSize = rewriter.create<emitc::CallOpaqueOp>(
-        loc, rewriter.getIndexType(), rewriter.getStringAttr("sizeof"),
-        mlir::ValueRange{},
-        mlir::ArrayAttr::get(
-            rewriter.getContext(),
-            {mlir::TypeAttr::get(srcMemrefType.getElementType())}));
-
-    auto numElements = rewriter.create<emitc::ConstantOp>(
-        loc, rewriter.getIndexType(),
-        rewriter.getIntegerAttr(rewriter.getIndexType(),
-                                srcMemrefType.getNumElements()));
-    auto byteSize = rewriter.create<emitc::MulOp>(loc, rewriter.getIndexType(),
-                                                  elementSize.getResult(0),
-                                                  numElements.getResult());
-
-    // 4. Emit the memcpy call.
-    rewriter.create<emitc::CallOpaqueOp>(loc, TypeRange{}, "memcpy",
-                                         ValueRange{targetPtr.getResult(),
-                                                    srcPtr.getResult(),
-                                                    byteSize.getResult()});
-
-    return success();
-  }
-};
-
 Type convertMemRefType(MemRefType opTy, const TypeConverter *typeConverter) {
   Type resultTy;
   if (opTy.getRank() == 0) {
@@ -157,6 +97,29 @@ Type convertMemRefType(MemRefType opTy, const TypeConverter *typeConverter) {
   return resultTy;
 }
 
+Value calculateMemrefTotalSizeBytes(Location loc, MemRefType memrefType,
+                                    ConversionPatternRewriter &rewriter) {
+  emitc::CallOpaqueOp elementSize = rewriter.create<emitc::CallOpaqueOp>(
+      loc, emitc::SizeTType::get(rewriter.getContext()),
+      rewriter.getStringAttr("sizeof"), ValueRange{},
+      ArrayAttr::get(rewriter.getContext(),
+                     {TypeAttr::get(memrefType.getElementType())}));
+
+  IndexType indexType = rewriter.getIndexType();
+  int64_t numElements = 1;
+  for (int64_t dimSize : memrefType.getShape()) {
+    numElements *= dimSize;
+  }
+  emitc::ConstantOp numElementsValue = rewriter.create<emitc::ConstantOp>(
+      loc, indexType, rewriter.getIndexAttr(numElements));
+
+  Type sizeTType = emitc::SizeTType::get(rewriter.getContext());
+  emitc::MulOp totalSizeBytes = rewriter.create<emitc::MulOp>(
+      loc, sizeTType, elementSize.getResult(0), numElementsValue);
+
+  return totalSizeBytes.getResult();
+}
+
 struct ConvertAlloc final : public OpConversionPattern<memref::AllocOp> {
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
@@ -219,6 +182,55 @@ struct ConvertAlloc final : public OpConversionPattern<memref::AllocOp> {
   }
 };
 
+struct ConvertCopy final : public OpConversionPattern<memref::CopyOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(memref::CopyOp copyOp, OpAdaptor operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = copyOp.getLoc();
+    MemRefType srcMemrefType =
+        dyn_cast<MemRefType>(copyOp.getSource().getType());
+    MemRefType targetMemrefType =
+        dyn_cast<MemRefType>(copyOp.getTarget().getType());
+
+    if (!isMemRefTypeLegalForEmitC(srcMemrefType) ||
+        !isMemRefTypeLegalForEmitC(targetMemrefType)) {
+      return rewriter.notifyMatchFailure(
+          loc, "incompatible memref type for EmitC conversion");
+    }
+
+    emitc::ConstantOp zeroIndex = rewriter.create<emitc::ConstantOp>(
+        loc, rewriter.getIndexType(), rewriter.getIndexAttr(0));
+    auto srcArrayValue =
+        dyn_cast<TypedValue<emitc::ArrayType>>(operands.getSource());
+
+    emitc::SubscriptOp srcSubPtr = rewriter.create<emitc::SubscriptOp>(
+        loc, srcArrayValue, ValueRange{zeroIndex, zeroIndex});
+    emitc::ApplyOp srcPtr = rewriter.create<emitc::ApplyOp>(
+        loc, emitc::PointerType::get(srcMemrefType.getElementType()),
+        rewriter.getStringAttr("&"), srcSubPtr);
+
+    auto arrayValue =
+        dyn_cast<TypedValue<emitc::ArrayType>>(operands.getTarget());
+    emitc::SubscriptOp targetSubPtr = rewriter.create<emitc::SubscriptOp>(
+        loc, arrayValue, ValueRange{zeroIndex, zeroIndex});
+    emitc::ApplyOp targetPtr = rewriter.create<emitc::ApplyOp>(
+        loc, emitc::PointerType::get(targetMemrefType.getElementType()),
+        rewriter.getStringAttr("&"), targetSubPtr);
+
+    emitc::CallOpaqueOp memCpyCall = rewriter.create<emitc::CallOpaqueOp>(
+        loc, TypeRange{}, "memcpy",
+        ValueRange{
+            targetPtr.getResult(), srcPtr.getResult(),
+            calculateMemrefTotalSizeBytes(loc, srcMemrefType, rewriter)});
+
+    rewriter.replaceOp(copyOp, memCpyCall.getResults());
+
+    return success();
+  }
+};
+
 struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
   using OpConversionPattern::OpConversionPattern;
 
@@ -380,6 +392,7 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
 
 void mlir::populateMemRefToEmitCConversionPatterns(
     RewritePatternSet &patterns, const TypeConverter &converter) {
-  patterns.add<ConvertAlloca, ConvertAlloc, ConvertGlobal, ConvertGetGlobal,
-               ConvertLoad, ConvertStore>(converter, patterns.getContext());
+  patterns.add<ConvertAlloca, ConvertAlloc, ConvertCopy, ConvertGlobal,
+               ConvertGetGlobal, ConvertLoad, ConvertStore>(
+      converter, patterns.getContext());
 }
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
index e78dd76d6e256..8e965b42f1043 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
@@ -18,6 +18,7 @@
 #include "mlir/IR/Attributes.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/StringRef.h"
 
 namespace mlir {
 #define GEN_PASS_DEF_CONVERTMEMREFTOEMITC
@@ -27,6 +28,15 @@ namespace mlir {
 using namespace mlir;
 
 namespace {
+
+emitc::IncludeOp addHeader(OpBuilder &builder, ModuleOp module,
+                           StringRef headerName) {
+  StringAttr includeAttr = builder.getStringAttr(headerName);
+  return builder.create<emitc::IncludeOp>(
+      module.getLoc(), includeAttr,
+      /*is_standard_include=*/builder.getUnitAttr());
+}
+
 struct ConvertMemRefToEmitCPass
     : public impl::ConvertMemRefToEmitCBase<ConvertMemRefToEmitCPass> {
   using Base::Base;
@@ -57,7 +67,8 @@ struct ConvertMemRefToEmitCPass
     mlir::ModuleOp module = getOperation();
     module.walk([&](mlir::emitc::CallOpaqueOp callOp) {
       if (callOp.getCallee() != alignedAllocFunctionName &&
-          callOp.getCallee() != mallocFunctionName) {
+          callOp.getCallee() != mallocFunctionName &&
+          callOp.getCallee() != memcpyFunctionName) {
         return mlir::WalkResult::advance();
       }
 
@@ -76,12 +87,14 @@ struct ConvertMemRefToEmitCPass
       }
 
       mlir::OpBuilder builder(module.getBody(), module.getBody()->begin());
-      StringAttr includeAttr =
-          builder.getStringAttr(options.lowerToCpp ? cppStandardLibraryHeader
-                                                   : cStandardLibraryHeader);
-      builder.create<mlir::emitc::IncludeOp>(
-          module.getLoc(), includeAttr,
-          /*is_standard_include=*/builder.getUnitAttr());
+      StringRef headerName;
+      if (callOp.getCallee() == memcpyFunctionName)
+        headerName = stringLibraryHeader;
+      else
+        headerName = options.lowerToCpp ? cppStandardLibraryHeader
+                                        : cStandardLibraryHeader;
+
+      addHeader(builder, module, headerName);
       return mlir::WalkResult::interrupt();
     });
   }
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir
index d031d60508df2..4b6eb50807513 100644
--- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir
@@ -1,25 +1,24 @@
-// RUN: mlir-opt -convert-memref-to-emitc %s -split-input-file | FileCheck %s
+// RUN: mlir-opt -convert-memref-to-emitc %s  | FileCheck %s
 
 func.func @copying(%arg0 : memref<2x4xf32>) {
   memref.copy %arg0, %arg0 : memref<2x4xf32> to memref<2x4xf32>
   return
 }
 
-// func.func @copying_memcpy(%arg_0: !emitc.ptr<f32>) {
-//   %size = "emitc.constant"() <{value = 8 : index}> :() -> index
-//   %element_size = "emitc.constant"() <{value = 4 : index}> :() -> index
-//   %total_bytes = emitc.mul %size, %element_size : (index, index) -> index
-  
-//   emitc.call_opaque "memcpy"(%arg_0, %arg_0, %total_bytes) : (!emitc.ptr<f32>, !emitc.ptr<f32>, index) -> ()
-//   return
-// }
+// CHECK: module {
+// CHECK-NEXT:  emitc.include <"string.h">
+// CHECK-LABEL:  copying
+// CHECK-NEXT: %0 = builtin.unrealized_conversion_cast %arg0 : memref<2x4xf32> to !emitc.array<2x4xf32>
+// CHECK-NEXT: %1 = "emitc.constant"() <{value = 0 : index}> : () -> index
+// CHECK-NEXT: %2 = emitc.subscript %0[%1, %1] : (!emitc.array<2x4xf32>, index, index) -> !emitc.lvalue<f32>
+// CHECK-NEXT: %3 = emitc.apply "&"(%2) : (!emitc.lvalue<f32>) -> !emitc.ptr<f32>
+// CHECK-NEXT: %4 = emitc.subscript %0[%1, %1] : (!emitc.array<2x4xf32>, index, index) -> !emitc.lvalue<f32>
+// CHECK-NEXT: %5 = emitc.apply "&"(%4) : (!emitc.lvalue<f32>) -> !emitc.ptr<f32>
+// CHECK-NEXT: %6 = emitc.call_opaque "sizeof"() {args = [f32]} : () -> !emitc.size_t
+// CHECK-NEXT: %7 = "emitc.constant"() <{value = 8 : index}> : () -> index
+// CHECK-NEXT: %8 = emitc.mul %6, %7 : (!emitc.size_t, index) -> !emitc.size_t
+// CHECK-NEXT: emitc.call_opaque "memcpy"(%5, %3, %8) : (!emitc.ptr<f32>, !emitc.ptr<f32>, !emitc.size_t) -> ()
+// CHECK-NEXT:    return
+// CHECK-NEXT:  }
+// CHECK-NEXT:}
 
-// CHECK-LABEL: copying_memcpy
-// CHECK-SAME: %arg_0: !emitc.ptr<f32>
-// CHECK-NEXT: %size = "emitc.constant"() <{value = 8 : index}> :() -> index
-// CHECK-NEXT: %element_size = "emitc.constant"() <{value = 4 : index}> :() -> index
-// CHECK-NEXT: %total_bytes = emitc.mul %size, %element_size : (index, index) -> index
-// CHECK-NEXT: emitc.call_opaque "memcpy"
-// CHECK-SAME: (%arg_0, %arg_0, %total_bytes)
-// CHECK-NEXT: : (!emitc.ptr<f32>, !emitc.ptr<f32>, index) -> ()
-// CHECK-NEXT: return
\ No newline at end of file
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir
index fda01974d3fc8..b6eccfc8f0050 100644
--- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir
@@ -1,13 +1,5 @@
 // RUN: mlir-opt -convert-memref-to-emitc %s -split-input-file -verify-diagnostics
 
-func.func @memref_op(%arg0 : memref<2x4xf32>) {
-  // expected-error at +1 {{failed to legalize operation 'memref.copy'}}
-  memref.copy %arg0, %arg0 : memref<2x4xf32> to memref<2x4xf32>
-  return
-}
-
-// -----
-
 func.func @alloca_with_dynamic_shape() {
   %0 = index.constant 1
   // expected-error at +1 {{failed to legalize operation 'memref.alloca'}}

>From c17cf73dbcd269f2de62f6efd473d40caff5805c Mon Sep 17 00:00:00 2001
From: Jaddyen <ajaden at google.com>
Date: Mon, 4 Aug 2025 18:11:17 +0000
Subject: [PATCH 3/9] allow for multi dimensional arays

---
 .../Conversion/MemRefToEmitC/MemRefToEmitC.h  |  3 +-
 .../MemRefToEmitC/MemRefToEmitC.cpp           | 88 ++++++++++++-------
 .../MemRefToEmitC/MemRefToEmitCPass.cpp       | 19 ++--
 3 files changed, 68 insertions(+), 42 deletions(-)

diff --git a/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h b/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h
index 4ea6649d64a92..5abfb3d7e72dd 100644
--- a/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h
+++ b/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h
@@ -13,7 +13,8 @@ constexpr const char *mallocFunctionName = "malloc";
 constexpr const char *memcpyFunctionName = "memcpy";
 constexpr const char *cppStandardLibraryHeader = "cstdlib";
 constexpr const char *cStandardLibraryHeader = "stdlib.h";
-constexpr const char *stringLibraryHeader = "string.h";
+constexpr const char *cppStringLibraryHeader = "cstring";
+constexpr const char *cStringLibraryHeader = "string.h";
 
 namespace mlir {
 class DialectRegistry;
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index adb0eb77fdf35..cabbfac4a1dca 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeRange.h"
 #include "mlir/IR/Value.h"
@@ -98,23 +99,25 @@ Type convertMemRefType(MemRefType opTy, const TypeConverter *typeConverter) {
 }
 
 Value calculateMemrefTotalSizeBytes(Location loc, MemRefType memrefType,
-                                    ConversionPatternRewriter &rewriter) {
-  emitc::CallOpaqueOp elementSize = rewriter.create<emitc::CallOpaqueOp>(
-      loc, emitc::SizeTType::get(rewriter.getContext()),
-      rewriter.getStringAttr("sizeof"), ValueRange{},
-      ArrayAttr::get(rewriter.getContext(),
+                                    OpBuilder &builder) {
+  assert(isMemRefTypeLegalForEmitC(memrefType) &&
+         "incompatible memref type for EmitC conversion");
+  emitc::CallOpaqueOp elementSize = builder.create<emitc::CallOpaqueOp>(
+      loc, emitc::SizeTType::get(builder.getContext()),
+      builder.getStringAttr("sizeof"), ValueRange{},
+      ArrayAttr::get(builder.getContext(),
                      {TypeAttr::get(memrefType.getElementType())}));
 
-  IndexType indexType = rewriter.getIndexType();
+  IndexType indexType = builder.getIndexType();
   int64_t numElements = 1;
   for (int64_t dimSize : memrefType.getShape()) {
     numElements *= dimSize;
   }
-  emitc::ConstantOp numElementsValue = rewriter.create<emitc::ConstantOp>(
-      loc, indexType, rewriter.getIndexAttr(numElements));
+  emitc::ConstantOp numElementsValue = builder.create<emitc::ConstantOp>(
+      loc, indexType, builder.getIndexAttr(numElements));
 
-  Type sizeTType = emitc::SizeTType::get(rewriter.getContext());
-  emitc::MulOp totalSizeBytes = rewriter.create<emitc::MulOp>(
+  Type sizeTType = emitc::SizeTType::get(builder.getContext());
+  emitc::MulOp totalSizeBytes = builder.create<emitc::MulOp>(
       loc, sizeTType, elementSize.getResult(0), numElementsValue);
 
   return totalSizeBytes.getResult();
@@ -189,41 +192,64 @@ struct ConvertCopy final : public OpConversionPattern<memref::CopyOp> {
   matchAndRewrite(memref::CopyOp copyOp, OpAdaptor operands,
                   ConversionPatternRewriter &rewriter) const override {
     Location loc = copyOp.getLoc();
-    MemRefType srcMemrefType =
-        dyn_cast<MemRefType>(copyOp.getSource().getType());
+    MemRefType srcMemrefType = cast<MemRefType>(copyOp.getSource().getType());
     MemRefType targetMemrefType =
-        dyn_cast<MemRefType>(copyOp.getTarget().getType());
+        cast<MemRefType>(copyOp.getTarget().getType());
 
-    if (!isMemRefTypeLegalForEmitC(srcMemrefType) ||
-        !isMemRefTypeLegalForEmitC(targetMemrefType)) {
+    if (!isMemRefTypeLegalForEmitC(srcMemrefType)) {
       return rewriter.notifyMatchFailure(
-          loc, "incompatible memref type for EmitC conversion");
+          loc, "incompatible source memref type for EmitC conversion");
+    }
+    if (!isMemRefTypeLegalForEmitC(targetMemrefType)) {
+      return rewriter.notifyMatchFailure(
+          loc, "incompatible target memref type for EmitC conversion");
     }
 
+    auto createPointerFromEmitcArray =
+        [&](mlir::Location loc, mlir::OpBuilder &rewriter,
+            mlir::TypedValue<emitc::ArrayType> arrayValue,
+            mlir::MemRefType memrefType,
+            emitc::ConstantOp zeroIndex) -> emitc::ApplyOp {
+      // Get the rank of the array to create the correct number of zero indices.
+      int64_t rank = arrayValue.getType().getRank();
+      llvm::SmallVector<mlir::Value> indices;
+      for (int i = 0; i < rank; ++i) {
+        indices.push_back(zeroIndex);
+      }
+
+      // Create a subscript operation to get the element at index [0, 0, ...,
+      // 0].
+      emitc::SubscriptOp subPtr = rewriter.create<emitc::SubscriptOp>(
+          loc, arrayValue, mlir::ValueRange(indices));
+
+      // Create an apply operation to take the address of the subscripted
+      // element.
+      emitc::ApplyOp ptr = rewriter.create<emitc::ApplyOp>(
+          loc, emitc::PointerType::get(memrefType.getElementType()),
+          rewriter.getStringAttr("&"), subPtr);
+
+      return ptr;
+    };
+
+    // Create a constant zero index.
     emitc::ConstantOp zeroIndex = rewriter.create<emitc::ConstantOp>(
         loc, rewriter.getIndexType(), rewriter.getIndexAttr(0));
+
     auto srcArrayValue =
         dyn_cast<TypedValue<emitc::ArrayType>>(operands.getSource());
+    emitc::ApplyOp srcPtr = createPointerFromEmitcArray(
+        loc, rewriter, srcArrayValue, srcMemrefType, zeroIndex);
 
-    emitc::SubscriptOp srcSubPtr = rewriter.create<emitc::SubscriptOp>(
-        loc, srcArrayValue, ValueRange{zeroIndex, zeroIndex});
-    emitc::ApplyOp srcPtr = rewriter.create<emitc::ApplyOp>(
-        loc, emitc::PointerType::get(srcMemrefType.getElementType()),
-        rewriter.getStringAttr("&"), srcSubPtr);
-
-    auto arrayValue =
+    auto targetArrayValue =
         dyn_cast<TypedValue<emitc::ArrayType>>(operands.getTarget());
-    emitc::SubscriptOp targetSubPtr = rewriter.create<emitc::SubscriptOp>(
-        loc, arrayValue, ValueRange{zeroIndex, zeroIndex});
-    emitc::ApplyOp targetPtr = rewriter.create<emitc::ApplyOp>(
-        loc, emitc::PointerType::get(targetMemrefType.getElementType()),
-        rewriter.getStringAttr("&"), targetSubPtr);
+    emitc::ApplyOp targetPtr = createPointerFromEmitcArray(
+        loc, rewriter, targetArrayValue, targetMemrefType, zeroIndex);
 
+    OpBuilder builder = rewriter;
     emitc::CallOpaqueOp memCpyCall = rewriter.create<emitc::CallOpaqueOp>(
         loc, TypeRange{}, "memcpy",
-        ValueRange{
-            targetPtr.getResult(), srcPtr.getResult(),
-            calculateMemrefTotalSizeBytes(loc, srcMemrefType, rewriter)});
+        ValueRange{targetPtr.getResult(), srcPtr.getResult(),
+                   calculateMemrefTotalSizeBytes(loc, srcMemrefType, builder)});
 
     rewriter.replaceOp(copyOp, memCpyCall.getResults());
 
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
index 8e965b42f1043..c60e7488fdb38 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
@@ -29,8 +29,8 @@ using namespace mlir;
 
 namespace {
 
-emitc::IncludeOp addHeader(OpBuilder &builder, ModuleOp module,
-                           StringRef headerName) {
+emitc::IncludeOp addStandardHeader(OpBuilder &builder, ModuleOp module,
+                                   StringRef headerName) {
   StringAttr includeAttr = builder.getStringAttr(headerName);
   return builder.create<emitc::IncludeOp>(
       module.getLoc(), includeAttr,
@@ -68,33 +68,32 @@ struct ConvertMemRefToEmitCPass
     module.walk([&](mlir::emitc::CallOpaqueOp callOp) {
       if (callOp.getCallee() != alignedAllocFunctionName &&
           callOp.getCallee() != mallocFunctionName &&
-          callOp.getCallee() != memcpyFunctionName) {
+          callOp.getCallee() != memcpyFunctionName)
         return mlir::WalkResult::advance();
-      }
 
       for (auto &op : *module.getBody()) {
         emitc::IncludeOp includeOp = llvm::dyn_cast<mlir::emitc::IncludeOp>(op);
-        if (!includeOp) {
+        if (!includeOp)
           continue;
-        }
+
         if (includeOp.getIsStandardInclude() &&
             ((options.lowerToCpp &&
               includeOp.getInclude() == cppStandardLibraryHeader) ||
              (!options.lowerToCpp &&
-              includeOp.getInclude() == cStandardLibraryHeader))) {
+              includeOp.getInclude() == cStandardLibraryHeader)))
           return mlir::WalkResult::interrupt();
-        }
       }
 
       mlir::OpBuilder builder(module.getBody(), module.getBody()->begin());
       StringRef headerName;
       if (callOp.getCallee() == memcpyFunctionName)
-        headerName = stringLibraryHeader;
+        headerName =
+            options.lowerToCpp ? cppStringLibraryHeader : cStringLibraryHeader;
       else
         headerName = options.lowerToCpp ? cppStandardLibraryHeader
                                         : cStandardLibraryHeader;
 
-      addHeader(builder, module, headerName);
+      addStandardHeader(builder, module, headerName);
       return mlir::WalkResult::interrupt();
     });
   }

>From a7a2d01ed0e2e2b9751d22af6cc329d0852ec80d Mon Sep 17 00:00:00 2001
From: Jaddyen <ajaden at google.com>
Date: Mon, 4 Aug 2025 18:19:32 +0000
Subject: [PATCH 4/9] update test file

---
 .../MemRefToEmitC/MemRefToEmitC.cpp           |  7 -----
 .../MemRefToEmitC/memref-to-emitc-copy.mlir   | 26 ++++++++++---------
 2 files changed, 14 insertions(+), 19 deletions(-)

diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index cabbfac4a1dca..4130e9be88a89 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -210,20 +210,14 @@ struct ConvertCopy final : public OpConversionPattern<memref::CopyOp> {
             mlir::TypedValue<emitc::ArrayType> arrayValue,
             mlir::MemRefType memrefType,
             emitc::ConstantOp zeroIndex) -> emitc::ApplyOp {
-      // Get the rank of the array to create the correct number of zero indices.
       int64_t rank = arrayValue.getType().getRank();
       llvm::SmallVector<mlir::Value> indices;
       for (int i = 0; i < rank; ++i) {
         indices.push_back(zeroIndex);
       }
 
-      // Create a subscript operation to get the element at index [0, 0, ...,
-      // 0].
       emitc::SubscriptOp subPtr = rewriter.create<emitc::SubscriptOp>(
           loc, arrayValue, mlir::ValueRange(indices));
-
-      // Create an apply operation to take the address of the subscripted
-      // element.
       emitc::ApplyOp ptr = rewriter.create<emitc::ApplyOp>(
           loc, emitc::PointerType::get(memrefType.getElementType()),
           rewriter.getStringAttr("&"), subPtr);
@@ -231,7 +225,6 @@ struct ConvertCopy final : public OpConversionPattern<memref::CopyOp> {
       return ptr;
     };
 
-    // Create a constant zero index.
     emitc::ConstantOp zeroIndex = rewriter.create<emitc::ConstantOp>(
         loc, rewriter.getIndexType(), rewriter.getIndexAttr(0));
 
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir
index 4b6eb50807513..88325e57762d3 100644
--- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir
@@ -1,23 +1,25 @@
 // RUN: mlir-opt -convert-memref-to-emitc %s  | FileCheck %s
 
-func.func @copying(%arg0 : memref<2x4xf32>) {
-  memref.copy %arg0, %arg0 : memref<2x4xf32> to memref<2x4xf32>
+func.func @copying(%arg0 : memref<9x4x5x7xf32>, %arg1 : memref<9x4x5x7xf32>) {
+  memref.copy %arg0, %arg1 : memref<9x4x5x7xf32> to memref<9x4x5x7xf32>
   return
 }
 
 // CHECK: module {
 // CHECK-NEXT:  emitc.include <"string.h">
 // CHECK-LABEL:  copying
-// CHECK-NEXT: %0 = builtin.unrealized_conversion_cast %arg0 : memref<2x4xf32> to !emitc.array<2x4xf32>
-// CHECK-NEXT: %1 = "emitc.constant"() <{value = 0 : index}> : () -> index
-// CHECK-NEXT: %2 = emitc.subscript %0[%1, %1] : (!emitc.array<2x4xf32>, index, index) -> !emitc.lvalue<f32>
-// CHECK-NEXT: %3 = emitc.apply "&"(%2) : (!emitc.lvalue<f32>) -> !emitc.ptr<f32>
-// CHECK-NEXT: %4 = emitc.subscript %0[%1, %1] : (!emitc.array<2x4xf32>, index, index) -> !emitc.lvalue<f32>
-// CHECK-NEXT: %5 = emitc.apply "&"(%4) : (!emitc.lvalue<f32>) -> !emitc.ptr<f32>
-// CHECK-NEXT: %6 = emitc.call_opaque "sizeof"() {args = [f32]} : () -> !emitc.size_t
-// CHECK-NEXT: %7 = "emitc.constant"() <{value = 8 : index}> : () -> index
-// CHECK-NEXT: %8 = emitc.mul %6, %7 : (!emitc.size_t, index) -> !emitc.size_t
-// CHECK-NEXT: emitc.call_opaque "memcpy"(%5, %3, %8) : (!emitc.ptr<f32>, !emitc.ptr<f32>, !emitc.size_t) -> ()
+// CHECK-SAME: %[[arg0:.*]]: memref<9x4x5x7xf32>, %[[arg1:.*]]: memref<9x4x5x7xf32>
+// CHECK-NEXT: %0 = builtin.unrealized_conversion_cast %arg1 : memref<9x4x5x7xf32> to !emitc.array<9x4x5x7xf32>
+// CHECK-NEXT: %1 = builtin.unrealized_conversion_cast %arg0 : memref<9x4x5x7xf32> to !emitc.array<9x4x5x7xf32>
+// CHECK-NEXT: %2 = "emitc.constant"() <{value = 0 : index}> : () -> index
+// CHECK-NEXT: %3 = emitc.subscript %1[%2, %2, %2, %2] : (!emitc.array<9x4x5x7xf32>, index, index, index, index) -> !emitc.lvalue<f32>
+// CHECK-NEXT: %4 = emitc.apply "&"(%3) : (!emitc.lvalue<f32>) -> !emitc.ptr<f32>
+// CHECK-NEXT: %5 = emitc.subscript %0[%2, %2, %2, %2] : (!emitc.array<9x4x5x7xf32>, index, index, index, index) -> !emitc.lvalue<f32>
+// CHECK-NEXT: %6 = emitc.apply "&"(%5) : (!emitc.lvalue<f32>) -> !emitc.ptr<f32>
+// CHECK-NEXT: %7 = emitc.call_opaque "sizeof"() {args = [f32]} : () -> !emitc.size_t
+// CHECK-NEXT: %8 = "emitc.constant"() <{value = 1260 : index}> : () -> index
+// CHECK-NEXT: %9 = emitc.mul %7, %8 : (!emitc.size_t, index) -> !emitc.size_t
+// CHECK-NEXT: emitc.call_opaque "memcpy"(%6, %4, %9) : (!emitc.ptr<f32>, !emitc.ptr<f32>, !emitc.size_t) -> ()
 // CHECK-NEXT:    return
 // CHECK-NEXT:  }
 // CHECK-NEXT:}

>From 12e967e9312bdffa74393ae4ec357e25a5df22dd Mon Sep 17 00:00:00 2001
From: Jaddyen <ajaden at google.com>
Date: Tue, 5 Aug 2025 17:44:17 +0000
Subject: [PATCH 5/9] test cpp output

---
 .../MemRefToEmitC/MemRefToEmitC.cpp           | 37 ++++++------
 .../MemRefToEmitC/MemRefToEmitCPass.cpp       | 15 +++--
 .../MemRefToEmitC/memref-to-emitc-copy.mlir   | 57 ++++++++++++-------
 3 files changed, 65 insertions(+), 44 deletions(-)

diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index 4130e9be88a89..c8124d2f16943 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -196,20 +196,20 @@ struct ConvertCopy final : public OpConversionPattern<memref::CopyOp> {
     MemRefType targetMemrefType =
         cast<MemRefType>(copyOp.getTarget().getType());
 
-    if (!isMemRefTypeLegalForEmitC(srcMemrefType)) {
+    if (!isMemRefTypeLegalForEmitC(srcMemrefType))
       return rewriter.notifyMatchFailure(
           loc, "incompatible source memref type for EmitC conversion");
-    }
-    if (!isMemRefTypeLegalForEmitC(targetMemrefType)) {
+
+    if (!isMemRefTypeLegalForEmitC(targetMemrefType))
       return rewriter.notifyMatchFailure(
           loc, "incompatible target memref type for EmitC conversion");
-    }
+
+    emitc::ConstantOp zeroIndex = rewriter.create<emitc::ConstantOp>(
+        loc, rewriter.getIndexType(), rewriter.getIndexAttr(0));
 
     auto createPointerFromEmitcArray =
-        [&](mlir::Location loc, mlir::OpBuilder &rewriter,
-            mlir::TypedValue<emitc::ArrayType> arrayValue,
-            mlir::MemRefType memrefType,
-            emitc::ConstantOp zeroIndex) -> emitc::ApplyOp {
+        [loc, &rewriter, &zeroIndex](
+            mlir::TypedValue<emitc::ArrayType> arrayValue) -> emitc::ApplyOp {
       int64_t rank = arrayValue.getType().getRank();
       llvm::SmallVector<mlir::Value> indices;
       for (int i = 0; i < rank; ++i) {
@@ -219,30 +219,25 @@ struct ConvertCopy final : public OpConversionPattern<memref::CopyOp> {
       emitc::SubscriptOp subPtr = rewriter.create<emitc::SubscriptOp>(
           loc, arrayValue, mlir::ValueRange(indices));
       emitc::ApplyOp ptr = rewriter.create<emitc::ApplyOp>(
-          loc, emitc::PointerType::get(memrefType.getElementType()),
+          loc, emitc::PointerType::get(arrayValue.getType().getElementType()),
           rewriter.getStringAttr("&"), subPtr);
 
       return ptr;
     };
 
-    emitc::ConstantOp zeroIndex = rewriter.create<emitc::ConstantOp>(
-        loc, rewriter.getIndexType(), rewriter.getIndexAttr(0));
-
     auto srcArrayValue =
-        dyn_cast<TypedValue<emitc::ArrayType>>(operands.getSource());
-    emitc::ApplyOp srcPtr = createPointerFromEmitcArray(
-        loc, rewriter, srcArrayValue, srcMemrefType, zeroIndex);
+        cast<TypedValue<emitc::ArrayType>>(operands.getSource());
+    emitc::ApplyOp srcPtr = createPointerFromEmitcArray(srcArrayValue);
 
     auto targetArrayValue =
-        dyn_cast<TypedValue<emitc::ArrayType>>(operands.getTarget());
-    emitc::ApplyOp targetPtr = createPointerFromEmitcArray(
-        loc, rewriter, targetArrayValue, targetMemrefType, zeroIndex);
+        cast<TypedValue<emitc::ArrayType>>(operands.getTarget());
+    emitc::ApplyOp targetPtr = createPointerFromEmitcArray(targetArrayValue);
 
-    OpBuilder builder = rewriter;
     emitc::CallOpaqueOp memCpyCall = rewriter.create<emitc::CallOpaqueOp>(
         loc, TypeRange{}, "memcpy",
-        ValueRange{targetPtr.getResult(), srcPtr.getResult(),
-                   calculateMemrefTotalSizeBytes(loc, srcMemrefType, builder)});
+        ValueRange{
+            targetPtr.getResult(), srcPtr.getResult(),
+            calculateMemrefTotalSizeBytes(loc, srcMemrefType, rewriter)});
 
     rewriter.replaceOp(copyOp, memCpyCall.getResults());
 
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
index c60e7488fdb38..3ffff9fca106a 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
@@ -37,6 +37,16 @@ emitc::IncludeOp addStandardHeader(OpBuilder &builder, ModuleOp module,
       /*is_standard_include=*/builder.getUnitAttr());
 }
 
+bool isExpectedStandardInclude(ConvertMemRefToEmitCOptions options,
+                               emitc::IncludeOp includeOp) {
+  return ((options.lowerToCpp &&
+           (includeOp.getInclude() == cppStandardLibraryHeader ||
+            includeOp.getInclude() == cppStringLibraryHeader)) ||
+          (!options.lowerToCpp &&
+           (includeOp.getInclude() == cStandardLibraryHeader ||
+            includeOp.getInclude() == cStringLibraryHeader)));
+}
+
 struct ConvertMemRefToEmitCPass
     : public impl::ConvertMemRefToEmitCBase<ConvertMemRefToEmitCPass> {
   using Base::Base;
@@ -77,10 +87,7 @@ struct ConvertMemRefToEmitCPass
           continue;
 
         if (includeOp.getIsStandardInclude() &&
-            ((options.lowerToCpp &&
-              includeOp.getInclude() == cppStandardLibraryHeader) ||
-             (!options.lowerToCpp &&
-              includeOp.getInclude() == cStandardLibraryHeader)))
+            isExpectedStandardInclude(options, includeOp))
           return mlir::WalkResult::interrupt();
       }
 
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir
index 88325e57762d3..1b515ba02dd46 100644
--- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir
@@ -1,26 +1,45 @@
-// RUN: mlir-opt -convert-memref-to-emitc %s  | FileCheck %s
+// RUN: mlir-opt -convert-memref-to-emitc="lower-to-cpp=true" %s -split-input-file | FileCheck %s --check-prefix=CPP
+// RUN: mlir-opt -convert-memref-to-emitc="lower-to-cpp=false" %s -split-input-file | FileCheck %s --check-prefix=NOCPP
 
 func.func @copying(%arg0 : memref<9x4x5x7xf32>, %arg1 : memref<9x4x5x7xf32>) {
   memref.copy %arg0, %arg1 : memref<9x4x5x7xf32> to memref<9x4x5x7xf32>
   return
 }
 
-// CHECK: module {
-// CHECK-NEXT:  emitc.include <"string.h">
-// CHECK-LABEL:  copying
-// CHECK-SAME: %[[arg0:.*]]: memref<9x4x5x7xf32>, %[[arg1:.*]]: memref<9x4x5x7xf32>
-// CHECK-NEXT: %0 = builtin.unrealized_conversion_cast %arg1 : memref<9x4x5x7xf32> to !emitc.array<9x4x5x7xf32>
-// CHECK-NEXT: %1 = builtin.unrealized_conversion_cast %arg0 : memref<9x4x5x7xf32> to !emitc.array<9x4x5x7xf32>
-// CHECK-NEXT: %2 = "emitc.constant"() <{value = 0 : index}> : () -> index
-// CHECK-NEXT: %3 = emitc.subscript %1[%2, %2, %2, %2] : (!emitc.array<9x4x5x7xf32>, index, index, index, index) -> !emitc.lvalue<f32>
-// CHECK-NEXT: %4 = emitc.apply "&"(%3) : (!emitc.lvalue<f32>) -> !emitc.ptr<f32>
-// CHECK-NEXT: %5 = emitc.subscript %0[%2, %2, %2, %2] : (!emitc.array<9x4x5x7xf32>, index, index, index, index) -> !emitc.lvalue<f32>
-// CHECK-NEXT: %6 = emitc.apply "&"(%5) : (!emitc.lvalue<f32>) -> !emitc.ptr<f32>
-// CHECK-NEXT: %7 = emitc.call_opaque "sizeof"() {args = [f32]} : () -> !emitc.size_t
-// CHECK-NEXT: %8 = "emitc.constant"() <{value = 1260 : index}> : () -> index
-// CHECK-NEXT: %9 = emitc.mul %7, %8 : (!emitc.size_t, index) -> !emitc.size_t
-// CHECK-NEXT: emitc.call_opaque "memcpy"(%6, %4, %9) : (!emitc.ptr<f32>, !emitc.ptr<f32>, !emitc.size_t) -> ()
-// CHECK-NEXT:    return
-// CHECK-NEXT:  }
-// CHECK-NEXT:}
+// NOCPP: module {
+// NOCPP-NEXT:  emitc.include <"string.h">
+// NOCPP-LABEL:  copying
+// NOCPP-SAME: %[[arg0:.*]]: memref<9x4x5x7xf32>, %[[arg1:.*]]: memref<9x4x5x7xf32>
+// NOCPP-NEXT: %0 = builtin.unrealized_conversion_cast %arg1 : memref<9x4x5x7xf32> to !emitc.array<9x4x5x7xf32>
+// NOCPP-NEXT: %1 = builtin.unrealized_conversion_cast %arg0 : memref<9x4x5x7xf32> to !emitc.array<9x4x5x7xf32>
+// NOCPP-NEXT: %2 = "emitc.constant"() <{value = 0 : index}> : () -> index
+// NOCPP-NEXT: %3 = emitc.subscript %1[%2, %2, %2, %2] : (!emitc.array<9x4x5x7xf32>, index, index, index, index) -> !emitc.lvalue<f32>
+// NOCPP-NEXT: %4 = emitc.apply "&"(%3) : (!emitc.lvalue<f32>) -> !emitc.ptr<f32>
+// NOCPP-NEXT: %5 = emitc.subscript %0[%2, %2, %2, %2] : (!emitc.array<9x4x5x7xf32>, index, index, index, index) -> !emitc.lvalue<f32>
+// NOCPP-NEXT: %6 = emitc.apply "&"(%5) : (!emitc.lvalue<f32>) -> !emitc.ptr<f32>
+// NOCPP-NEXT: %7 = emitc.call_opaque "sizeof"() {args = [f32]} : () -> !emitc.size_t
+// NOCPP-NEXT: %8 = "emitc.constant"() <{value = 1260 : index}> : () -> index
+// NOCPP-NEXT: %9 = emitc.mul %7, %8 : (!emitc.size_t, index) -> !emitc.size_t
+// NOCPP-NEXT: emitc.call_opaque "memcpy"(%6, %4, %9) : (!emitc.ptr<f32>, !emitc.ptr<f32>, !emitc.size_t) -> ()
+// NOCPP-NEXT:    return
+// NOCPP-NEXT:  }
+// NOCPP-NEXT:}
 
+// CPP: module {
+// CPP-NEXT:  emitc.include <"cstring">
+// CPP-LABEL:  copying
+// CPP-SAME: %[[arg0:.*]]: memref<9x4x5x7xf32>, %[[arg1:.*]]: memref<9x4x5x7xf32>
+// CPP-NEXT: %0 = builtin.unrealized_conversion_cast %arg1 : memref<9x4x5x7xf32> to !emitc.array<9x4x5x7xf32>
+// CPP-NEXT: %1 = builtin.unrealized_conversion_cast %arg0 : memref<9x4x5x7xf32> to !emitc.array<9x4x5x7xf32>
+// CPP-NEXT: %2 = "emitc.constant"() <{value = 0 : index}> : () -> index
+// CPP-NEXT: %3 = emitc.subscript %1[%2, %2, %2, %2] : (!emitc.array<9x4x5x7xf32>, index, index, index, index) -> !emitc.lvalue<f32>
+// CPP-NEXT: %4 = emitc.apply "&"(%3) : (!emitc.lvalue<f32>) -> !emitc.ptr<f32>
+// CPP-NEXT: %5 = emitc.subscript %0[%2, %2, %2, %2] : (!emitc.array<9x4x5x7xf32>, index, index, index, index) -> !emitc.lvalue<f32>
+// CPP-NEXT: %6 = emitc.apply "&"(%5) : (!emitc.lvalue<f32>) -> !emitc.ptr<f32>
+// CPP-NEXT: %7 = emitc.call_opaque "sizeof"() {args = [f32]} : () -> !emitc.size_t
+// CPP-NEXT: %8 = "emitc.constant"() <{value = 1260 : index}> : () -> index
+// CPP-NEXT: %9 = emitc.mul %7, %8 : (!emitc.size_t, index) -> !emitc.size_t
+// CPP-NEXT: emitc.call_opaque "memcpy"(%6, %4, %9) : (!emitc.ptr<f32>, !emitc.ptr<f32>, !emitc.size_t) -> ()
+// CPP-NEXT:    return
+// CPP-NEXT:  }
+// CPP-NEXT:}

>From 5870e6b3fb21e0f644bd237cbb6b8f30f28cdb73 Mon Sep 17 00:00:00 2001
From: Jaddyen <ajaden at google.com>
Date: Thu, 7 Aug 2025 04:23:15 +0000
Subject: [PATCH 6/9] ensure both headers are added

---
 .../MemRefToEmitC/MemRefToEmitCPass.cpp       | 57 ++++++++-----------
 .../memref-to-emitc-alloc-copy.mlir           | 30 ++++++++++
 .../MemRefToEmitC/memref-to-emitc-copy.mlir   | 48 ++++++----------
 3 files changed, 70 insertions(+), 65 deletions(-)
 create mode 100644 mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-alloc-copy.mlir

diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
index 3ffff9fca106a..5469949311879 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
@@ -37,16 +37,6 @@ emitc::IncludeOp addStandardHeader(OpBuilder &builder, ModuleOp module,
       /*is_standard_include=*/builder.getUnitAttr());
 }
 
-bool isExpectedStandardInclude(ConvertMemRefToEmitCOptions options,
-                               emitc::IncludeOp includeOp) {
-  return ((options.lowerToCpp &&
-           (includeOp.getInclude() == cppStandardLibraryHeader ||
-            includeOp.getInclude() == cppStringLibraryHeader)) ||
-          (!options.lowerToCpp &&
-           (includeOp.getInclude() == cStandardLibraryHeader ||
-            includeOp.getInclude() == cStringLibraryHeader)));
-}
-
 struct ConvertMemRefToEmitCPass
     : public impl::ConvertMemRefToEmitCBase<ConvertMemRefToEmitCPass> {
   using Base::Base;
@@ -75,34 +65,33 @@ struct ConvertMemRefToEmitCPass
       return signalPassFailure();
 
     mlir::ModuleOp module = getOperation();
+    llvm::SmallVector<StringRef> requiredHeaders;
     module.walk([&](mlir::emitc::CallOpaqueOp callOp) {
-      if (callOp.getCallee() != alignedAllocFunctionName &&
-          callOp.getCallee() != mallocFunctionName &&
-          callOp.getCallee() != memcpyFunctionName)
-        return mlir::WalkResult::advance();
-
-      for (auto &op : *module.getBody()) {
-        emitc::IncludeOp includeOp = llvm::dyn_cast<mlir::emitc::IncludeOp>(op);
-        if (!includeOp)
-          continue;
-
-        if (includeOp.getIsStandardInclude() &&
-            isExpectedStandardInclude(options, includeOp))
-          return mlir::WalkResult::interrupt();
-      }
-
-      mlir::OpBuilder builder(module.getBody(), module.getBody()->begin());
-      StringRef headerName;
-      if (callOp.getCallee() == memcpyFunctionName)
-        headerName =
+      StringRef expectedHeader;
+      if (callOp.getCallee() == alignedAllocFunctionName ||
+          callOp.getCallee() == mallocFunctionName)
+        expectedHeader = options.lowerToCpp ? cppStandardLibraryHeader
+                                            : cStandardLibraryHeader;
+      else if (callOp.getCallee() == memcpyFunctionName)
+        expectedHeader =
             options.lowerToCpp ? cppStringLibraryHeader : cStringLibraryHeader;
       else
-        headerName = options.lowerToCpp ? cppStandardLibraryHeader
-                                        : cStandardLibraryHeader;
-
-      addStandardHeader(builder, module, headerName);
-      return mlir::WalkResult::interrupt();
+        return mlir::WalkResult::advance();
+      requiredHeaders.push_back(expectedHeader);
+      return mlir::WalkResult::advance();
     });
+    for (StringRef expectedHeader : requiredHeaders) {
+      bool headerFound = llvm::any_of(*module.getBody(), [&](Operation &op) {
+        auto includeOp = dyn_cast<mlir::emitc::IncludeOp>(op);
+        return includeOp && includeOp.getIsStandardInclude() &&
+               (includeOp.getInclude() == expectedHeader);
+      });
+
+      if (!headerFound) {
+        mlir::OpBuilder builder(module.getBody(), module.getBody()->begin());
+        addStandardHeader(builder, module, expectedHeader);
+      }
+    }
   }
 };
 } // namespace
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-alloc-copy.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-alloc-copy.mlir
new file mode 100644
index 0000000000000..2e0ff63715355
--- /dev/null
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-alloc-copy.mlir
@@ -0,0 +1,30 @@
+// RUN: mlir-opt -convert-memref-to-emitc="lower-to-cpp=true" %s -split-input-file | FileCheck %s --check-prefix=CPP
+// RUN: mlir-opt -convert-memref-to-emitc="lower-to-cpp=false" %s -split-input-file | FileCheck %s --check-prefix=NOCPP
+
+
+func.func @alloc_copy(%arg0: memref<999xi32>) {
+  %alloc = memref.alloc() : memref<999xi32>
+  memref.copy %arg0, %alloc : memref<999xi32> to memref<999xi32>
+  return
+} 
+
+// NOCPP: module {
+// NOCPP-NEXT:  emitc.include <"string.h">
+// NOCPP-NEXT:  emitc.include <"stdlib.h">
+
+// CPP: module {
+// CPP-NEXT:  emitc.include <"cstring">
+// CHECK-NEXT: emitc.include <"cstdlib">
+// CHECK-LABEL: alloc_copy
+// CHECK-SAME: %[[arg0:.*]]: memref<999xi32>
+// CHECK-NEXT:  %0 = builtin.unrealized_conversion_cast %arg0 : memref<999xi32> to !emitc.array<999xi32> 
+// CHECK-NEXT:  emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t 
+// CHECK-NEXT:  "emitc.constant"() <{value = 999 : index}> : () -> index 
+// CHECK-NEXT:  emitc.mul %1, %2 : (!emitc.size_t, index) -> !emitc.size_t 
+// CHECK-NEXT:  emitc.call_opaque "malloc"(%3) : (!emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void">> 
+// CHECK-NEXT:  emitc.cast %4 : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<i32> 
+// CHECK-NEXT:  builtin.unrealized_conversion_cast %5 : !emitc.ptr<i32> to !emitc.array<999xi32> 
+// CHECK-NEXT:  "emitc.constant"() <{value = 0 : index}> : () -> index 
+// CHECK-NEXT:  emitc.subscript %0[%7] : (!emitc.array<999xi32>, index) -> !emitc.lvalue<i32> 
+// CHECK-NEXT:  emitc.apply "&"(%8) : (!emitc.lvalue<i32>) -> !emitc.ptr<i32> 
+
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir
index 1b515ba02dd46..615dbfeb461a2 100644
--- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir
@@ -8,38 +8,24 @@ func.func @copying(%arg0 : memref<9x4x5x7xf32>, %arg1 : memref<9x4x5x7xf32>) {
 
 // NOCPP: module {
 // NOCPP-NEXT:  emitc.include <"string.h">
-// NOCPP-LABEL:  copying
-// NOCPP-SAME: %[[arg0:.*]]: memref<9x4x5x7xf32>, %[[arg1:.*]]: memref<9x4x5x7xf32>
-// NOCPP-NEXT: %0 = builtin.unrealized_conversion_cast %arg1 : memref<9x4x5x7xf32> to !emitc.array<9x4x5x7xf32>
-// NOCPP-NEXT: %1 = builtin.unrealized_conversion_cast %arg0 : memref<9x4x5x7xf32> to !emitc.array<9x4x5x7xf32>
-// NOCPP-NEXT: %2 = "emitc.constant"() <{value = 0 : index}> : () -> index
-// NOCPP-NEXT: %3 = emitc.subscript %1[%2, %2, %2, %2] : (!emitc.array<9x4x5x7xf32>, index, index, index, index) -> !emitc.lvalue<f32>
-// NOCPP-NEXT: %4 = emitc.apply "&"(%3) : (!emitc.lvalue<f32>) -> !emitc.ptr<f32>
-// NOCPP-NEXT: %5 = emitc.subscript %0[%2, %2, %2, %2] : (!emitc.array<9x4x5x7xf32>, index, index, index, index) -> !emitc.lvalue<f32>
-// NOCPP-NEXT: %6 = emitc.apply "&"(%5) : (!emitc.lvalue<f32>) -> !emitc.ptr<f32>
-// NOCPP-NEXT: %7 = emitc.call_opaque "sizeof"() {args = [f32]} : () -> !emitc.size_t
-// NOCPP-NEXT: %8 = "emitc.constant"() <{value = 1260 : index}> : () -> index
-// NOCPP-NEXT: %9 = emitc.mul %7, %8 : (!emitc.size_t, index) -> !emitc.size_t
-// NOCPP-NEXT: emitc.call_opaque "memcpy"(%6, %4, %9) : (!emitc.ptr<f32>, !emitc.ptr<f32>, !emitc.size_t) -> ()
-// NOCPP-NEXT:    return
-// NOCPP-NEXT:  }
-// NOCPP-NEXT:}
+
 
 // CPP: module {
 // CPP-NEXT:  emitc.include <"cstring">
 // CPP-LABEL:  copying
-// CPP-SAME: %[[arg0:.*]]: memref<9x4x5x7xf32>, %[[arg1:.*]]: memref<9x4x5x7xf32>
-// CPP-NEXT: %0 = builtin.unrealized_conversion_cast %arg1 : memref<9x4x5x7xf32> to !emitc.array<9x4x5x7xf32>
-// CPP-NEXT: %1 = builtin.unrealized_conversion_cast %arg0 : memref<9x4x5x7xf32> to !emitc.array<9x4x5x7xf32>
-// CPP-NEXT: %2 = "emitc.constant"() <{value = 0 : index}> : () -> index
-// CPP-NEXT: %3 = emitc.subscript %1[%2, %2, %2, %2] : (!emitc.array<9x4x5x7xf32>, index, index, index, index) -> !emitc.lvalue<f32>
-// CPP-NEXT: %4 = emitc.apply "&"(%3) : (!emitc.lvalue<f32>) -> !emitc.ptr<f32>
-// CPP-NEXT: %5 = emitc.subscript %0[%2, %2, %2, %2] : (!emitc.array<9x4x5x7xf32>, index, index, index, index) -> !emitc.lvalue<f32>
-// CPP-NEXT: %6 = emitc.apply "&"(%5) : (!emitc.lvalue<f32>) -> !emitc.ptr<f32>
-// CPP-NEXT: %7 = emitc.call_opaque "sizeof"() {args = [f32]} : () -> !emitc.size_t
-// CPP-NEXT: %8 = "emitc.constant"() <{value = 1260 : index}> : () -> index
-// CPP-NEXT: %9 = emitc.mul %7, %8 : (!emitc.size_t, index) -> !emitc.size_t
-// CPP-NEXT: emitc.call_opaque "memcpy"(%6, %4, %9) : (!emitc.ptr<f32>, !emitc.ptr<f32>, !emitc.size_t) -> ()
-// CPP-NEXT:    return
-// CPP-NEXT:  }
-// CPP-NEXT:}
+// CHECK-SAME: %[[arg0:.*]]: memref<9x4x5x7xf32>, %[[arg1:.*]]: memref<9x4x5x7xf32>
+// CHECK-NEXT: %0 = builtin.unrealized_conversion_cast %arg1 : memref<9x4x5x7xf32> to !emitc.array<9x4x5x7xf32>
+// CHECK-NEXT: %1 = builtin.unrealized_conversion_cast %arg0 : memref<9x4x5x7xf32> to !emitc.array<9x4x5x7xf32>
+// CHECK-NEXT: %2 = "emitc.constant"() <{value = 0 : index}> : () -> index
+// CHECK-NEXT: %3 = emitc.subscript %1[%2, %2, %2, %2] : (!emitc.array<9x4x5x7xf32>, index, index, index, index) -> !emitc.lvalue<f32>
+// CHECK-NEXT: %4 = emitc.apply "&"(%3) : (!emitc.lvalue<f32>) -> !emitc.ptr<f32>
+// CHECK-NEXT: %5 = emitc.subscript %0[%2, %2, %2, %2] : (!emitc.array<9x4x5x7xf32>, index, index, index, index) -> !emitc.lvalue<f32>
+// CHECK-NEXT: %6 = emitc.apply "&"(%5) : (!emitc.lvalue<f32>) -> !emitc.ptr<f32>
+// CHECK-NEXT: %7 = emitc.call_opaque "sizeof"() {args = [f32]} : () -> !emitc.size_t
+// CHECK-NEXT: %8 = "emitc.constant"() <{value = 1260 : index}> : () -> index
+// CHECK-NEXT: %9 = emitc.mul %7, %8 : (!emitc.size_t, index) -> !emitc.size_t
+// CHECK-NEXT: emitc.call_opaque "memcpy"(%6, %4, %9) : (!emitc.ptr<f32>, !emitc.ptr<f32>, !emitc.size_t) -> ()
+// CHECK-NEXT:    return
+// CHECK-NEXT:  }
+// CHECK-NEXT:}
+

>From 3663062675a375ef4ab8c7a9d050494c7d5203e0 Mon Sep 17 00:00:00 2001
From: Jaddyen <ajaden at google.com>
Date: Thu, 7 Aug 2025 04:26:59 +0000
Subject: [PATCH 7/9] update test file

---
 .../Conversion/MemRefToEmitC/memref-to-emitc-alloc-copy.mlir    | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-alloc-copy.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-alloc-copy.mlir
index 2e0ff63715355..a1fa58803e15a 100644
--- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-alloc-copy.mlir
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-alloc-copy.mlir
@@ -14,7 +14,7 @@ func.func @alloc_copy(%arg0: memref<999xi32>) {
 
 // CPP: module {
 // CPP-NEXT:  emitc.include <"cstring">
-// CHECK-NEXT: emitc.include <"cstdlib">
+// CPP-NEXT:  emitc.include <"cstdlib">
 // CHECK-LABEL: alloc_copy
 // CHECK-SAME: %[[arg0:.*]]: memref<999xi32>
 // CHECK-NEXT:  %0 = builtin.unrealized_conversion_cast %arg0 : memref<999xi32> to !emitc.array<999xi32> 

>From ec252fcc01afe0f819caa9025270fe520187f3f4 Mon Sep 17 00:00:00 2001
From: Jaddyen <ajaden at google.com>
Date: Fri, 8 Aug 2025 00:15:33 +0000
Subject: [PATCH 8/9] ensure no duplicate headers

---
 .../MemRefToEmitC/MemRefToEmitC.cpp           | 17 ++++------
 .../MemRefToEmitC/MemRefToEmitCPass.cpp       | 26 +++++++-------
 .../memref-to-emitc-alloc-copy.mlir           | 34 +++++++++++++++----
 .../MemRefToEmitC/memref-to-emitc-copy.mlir   |  9 +++--
 4 files changed, 50 insertions(+), 36 deletions(-)

diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index c8124d2f16943..4e86188af719e 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -23,6 +23,7 @@
 #include "mlir/IR/Value.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include <cstdint>
+#include <numeric>
 
 using namespace mlir;
 
@@ -98,8 +99,8 @@ Type convertMemRefType(MemRefType opTy, const TypeConverter *typeConverter) {
   return resultTy;
 }
 
-Value calculateMemrefTotalSizeBytes(Location loc, MemRefType memrefType,
-                                    OpBuilder &builder) {
+static Value calculateMemrefTotalSizeBytes(Location loc, MemRefType memrefType,
+                                           OpBuilder &builder) {
   assert(isMemRefTypeLegalForEmitC(memrefType) &&
          "incompatible memref type for EmitC conversion");
   emitc::CallOpaqueOp elementSize = builder.create<emitc::CallOpaqueOp>(
@@ -109,10 +110,9 @@ Value calculateMemrefTotalSizeBytes(Location loc, MemRefType memrefType,
                      {TypeAttr::get(memrefType.getElementType())}));
 
   IndexType indexType = builder.getIndexType();
-  int64_t numElements = 1;
-  for (int64_t dimSize : memrefType.getShape()) {
-    numElements *= dimSize;
-  }
+  int64_t numElements = std::accumulate(memrefType.getShape().begin(),
+                                        memrefType.getShape().end(), int64_t{1},
+                                        std::multiplies<int64_t>());
   emitc::ConstantOp numElementsValue = builder.create<emitc::ConstantOp>(
       loc, indexType, builder.getIndexAttr(numElements));
 
@@ -211,10 +211,7 @@ struct ConvertCopy final : public OpConversionPattern<memref::CopyOp> {
         [loc, &rewriter, &zeroIndex](
             mlir::TypedValue<emitc::ArrayType> arrayValue) -> emitc::ApplyOp {
       int64_t rank = arrayValue.getType().getRank();
-      llvm::SmallVector<mlir::Value> indices;
-      for (int i = 0; i < rank; ++i) {
-        indices.push_back(zeroIndex);
-      }
+      llvm::SmallVector<mlir::Value> indices(rank, zeroIndex);
 
       emitc::SubscriptOp subPtr = rewriter.create<emitc::SubscriptOp>(
           loc, arrayValue, mlir::ValueRange(indices));
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
index 5469949311879..a51890248271f 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
@@ -18,6 +18,7 @@
 #include "mlir/IR/Attributes.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/SmallSet.h"
 #include "llvm/ADT/StringRef.h"
 
 namespace mlir {
@@ -65,7 +66,13 @@ struct ConvertMemRefToEmitCPass
       return signalPassFailure();
 
     mlir::ModuleOp module = getOperation();
-    llvm::SmallVector<StringRef> requiredHeaders;
+    llvm::SmallSet<StringRef, 4> existingHeaders;
+    mlir::OpBuilder builder(module.getBody(), module.getBody()->begin());
+    module.walk([&](mlir::emitc::IncludeOp includeOp) {
+      if (includeOp.getIsStandardInclude())
+        existingHeaders.insert(includeOp.getInclude());
+    });
+
     module.walk([&](mlir::emitc::CallOpaqueOp callOp) {
       StringRef expectedHeader;
       if (callOp.getCallee() == alignedAllocFunctionName ||
@@ -77,21 +84,12 @@ struct ConvertMemRefToEmitCPass
             options.lowerToCpp ? cppStringLibraryHeader : cStringLibraryHeader;
       else
         return mlir::WalkResult::advance();
-      requiredHeaders.push_back(expectedHeader);
-      return mlir::WalkResult::advance();
-    });
-    for (StringRef expectedHeader : requiredHeaders) {
-      bool headerFound = llvm::any_of(*module.getBody(), [&](Operation &op) {
-        auto includeOp = dyn_cast<mlir::emitc::IncludeOp>(op);
-        return includeOp && includeOp.getIsStandardInclude() &&
-               (includeOp.getInclude() == expectedHeader);
-      });
-
-      if (!headerFound) {
-        mlir::OpBuilder builder(module.getBody(), module.getBody()->begin());
+      if (!existingHeaders.contains(expectedHeader)) {
         addStandardHeader(builder, module, expectedHeader);
+        existingHeaders.insert(expectedHeader);
       }
-    }
+      return mlir::WalkResult::advance();
+    });
   }
 };
 } // namespace
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-alloc-copy.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-alloc-copy.mlir
index a1fa58803e15a..c1627a0d4d023 100644
--- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-alloc-copy.mlir
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-alloc-copy.mlir
@@ -1,23 +1,24 @@
 // RUN: mlir-opt -convert-memref-to-emitc="lower-to-cpp=true" %s -split-input-file | FileCheck %s --check-prefix=CPP
 // RUN: mlir-opt -convert-memref-to-emitc="lower-to-cpp=false" %s -split-input-file | FileCheck %s --check-prefix=NOCPP
 
-
 func.func @alloc_copy(%arg0: memref<999xi32>) {
   %alloc = memref.alloc() : memref<999xi32>
   memref.copy %arg0, %alloc : memref<999xi32> to memref<999xi32>
+  %alloc_1 = memref.alloc() : memref<999xi32>
+  memref.copy %arg0, %alloc_1 : memref<999xi32> to memref<999xi32>
   return
 } 
 
-// NOCPP: module {
+// CHECK: module {
+// NOCPP:  emitc.include <"stdlib.h">
 // NOCPP-NEXT:  emitc.include <"string.h">
-// NOCPP-NEXT:  emitc.include <"stdlib.h">
 
-// CPP: module {
+// CPP:  emitc.include <"cstdlib">
 // CPP-NEXT:  emitc.include <"cstring">
-// CPP-NEXT:  emitc.include <"cstdlib">
+
 // CHECK-LABEL: alloc_copy
 // CHECK-SAME: %[[arg0:.*]]: memref<999xi32>
-// CHECK-NEXT:  %0 = builtin.unrealized_conversion_cast %arg0 : memref<999xi32> to !emitc.array<999xi32> 
+// CHECK-NEXT:  builtin.unrealized_conversion_cast %arg0 : memref<999xi32> to !emitc.array<999xi32> 
 // CHECK-NEXT:  emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t 
 // CHECK-NEXT:  "emitc.constant"() <{value = 999 : index}> : () -> index 
 // CHECK-NEXT:  emitc.mul %1, %2 : (!emitc.size_t, index) -> !emitc.size_t 
@@ -27,4 +28,23 @@ func.func @alloc_copy(%arg0: memref<999xi32>) {
 // CHECK-NEXT:  "emitc.constant"() <{value = 0 : index}> : () -> index 
 // CHECK-NEXT:  emitc.subscript %0[%7] : (!emitc.array<999xi32>, index) -> !emitc.lvalue<i32> 
 // CHECK-NEXT:  emitc.apply "&"(%8) : (!emitc.lvalue<i32>) -> !emitc.ptr<i32> 
-
+// CHECK-NEXT:  emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t
+// CHECK-NEXT:  "emitc.constant"() <{value = 999 : index}> : () -> index
+// CHECK-NEXT:  emitc.mul %12, %13 : (!emitc.size_t, index) -> !emitc.size_t
+// CHECK-NEXT:  emitc.call_opaque "memcpy"(%11, %9, %14) : (!emitc.ptr<i32>, !emitc.ptr<i32>, !emitc.size_t) -> ()
+// CHECK-NEXT:  emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t
+// CHECK-NEXT:  "emitc.constant"() <{value = 999 : index}> : () -> index
+// CHECK-NEXT:  emitc.mul %15, %16 : (!emitc.size_t, index) -> !emitc.size_t
+// CHECK-NEXT:  emitc.call_opaque "malloc"(%17) : (!emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void">>
+// CHECK-NEXT:  emitc.cast %18 : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<i32>
+// CHECK-NEXT:  builtin.unrealized_conversion_cast %19 : !emitc.ptr<i32> to !emitc.array<999xi32>
+// CHECK-NEXT:  "emitc.constant"() <{value = 0 : index}> : () -> index
+// CHECK-NEXT:  emitc.subscript %0[%21] : (!emitc.array<999xi32>, index) -> !emitc.lvalue<i32>
+// CHECK-NEXT:  emitc.apply "&"(%22) : (!emitc.lvalue<i32>) -> !emitc.ptr<i32>
+// CHECK-NEXT:  emitc.subscript %20[%21] : (!emitc.array<999xi32>, index) -> !emitc.lvalue<i32>
+// CHECK-NEXT:  emitc.apply "&"(%24) : (!emitc.lvalue<i32>) -> !emitc.ptr<i32>
+// CHECK-NEXT:  emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t
+// CHECK-NEXT:  "emitc.constant"() <{value = 999 : index}> : () -> index
+// CHECK-NEXT:  emitc.mul %26, %27 : (!emitc.size_t, index) -> !emitc.size_t
+// CHECK-NEXT:  emitc.call_opaque "memcpy"(%25, %23, %28) : (!emitc.ptr<i32>, !emitc.ptr<i32>, !emitc.size_t) -> ()
+// CHECK-NEXT:    return
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir
index 615dbfeb461a2..6eb2c2db6a0b0 100644
--- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir
@@ -6,13 +6,12 @@ func.func @copying(%arg0 : memref<9x4x5x7xf32>, %arg1 : memref<9x4x5x7xf32>) {
   return
 }
 
-// NOCPP: module {
-// NOCPP-NEXT:  emitc.include <"string.h">
+// CHECK: module {
+// NOCPP:  emitc.include <"string.h">
 
+// CPP:  emitc.include <"cstring">
 
-// CPP: module {
-// CPP-NEXT:  emitc.include <"cstring">
-// CPP-LABEL:  copying
+// CHECK-LABEL:  copying
 // CHECK-SAME: %[[arg0:.*]]: memref<9x4x5x7xf32>, %[[arg1:.*]]: memref<9x4x5x7xf32>
 // CHECK-NEXT: %0 = builtin.unrealized_conversion_cast %arg1 : memref<9x4x5x7xf32> to !emitc.array<9x4x5x7xf32>
 // CHECK-NEXT: %1 = builtin.unrealized_conversion_cast %arg0 : memref<9x4x5x7xf32> to !emitc.array<9x4x5x7xf32>

>From de3c63d248fcb7d6ac25c93cd2b61aea1b527b96 Mon Sep 17 00:00:00 2001
From: Jaddyen <ajaden at google.com>
Date: Fri, 8 Aug 2025 21:24:51 +0000
Subject: [PATCH 9/9] refactoring

---
 .../MemRefToEmitC/MemRefToEmitC.cpp           | 43 ++++++++++---------
 .../MemRefToEmitC/memref-to-emitc-copy.mlir   |  1 -
 2 files changed, 23 insertions(+), 21 deletions(-)

diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index 4e86188af719e..c3b937fa41431 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -123,6 +123,25 @@ static Value calculateMemrefTotalSizeBytes(Location loc, MemRefType memrefType,
   return totalSizeBytes.getResult();
 }
 
+static emitc::ApplyOp
+createPointerFromEmitcArray(Location loc, OpBuilder &builder,
+                            TypedValue<emitc::ArrayType> arrayValue) {
+
+  emitc::ConstantOp zeroIndex = builder.create<emitc::ConstantOp>(
+      loc, builder.getIndexType(), builder.getIndexAttr(0));
+
+  int64_t rank = arrayValue.getType().getRank();
+  llvm::SmallVector<mlir::Value> indices(rank, zeroIndex);
+
+  emitc::SubscriptOp subPtr =
+      builder.create<emitc::SubscriptOp>(loc, arrayValue, ValueRange(indices));
+  emitc::ApplyOp ptr = builder.create<emitc::ApplyOp>(
+      loc, emitc::PointerType::get(arrayValue.getType().getElementType()),
+      builder.getStringAttr("&"), subPtr);
+
+  return ptr;
+}
+
 struct ConvertAlloc final : public OpConversionPattern<memref::AllocOp> {
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
@@ -204,31 +223,15 @@ struct ConvertCopy final : public OpConversionPattern<memref::CopyOp> {
       return rewriter.notifyMatchFailure(
           loc, "incompatible target memref type for EmitC conversion");
 
-    emitc::ConstantOp zeroIndex = rewriter.create<emitc::ConstantOp>(
-        loc, rewriter.getIndexType(), rewriter.getIndexAttr(0));
-
-    auto createPointerFromEmitcArray =
-        [loc, &rewriter, &zeroIndex](
-            mlir::TypedValue<emitc::ArrayType> arrayValue) -> emitc::ApplyOp {
-      int64_t rank = arrayValue.getType().getRank();
-      llvm::SmallVector<mlir::Value> indices(rank, zeroIndex);
-
-      emitc::SubscriptOp subPtr = rewriter.create<emitc::SubscriptOp>(
-          loc, arrayValue, mlir::ValueRange(indices));
-      emitc::ApplyOp ptr = rewriter.create<emitc::ApplyOp>(
-          loc, emitc::PointerType::get(arrayValue.getType().getElementType()),
-          rewriter.getStringAttr("&"), subPtr);
-
-      return ptr;
-    };
-
     auto srcArrayValue =
         cast<TypedValue<emitc::ArrayType>>(operands.getSource());
-    emitc::ApplyOp srcPtr = createPointerFromEmitcArray(srcArrayValue);
+    emitc::ApplyOp srcPtr =
+        createPointerFromEmitcArray(loc, rewriter, srcArrayValue);
 
     auto targetArrayValue =
         cast<TypedValue<emitc::ArrayType>>(operands.getTarget());
-    emitc::ApplyOp targetPtr = createPointerFromEmitcArray(targetArrayValue);
+    emitc::ApplyOp targetPtr =
+        createPointerFromEmitcArray(loc, rewriter, targetArrayValue);
 
     emitc::CallOpaqueOp memCpyCall = rewriter.create<emitc::CallOpaqueOp>(
         loc, TypeRange{}, "memcpy",
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir
index 6eb2c2db6a0b0..d151d1bd53458 100644
--- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir
@@ -8,7 +8,6 @@ func.func @copying(%arg0 : memref<9x4x5x7xf32>, %arg1 : memref<9x4x5x7xf32>) {
 
 // CHECK: module {
 // NOCPP:  emitc.include <"string.h">
-
 // CPP:  emitc.include <"cstring">
 
 // CHECK-LABEL:  copying



More information about the Mlir-commits mailing list