[Mlir-commits] [mlir] [mlir][EmitC] Expand the MemRefToEmitC pass - Lowering `reinterpret_cast` (PR #152610)

Jaden Angella llvmlistbot at llvm.org
Fri Aug 8 14:08:33 PDT 2025


https://github.com/Jaddyen updated https://github.com/llvm/llvm-project/pull/152610

>From c6ea99574fde268a83f20642f507aafe1ac6edf6 Mon Sep 17 00:00:00 2001
From: Jaddyen <ajaden at google.com>
Date: Wed, 6 Aug 2025 20:18:38 +0000
Subject: [PATCH 1/3] needs improvment

---
 .../MemRefToEmitC/MemRefToEmitC.cpp           | 84 ++++++++++++++++++-
 1 file changed, 83 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index 6bd0e2d4d4b08..60cccb0ece8f2 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -21,7 +21,9 @@
 #include "mlir/IR/TypeRange.h"
 #include "mlir/IR/Value.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "llvm/Support/FormatVariadic.h"
 #include <cstdint>
+#include <string>
 
 using namespace mlir;
 
@@ -269,6 +271,85 @@ struct ConvertLoad final : public OpConversionPattern<memref::LoadOp> {
   }
 };
 
+struct ConvertReinterpretCastOp final
+    : public OpConversionPattern<memref::ReinterpretCastOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    MemRefType srcType = cast<MemRefType>(castOp.getSource().getType());
+
+    MemRefType targetMemRefType =
+        cast<MemRefType>(castOp.getResult().getType());
+
+    auto srcInEmitC = convertMemRefType(srcType, getTypeConverter());
+    auto targetInEmitC =
+        convertMemRefType(targetMemRefType, getTypeConverter());
+    if (!srcInEmitC || !targetInEmitC) {
+      return rewriter.notifyMatchFailure(castOp.getLoc(),
+                                         "cannot convert memref type");
+    }
+    Location loc = castOp.getLoc();
+
+    auto srcArrayValue =
+        cast<TypedValue<emitc::ArrayType>>(adaptor.getSource());
+
+    emitc::ConstantOp zeroIndex = rewriter.create<emitc::ConstantOp>(
+        loc, rewriter.getIndexType(), rewriter.getIndexAttr(0));
+
+    auto createPointerFromEmitcArray =
+        [loc, &rewriter, &zeroIndex](
+            mlir::TypedValue<emitc::ArrayType> arrayValue) -> emitc::ApplyOp {
+      int64_t rank = arrayValue.getType().getRank();
+      llvm::SmallVector<mlir::Value> indices;
+      for (int i = 0; i < rank; ++i) {
+        indices.push_back(zeroIndex);
+      }
+
+      emitc::SubscriptOp subPtr = rewriter.create<emitc::SubscriptOp>(
+          loc, arrayValue, mlir::ValueRange(indices));
+      emitc::ApplyOp ptr = rewriter.create<emitc::ApplyOp>(
+          loc, emitc::PointerType::get(arrayValue.getType().getElementType()),
+          rewriter.getStringAttr("&"), subPtr);
+
+      return ptr;
+    };
+    auto [strides, offset] = targetMemRefType.getStridesAndOffset();
+    // Value offsetValue = rewriter.create<emitc::ConstantOp>(
+    //     loc, rewriter.getIndexType(), rewriter.getIndexAttr(offset));
+
+    auto srcPtr = createPointerFromEmitcArray(srcArrayValue);
+    // emitc::PointerType targetPointerType =
+    //     emitc::PointerType::get(srcArrayValue.getType().getElementType());
+
+    auto dimensions = targetMemRefType.getShape();
+    std::string reinterpretCastName = llvm::formatv(
+        "reinterpret_cast<{0}(*)", srcArrayValue.getType().getElementType());
+    std::string dimensionsStr;
+    for (auto dim : dimensions) {
+      dimensionsStr += llvm::formatv("[{0}]", dim);
+    }
+    reinterpretCastName += llvm::formatv("{0}>", dimensionsStr);
+    reinterpretCastName += ">";
+
+    reinterpretCastName += llvm::formatv("{0}", srcPtr->getResult(0));
+
+    std::string outputStr = llvm::formatv(
+        "{0}(*){1}", srcArrayValue.getType().getElementType(), dimensionsStr);
+    auto outputType = emitc::PointerType::get(
+        emitc::OpaqueType::get(rewriter.getContext(), outputStr));
+
+    emitc::ConstantOp reinterpretOp = rewriter.create<emitc::ConstantOp>(
+        loc, outputType,
+        emitc::OpaqueAttr::get(rewriter.getContext(), reinterpretCastName));
+
+    rewriter.replaceOp(castOp, reinterpretOp.getResult());
+    return success();
+  }
+};
+
 struct ConvertStore final : public OpConversionPattern<memref::StoreOp> {
   using OpConversionPattern::OpConversionPattern;
 
@@ -321,5 +402,6 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
 void mlir::populateMemRefToEmitCConversionPatterns(
     RewritePatternSet &patterns, const TypeConverter &converter) {
   patterns.add<ConvertAlloca, ConvertAlloc, ConvertGlobal, ConvertGetGlobal,
-               ConvertLoad, ConvertStore>(converter, patterns.getContext());
+               ConvertLoad, ConvertReinterpretCastOp, ConvertStore>(
+      converter, patterns.getContext());
 }

>From d09ccc0fbcf77ac09bafbc9b96b4bc6332c7f02a Mon Sep 17 00:00:00 2001
From: Jaddyen <ajaden at google.com>
Date: Thu, 7 Aug 2025 23:06:46 +0000
Subject: [PATCH 2/3] almost functional option

---
 .../MemRefToEmitC/MemRefToEmitC.cpp           | 41 ++++++++-----------
 mlir/lib/Target/Cpp/TranslateToCpp.cpp        | 13 +++++-
 .../memref-to-emitc-reinterpret-cast.mlir     | 16 ++++++++
 3 files changed, 43 insertions(+), 27 deletions(-)
 create mode 100644 mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-reinterpret-cast.mlir

diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index 60cccb0ece8f2..97b926861482f 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -20,6 +20,7 @@
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeRange.h"
 #include "mlir/IR/Value.h"
+#include "mlir/IR/ValueRange.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "llvm/Support/FormatVariadic.h"
 #include <cstdint>
@@ -316,36 +317,26 @@ struct ConvertReinterpretCastOp final
 
       return ptr;
     };
-    auto [strides, offset] = targetMemRefType.getStridesAndOffset();
-    // Value offsetValue = rewriter.create<emitc::ConstantOp>(
-    //     loc, rewriter.getIndexType(), rewriter.getIndexAttr(offset));
 
     auto srcPtr = createPointerFromEmitcArray(srcArrayValue);
-    // emitc::PointerType targetPointerType =
-    //     emitc::PointerType::get(srcArrayValue.getType().getElementType());
-
-    auto dimensions = targetMemRefType.getShape();
-    std::string reinterpretCastName = llvm::formatv(
-        "reinterpret_cast<{0}(*)", srcArrayValue.getType().getElementType());
-    std::string dimensionsStr;
-    for (auto dim : dimensions) {
-      dimensionsStr += llvm::formatv("[{0}]", dim);
-    }
-    reinterpretCastName += llvm::formatv("{0}>", dimensionsStr);
-    reinterpretCastName += ">";
-
-    reinterpretCastName += llvm::formatv("{0}", srcPtr->getResult(0));
+    // 1. Create a TypeAttr for the target type.
+    TypeAttr targetTypeAttr =
+        TypeAttr::get(emitc::PointerType::get(targetInEmitC));
+    IntegerAttr resty = rewriter.getIndexAttr(0);
 
-    std::string outputStr = llvm::formatv(
-        "{0}(*){1}", srcArrayValue.getType().getElementType(), dimensionsStr);
-    auto outputType = emitc::PointerType::get(
-        emitc::OpaqueType::get(rewriter.getContext(), outputStr));
+    // 2. Create an ArrayAttr with the TypeAttr. This will be the
+    // templateArgsAttr.
+    ArrayAttr templateArgsAttr = rewriter.getArrayAttr({targetTypeAttr});
 
-    emitc::ConstantOp reinterpretOp = rewriter.create<emitc::ConstantOp>(
-        loc, outputType,
-        emitc::OpaqueAttr::get(rewriter.getContext(), reinterpretCastName));
+    auto reinterpretCastCall = rewriter.create<emitc::CallOpaqueOp>(
+        loc,
+        /*result types=*/TypeRange{emitc::PointerType::get(targetInEmitC)},
+        /*callee=*/"reinterpret_cast",
+        /*args*/ rewriter.getArrayAttr({resty}),
+        /*template_args=*/templateArgsAttr,
+        /*operands=*/ValueRange{srcPtr.getResult()});
 
-    rewriter.replaceOp(castOp, reinterpretOp.getResult());
+    rewriter.replaceOp(castOp, reinterpretCastCall.getResults());
     return success();
   }
 };
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index 8e83e455d1a7f..3d47da8c439de 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -1848,8 +1848,17 @@ LogicalResult CppEmitter::emitType(Location loc, Type type) {
   if (auto lType = dyn_cast<emitc::LValueType>(type))
     return emitType(loc, lType.getValueType());
   if (auto pType = dyn_cast<emitc::PointerType>(type)) {
-    if (isa<ArrayType>(pType.getPointee()))
-      return emitError(loc, "cannot emit pointer to array type ") << type;
+    // Check if the pointee is an array type.
+    if (auto aType = dyn_cast<emitc::ArrayType>(pType.getPointee())) {
+      // Handle pointer to array: `element_type (*)[dim]`.
+      if (failed(emitType(loc, aType.getElementType())))
+        return failure();
+      os << "(*)";
+      for (auto dim : aType.getShape())
+        os << "[" << dim << "]";
+      return success();
+    }
+    // Handle standard pointer: `element_type*`.
     if (failed(emitType(loc, pType.getPointee())))
       return failure();
     os << "*";
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-reinterpret-cast.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-reinterpret-cast.mlir
new file mode 100644
index 0000000000000..5d610fe1ae642
--- /dev/null
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-reinterpret-cast.mlir
@@ -0,0 +1,16 @@
+// RUN: mlir-opt -convert-memref-to-emitc %s -split-input-file | FileCheck %s
+
+func.func @casting(%arg0: memref<999xi32>) {
+  %reinterpret_cast_5 = memref.reinterpret_cast %arg0 to offset: [0], sizes: [1, 1, 999], strides: [999, 999, 1] : memref<999xi32> to memref<1x1x999xi32>
+  return
+}
+
+//CHECK: module {
+//CHECK-NEXT:   func.func @casting(%arg0: memref<999xi32>) {
+//CHECK-NEXT:     %0 = builtin.unrealized_conversion_cast %arg0 : memref<999xi32> to !emitc.array<999xi32>
+//CHECK-NEXT:     %1 = "emitc.constant"() <{value = 0 : index}> : () -> index
+//CHECK-NEXT:     %2 = emitc.subscript %0[%1] : (!emitc.array<999xi32>, index) -> !emitc.lvalue<i32>
+//CHECK-NEXT:     %3 = emitc.apply "&"(%2) : (!emitc.lvalue<i32>) -> !emitc.ptr<i32>
+//CHECK-NEXT:     %4 = emitc.call_opaque "reinterpret_cast"(%3) {args = [0 : index], template_args = [!emitc.ptr<!emitc.array<1x1x999xi32>>]} : (!emitc.ptr<i32>) -> !emitc.ptr<!emitc.array<1x1x999xi32>>
+//CHECK-NEXT:     return
+

>From 19e23d560b00008e664c2e85df9840ae71f51ee1 Mon Sep 17 00:00:00 2001
From: Jaddyen <ajaden at google.com>
Date: Fri, 8 Aug 2025 21:08:08 +0000
Subject: [PATCH 3/3] cpp output that compiles

---
 .../MemRefToEmitC/MemRefToEmitC.cpp           | 19 +++----------------
 mlir/lib/Target/Cpp/TranslateToCpp.cpp        | 14 ++++++++++++++
 2 files changed, 17 insertions(+), 16 deletions(-)

diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index 97b926861482f..18bd79de2ec2a 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -319,24 +319,11 @@ struct ConvertReinterpretCastOp final
     };
 
     auto srcPtr = createPointerFromEmitcArray(srcArrayValue);
-    // 1. Create a TypeAttr for the target type.
-    TypeAttr targetTypeAttr =
-        TypeAttr::get(emitc::PointerType::get(targetInEmitC));
-    IntegerAttr resty = rewriter.getIndexAttr(0);
 
-    // 2. Create an ArrayAttr with the TypeAttr. This will be the
-    // templateArgsAttr.
-    ArrayAttr templateArgsAttr = rewriter.getArrayAttr({targetTypeAttr});
+    auto castCall = rewriter.create<emitc::CastOp>(
+        loc, emitc::PointerType::get(targetInEmitC), srcPtr.getResult());
 
-    auto reinterpretCastCall = rewriter.create<emitc::CallOpaqueOp>(
-        loc,
-        /*result types=*/TypeRange{emitc::PointerType::get(targetInEmitC)},
-        /*callee=*/"reinterpret_cast",
-        /*args*/ rewriter.getArrayAttr({resty}),
-        /*template_args=*/templateArgsAttr,
-        /*operands=*/ValueRange{srcPtr.getResult()});
-
-    rewriter.replaceOp(castOp, reinterpretCastCall.getResults());
+    rewriter.replaceOp(castOp, castCall);
     return success();
   }
 };
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index 3d47da8c439de..6ec1c070fde83 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -1756,6 +1756,20 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
 
 LogicalResult CppEmitter::emitVariableDeclaration(Location loc, Type type,
                                                   StringRef name) {
+  if (auto pType = dyn_cast<emitc::PointerType>(type)) {
+    if (auto aType = dyn_cast<emitc::ArrayType>(pType.getPointee())) {
+      if (failed(emitType(loc, aType.getElementType())))
+        return failure();
+      os << " (*" << name << ")";
+      for (auto dim : aType.getShape())
+        os << "[" << dim << "]";
+      return success();
+    }
+    if (failed(emitType(loc, pType.getPointee())))
+      return failure();
+    os << " *" << name;
+    return success();
+  }
   if (auto arrType = dyn_cast<emitc::ArrayType>(type)) {
     if (failed(emitType(loc, arrType.getElementType())))
       return failure();



More information about the Mlir-commits mailing list